### ResNet-18 - CIFAR100 classification

In [31]:
import torch
from torchvision.models import resnet18
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor

from hook_management import HookManager

In [2]:
cifar100 = CIFAR100('../data/', train=True, transform=Compose([Resize((224, 224)), ToTensor()]))
cifar_ldr = DataLoader(cifar100)

In [41]:
resnet = resnet18()
resnet.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### XLM - Multi30k machine translation

In [85]:
import torch
from transformers import XLMTokenizer, XLMWithLMHeadModel
import spacy
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
import torch.nn as nn

In [72]:
english = spacy.load('en')
german = spacy.load('de')

def tokenize_en(text):
    return [tok.text for tok in english.tokenizer(text)]
def tokenize_de(text):
    return [tok.text for tok in german.tokenizer(text)]

en_text = Field(sequential=True, use_vocab=True, tokenize=tokenize_en, lower=True)
de_text = Field(sequential=True, use_vocab=True, tokenize=tokenize_de, lower=True)

train, val, test = Multi30k.splits(root='../data', exts=('.en', '.de'), fields=(en_text, de_text))

en_text.build_vocab(train, max_size=30000, min_freq=3)
de_text.build_vocab(train, max_size=30000, min_freq=3)
vocab_en = en_text.vocab
vocab_de = de_text.vocab
pad_idx = vocab_de.stoi['<pad>']

train_ldr, val_ldr, test_ldr = BucketIterator.splits((train, val, test),
                                                    batch_size=5)

In [66]:
xlm = XLMWithLMHeadModel.from_pretrained('xlm-mlm-ende-1024')
xlm.transformer.embeddings = nn.Embedding(len(vocab_en), xlm.config.emb_dim, padding_idx=pad_idx)
xlm.pred_layer.proj = nn.Linear(xlm.config.emb_dim, len(vocab_de), bias=True)
xlm.cuda()

XLMWithLMHeadModel(
  (transformer): XLMModel(
    (position_embeddings): Embedding(512, 1024)
    (lang_embeddings): Embedding(2, 1024)
    (embeddings): Embedding(4554, 1024, padding_idx=1)
    (layer_norm_emb): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (attentions): ModuleList(
      (0): MultiHeadAttention(
        (q_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (k_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (v_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (out_lin): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (1): MultiHeadAttention(
        (q_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (k_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (v_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (out_lin): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (2): MultiHeadAttention(
        (q_l

In [151]:
xent = nn.CrossEntropyLoss()

batch = next(iter(train_ldr))
src, trg = batch.src.to(0), batch.trg.to(0)
out, = xlm(src)
min_idx = min([out.shape[0], trg.shape[0]])
out, trg = out[:min_idx], trg[:min_idx]

mask = (trg != pad_idx).type(torch.bool)
loss = xent(out[mask], trg[mask])
loss.backward()