<a href="https://colab.research.google.com/github/MathBorgess/into_pytorch/blob/master/transformer_training_loop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [59]:
%pip install spacy
%pip install torchtext
!python -m spacy download en
!python -m spacy download de

[38;5;3m⚠ As of spaCy v3.0, shortcuts like 'en' are deprecated. Please use the
full pipeline package name 'en_core_web_sm' instead.[0m
Collecting en-core-web-sm==3.5.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ As of spaCy v3.0, shortcuts like 'de' are deprecated. Please use the
full pipeline package name 'de_core_news_sm' instead.[0m
Collecting de-core-news-sm==3.5.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.5.0/de_core_news_sm-3.5.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m34.0 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Dow

In [60]:
import torch
from torch import nn
from torch import optim
import spacy
import torchtext
from torchtext.datasets import Multi30k
from collections import Counter
#https://pytorch.org/tutorials/beginner/torchtext_translation_tutorial.html

In [61]:
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
train_urls = ('train.de.gz', 'train.en.gz')
val_urls = ('val.de.gz', 'val.en.gz')
test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz')

train_filepaths = [torchtext.utils.extract_archive(
    torchtext.utils.download_from_url(
        url_base+train_url
    )
)[0] for train_url in train_urls]

val_filepaths = [torchtext.utils.extract_archive(
    torchtext.utils.download_from_url(
        url_base+val_url
    )
)[0] for val_url in val_urls]

test_filepaths = [torchtext.utils.extract_archive(
    torchtext.utils.download_from_url(
        url_base+test_url
    )
)[0] for test_url in test_urls]

In [62]:
en_tokenizer = torchtext.data.get_tokenizer('spacy', language='en')
de_tokenizer = torchtext.data.get_tokenizer('spacy', language='de')



In [63]:
import io

def build_vocab(filepaths, tokenizer):
    counter = Counter()
    for filepath in filepaths:
      with io.open(filepath, encoding="utf-8") as file:
          for string_ in file:
              counter.update(tokenizer(string_))
    return torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

de_vocab, en_vocab = build_vocab([train_filepaths[0], val_filepaths[0], test_filepaths[0]], de_tokenizer), build_vocab([train_filepaths[1], val_filepaths[1], test_filepaths[1]], en_tokenizer)

In [64]:
def data_process(filepaths):
  raw_de_iter = iter(io.open(filepaths[0], encoding="utf8"))
  raw_en_iter = iter(io.open(filepaths[1], encoding="utf8"))
  data = []
  for (raw_de, raw_en) in zip(raw_de_iter, raw_en_iter):
    de_tensor_ = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de)],
                            dtype=torch.long)
    en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en)],
                            dtype=torch.long)
    data.append((de_tensor_, en_tensor_))
  return data

train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

In [69]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 128
PAD_IDX = de_vocab.vocab.lookup_indices(['<pad>'])[0]
BOS_IDX = de_vocab.vocab.lookup_indices(['<bos>'])[0]
EOS_IDX = de_vocab.vocab.lookup_indices(['<eos>'])[0]

def generate_batch(data_batch):
  de_batch, en_batch = [], []
  for (de_item, en_item) in data_batch:
    de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
    en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
  de_batch = torch.nn.utils.rnn.pad_sequence(de_batch, padding_value=PAD_IDX)
  en_batch = torch.nn.utils.rnn.pad_sequence(en_batch, padding_value=PAD_IDX)
  return de_batch, en_batch

train_iter = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iter = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
test_iter = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=True, collate_fn=generate_batch)

In [66]:
from torch import nn

class Translator(nn.Module):
    def __init__(self,
                 layers_units,
                 dim_model,
                 heads,
                 src_vocab_size,
                 tar_vocab_size,
                 src_pad_idx,
                 forward_expansion,
                 device,
                 max_length,
                 dropout):
        super(Translator, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size,dim_model)
        self.src_pos_embedding = nn.Embedding(max_length, dim_model)

        self.tar_embedding = nn.Embedding(tar_vocab_size,dim_model)
        self.tar_pos_embedding = nn.Embedding(max_length, dim_model)

        self.device = device

        self.transformer = nn.Transformer(dim_model,
                                        heads,
                                        layers_units,
                                        layers_units,
                                        forward_expansion,
                                        dropout)
        self.fc_out = nn.Linear(dim_model, tar_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def create_src_mask(self, src):
        # the torch src_mask for transformer needs to be transposed
        src_mask = src.transpose(0, 1) == self.src_pad_idx
        return src_mask

    def forward(self, src, tar):
        src_seq_length, N = src.shape
        tar_seq_length, N = tar.shape

        #the arange will make a [0, 1..., length-1] tensor, unsqueeze(1) makes it 2D, and he expand will replicate it by N layers
        src_pos = (
            torch.arange(0, src_seq_length).unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )
        tar_pos = (
            torch.arange(0, tar_seq_length).unsqueeze(1)
            .expand(tar_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_embedding(src) + self.src_pos_embedding(src_pos))
        )
        embed_tar = self.dropout(
            (self.tar_embedding(tar) + self.tar_pos_embedding(tar_pos))
        )

        src_padding_mask = self.create_src_mask(src)
        tar_mask = self.transformer.generate_square_subsequent_mask(tar_seq_length).to(self.device)

        return self.transformer(
            embed_src,
            embed_tar,
            src_key_padding_mask= src_padding_mask,
            tgt_mask=tar_mask
            )

In [67]:
# Training hyperparameters
epochs = 5
lr =  1e-4
dim_model = 512
heads = 8
layers_units = 3
dropout = 0.1
max_length = 100
forward_expansion = 4

In [95]:
model = Translator(
    layers_units,
    dim_model,
    heads,
    len(de_vocab),
    len(en_vocab),
    PAD_IDX,
    forward_expansion,
    device,
    max_length,
    dropout
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr)
criterion = nn.CrossEntropyLoss(ignore_index=en_vocab.vocab.lookup_indices(['<pad>'])[0]).to(device)

In [118]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
step = 0

for epoch in range(epochs):
    print(f'Epoch: {epoch} / {epochs}')

    model.train()

    for _, (src, tar) in enumerate(train_iter):
        src, tar = src.to(device), tar.to(device)

        output = model(src, tar)

        output = output[1:].view(-1, output.shape[-1])
        tar = tar[1:].view(-1)

        optimizer.zero_grad()
        loss = criterion(output, tar)
        loss.backward()
        torch.nn.utils.clip_grad.clip_grad_norm(model.parameters(), max_norm= 1)
        optimizer.step()

        writer.add_scalar("training loss", loss, global_step=step)
        step += 1

Epoch: 0 / 5


IndexError: ignored