In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext.datasets import IWSLT
from torchtext.data import Field, BucketIterator
import spacy
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

#execute on gpu if available, else on cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#multi-head attention block
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, mask = False):
        super().__init__()
        #split the embeded word into multiple heads that run in parallel
        self.num_heads = num_heads
        self.head_size = embed_dim // num_heads
        #optionally use mask to hide the next part of the sentence (used for the decoder)
        self.mask = mask
        assert(num_heads * self.head_size == embed_dim, "embed dim and number of heads aren't compatible")
        #linear layers
        self.lin1 = nn.Linear(self.head_size, self.head_size, bias = False)
        self.lin2 = nn.Linear(embed_dim, embed_dim, bias = False)
        
    def forward(self, queries, keys, values):
        b, t2 = queries.size(0), queries.size(1)
        h = self.num_heads
        d = self.head_size
        t = values.size(1)
        
        #receive queries, keys and values and pass through linear layers
        queries = self.lin1(queries.reshape(b, t2, h , d))
        keys = self.lin1(keys.reshape(b, t, h , d))
        values = self.lin1(values.reshape(b, t, h , d))
        
        #scaled dot product attention
        queries = queries.transpose(1,2).reshape(b * h, t2, d)
        keys = keys.transpose(1,2).reshape(b * h, t, d)
        matmul1 = torch.bmm(queries, keys.transpose(1,2))
        scale = (matmul1 / (d ** (1/2)))
        
        if self.mask:
            indices = torch.triu_indices(t2, t, offset = 1)
            scale[:, indices[0], indices[1]] = float('-inf')
            
        soft = F.softmax(scale, dim=2)
        values = values.transpose(1,2).reshape(b *  h, t, d)
        matmul2 = torch.bmm(soft, values)
        
        #concat and linear layer
        out = self.lin2(matmul2.reshape(b, h, t2, d).transpose(1, 2).reshape(b, t2, h * d))
        
        return out

#transformer block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, expansion_size, drop = 0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        #use multi-head attention block
        self.attentionblock = SelfAttention(self.embed_dim, self.num_heads)
        
        #feed forward block
        self.ff = nn.Sequential(nn.Linear(embed_dim, expansion_size),
                                nn.ReLU(),
                                nn.Linear(expansion_size, embed_dim))
        #normalization layer
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(drop)
        
    def forward(self, queries, keys, values):
        
        #attention
        mha = self.attentionblock(queries, keys, values)
        
        #add & norm 1 block
        addnorm1 = self.norm(queries + self.dropout(mha))
        
        #feed forward
        feedfwd = self.ff(addnorm1)
        
        #add & norm block
        addnorm2 = self.norm(addnorm1 + self.dropout(feedfwd))
        
        return addnorm2
    
#encoder block
class Encoder(nn.Module):
    def __init__(self, embed_dim, num_heads, expansion_size, num_layers, dict_size, max_len=100, drop = 0.1):
        super().__init__()
        self.num_layers = num_layers
        
        #word embedding
        self.embed = nn.Embedding(dict_size, embed_dim)
        
        #positional encoding
        self.embed_pos = nn.Embedding(max_len, embed_dim)
        self.dropout = nn.Dropout(drop)
        
        #tranformer block
        self.transformblock = TransformerBlock(embed_dim, num_heads, expansion_size, drop)
        
    def forward(self, inp):
        b, t = inp.size()
        
        #combine tranformer block with embedding and positional encoding
        input_embed = self.embed(inp)
        pos = torch.arange(t).repeat(b, 1).to(device)
        pos_embed = self.embed_pos(pos)
        input_pos = self.dropout(pos_embed + input_embed)
        out = input_pos
        
        #make 'n' layers of encoder block
        for i in range(self.num_layers):
            out = self.transformblock.forward(out, out, out)
            
        return out

#decoder block
class Decoder(nn.Module):
    
    def __init__(self, embed_dim, num_heads, expansion_size, num_layers, dict_size, max_len=100, drop = 0.1):
        super().__init__()
        self.num_layers = num_layers
        
        #word embedding
        self.embed = nn.Embedding(dict_size, embed_dim)
        
        #positional encoding
        self.embed_pos = nn.Embedding(max_len, embed_dim)
        self.dropout = nn.Dropout(drop)
        
        #attention block with mask
        self.attentionblock = SelfAttention(embed_dim, num_heads, mask = True)
        
        #norm layer
        self.norm = nn.LayerNorm(embed_dim)
        
        #transformer block
        self.transformblock = TransformerBlock(embed_dim, num_heads, expansion_size, drop)
        self.lin = nn.Linear(embed_dim, dict_size)
        
    def forward(self, inp, inp_enc):
        
        #same architecture as the encoder with masked attention and addnorm block added before transformer block
        b, t = inp.size()
        input_embed = self.embed(inp)
        pos = torch.arange(t).repeat(b, 1).to(device)
        pos_embed = self.embed_pos(pos)
        input_pos = self.dropout(pos_embed + input_embed)
        out = input_pos
        
        for i in range(self.num_layers):
            out = self.norm(out + self.dropout(self.attentionblock.forward(out, out, out)))
            out = self.transformblock.forward(out, inp_enc, inp_enc)
            
        out =  self.dropout(self.lin(out))
        
        return out

#complete transformer
class Transformer(nn.Module):
    
    def __init__(self, embed_dim, num_heads, expansion_size, num_layers, src_dict_size, trg_dict_size, max_len=100, drop = 0.1):
        super().__init__()
        self.enc = Encoder(embed_dim, num_heads, expansion_size, num_layers, src_dict_size, max_len, drop)
        self.dec = Decoder(embed_dim, num_heads, expansion_size, num_layers, trg_dict_size, max_len, drop)
        
    def forward(self, inp, out):
        
        #combine encoder with decoder
        return self.dec.forward(out, self.enc.forward(inp))

#function to translate sentence
def translate_sentence(sentence, model, french, english, max_len=100):
    
    #load spacy language module to build vocabulary
    spacy_fr = spacy.load("fr")
    
    #use spacy tokenization function
    if type(sentence) == str:
        tokens = [token.text.lower() for token in spacy_fr(sentence)]
    else:
        tokens = [token.lower() for token in sentence]
        
    #insert start and end of sentence tokens
    tokens.insert(0, french.init_token)
    tokens.append(french.eos_token)
    
    #convert tokens to indices
    text_to_indices = [french.vocab.stoi[tok] for tok in tokens]
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(0).to(device)
    
    #second param will be start of sentence index
    outputs = [english.vocab.stoi["<sos>"]]
    
    #for each loop the model will predict the next word in the translated sentence
    for i in range(max_len):
        
        #convert target to tensor
        trg_tensor = torch.LongTensor(outputs).unsqueeze(0).to(device)
        
        #predict next word
        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)
        best_guess = output.argmax(2)[:, -1].item()
        outputs.append(best_guess)
        
        
        if best_guess == english.vocab.stoi["<eos>"]:
            break
    #convert indices into words
    translated_sentence = [english.vocab.itos[idx] for idx in outputs]
    
    #return sentence without "start" and "end" tokens
    return translated_sentence[1:-1]


#load spacy language modules to build vocabularies
spacy_fr = spacy.load("fr")
spacy_en = spacy.load("en")

#use spacy tokenization functions
def tokenize_fr(text):
    return [tok.text for tok in spacy_fr.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]


french = Field(tokenize=tokenize_fr, lower=True, init_token="<sos>", eos_token="<eos>")
english = Field(tokenize=tokenize_en, lower=True, init_token="<sos>", eos_token="<eos>")

#use the ISWLT dataset of TED talks
train_data, valid_data, test_data = IWSLT.splits(exts=(".fr", ".en"), fields=(french, english))

#build vocabularies
french.build_vocab(train_data, max_size=10000, min_freq=2)
english.build_vocab(train_data, max_size=10000, min_freq=2)

#model hyperparameters
embed_dim = 512
num_heads = 8
expansion_size = 2048
num_layers = 3
src_dict_size = len(french.vocab)
trg_dict_size = len(english.vocab)
max_len = 100
dropout = 0.1

#training hyperparameters
batch_size = 32
num_epochs = 10
learn_rate = 0.0003

load_model = False
save_model = True

#use Tensorboard summary writer to plot loss and accuracy during training
writer = SummaryWriter("runs/loss-plot")
step = 0

#use the BucketIterator module to split data into batches sorted by sentence length
train_it, valid_it, test_it = BucketIterator.splits((train_data, valid_data, test_data), batch_size=batch_size, sort_within_batch=True, sort_key= lambda x: len(x.src), device=device)

#make instance of tranformer model
model = Transformer(embed_dim, num_heads, expansion_size, num_layers, src_dict_size, trg_dict_size, max_len=max_len, drop=dropout).to(device)

#make optimizer
optimizer = optim.Adam(model.parameters(), lr=learn_rate)

#get padding index and ignore it when calculating loss
pad_idx = english.vocab.stoi["<pad>"]
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

#can load saved model and optimizer
if load_model:
    print("=> loading checkpoint")
    checkpoint = torch.load("my_checkpoint.pth.tar")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    
#test sentence
sentence = "un chat noir et blanc est mignon."

#loop over epochs
for epoch in range(num_epochs):
    
    losses = []
    print(f"epoch {epoch + 1}/{num_epochs}")
    
    #can save model and optimizer in ".tar" file
    if save_model:
        print("=> saving checkpoint")
        checkpoint = {"state_dict":model.state_dict(), "optimizer":optimizer.state_dict()}
        torch.save(checkpoint, "my_checkpoint.pth.tar")
        
    #test model while training
    model.eval()
    translated_sentence = translate_sentence(sentence, model, french, english)
    print(f"translated exemple sentence: \n {translated_sentence}")
    model.train()
    
    #loop over batches (use tqdm to make a progress bar)
    for batch in tqdm(train_it):
        
        #ignore batches with sentence length over 100
        if (batch.src.size(0) > 100) or (batch.trg.size(0) > 100):
          continue
        
        #convert the input and the target into the adquate shape
        inp = batch.src.transpose(0, 1).to(device)
        trg = batch.trg.transpose(0, 1).to(device)
        
        #pass them through the model
        out = model(inp, trg[:, :-1])
        out = out.reshape(-1, out.shape[2])
        trg = trg[:, 1:].reshape(-1)
        
        #zero the gradients
        optimizer.zero_grad()
        
        #calculate the loss backward it
        loss = criterion(out, trg)
        losses.append(loss.item())
        loss.backward()
        
        #clip the gradients
        torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=1)
        
        optimizer.step()
        
        #add summary writer step
        writer.add_scalar("Training Loss", loss, global_step=step)
        step += 1
        
    #calculate the average loss over one epoch
    mean_loss = sum(losses)/len(losses)
    print("loss :",mean_loss)
    

  assert(num_heads * self.head_size == embed_dim, "embed dim and number of heads aren't compatible")


epoch 1/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['tanzania', 'dot', 'refugee', 'backed', 'twisted', 'returning', 'danny', 'warn', 'neat', 'counseling', 'biases', 'non', 'phase', 'jazz', 'driving', 'bipolar', 'beloved', 'potentially', 'disability', 'pond', 'derive', 'speakers', 'underlies', 'bipolar', 'tip', 'connecticut', 'ideally', 'sudan', 'stores', 'period', 'knit', 'nba', 'daylight', 'optimal', 'currents', 'p', 'injured', 'bench', 'download', 'underlies', 'ashamed', 'kentucky', 'opinion', 'existed', 'liked', 'l.a.', 'bj', 'twisted', 'returning', 'unacceptable', 'underlies', 'introduced', 'solitary', 'movement', 'hurricane', 'magazines', 'underlies', 'puzzle', 'infections', 'digging', 'engineering', 'lapse', 'reminds', 'conclusions', 'bus', 'biases', 'receives', 'integration', 'shooting', 'may', 'legacy', 'solitary', 'non', 'april', 'exceptional', 'participants', 'avatar', 'graders', '20', 'dancing', 'cholera', 'underlies', 'a.i.', 'prayer', 'constitution', 'critics', 'operates', 'bleed', 'pit', 'ic

100%|██████████| 6888/6888 [04:51<00:00, 23.67it/s]


70
epoch 2/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['a', 'black', 'cat', 'and', 'white', 'is', 'beautiful', '.', '<eos>']


100%|██████████| 6888/6888 [04:55<00:00, 23.34it/s]


69
epoch 3/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['black', 'cat', 'and', 'white', 'and', 'white', 'is', 'cute', '.', '<eos>']


100%|██████████| 6888/6888 [04:55<00:00, 23.28it/s]


69
epoch 4/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['black', 'cat', 'and', 'white', 'is', 'white', '.', '<eos>']


100%|██████████| 6888/6888 [04:52<00:00, 23.55it/s]


69
epoch 5/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['black', 'cat', 'and', 'white', 'is', 'cute', '.', '<eos>']


100%|██████████| 6888/6888 [04:59<00:00, 23.04it/s]


69
epoch 6/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['a', 'cat', 'and', 'white', 'cat', 'is', 'cute', '.', '<eos>']


100%|██████████| 6888/6888 [04:57<00:00, 23.14it/s]


69
epoch 7/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['a', 'black', 'cat', 'and', 'white', 'is', 'cute', '.', '<eos>']


100%|██████████| 6888/6888 [04:55<00:00, 23.31it/s]


69
epoch 8/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['a', 'black', 'cat', 'and', 'white', 'is', 'cute', '.', '<eos>']


100%|██████████| 6888/6888 [04:54<00:00, 23.36it/s]


69
epoch 9/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['black', 'cat', 'and', 'white', 'is', 'cute', '.', '<eos>']


100%|██████████| 6888/6888 [04:54<00:00, 23.41it/s]


69
epoch 10/10
=> saving checkpoint


  0%|          | 0/6888 [00:00<?, ?it/s]

translated exemple sentence: 
 ['black', 'cat', 'is', 'cute', '.', '<eos>']


100%|██████████| 6888/6888 [04:54<00:00, 23.36it/s]

69





In [2]:
#install spacy laguage modules
!python -m spacy download en
!python -m spacy download fr

[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')
[38;5;2m✔ Linking successful[0m
/usr/local/lib/python3.6/dist-packages/en_core_web_sm -->
/usr/local/lib/python3.6/dist-packages/spacy/data/en
You can now load the model via spacy.load('en')
Collecting fr_core_news_sm==2.2.5
[?25l  Downloading https://github.com/explosion/spacy-models/releases/download/fr_core_news_sm-2.2.5/fr_core_news_sm-2.2.5.tar.gz (14.7MB)
[K     |████████████████████████████████| 14.7MB 3.8MB/s 
Building wheels for collected packages: fr-core-news-sm
  Building wheel for fr-core-news-sm (setup.py) ... [?25l[?25hdone
  Created wheel for fr-core-news-sm: filename=fr_core_news_sm-2.2.5-cp36-none-any.whl size=14727027 sha256=1b26a348dc5a015f6b8600ac4539394b2dfb2699b864a19f08984b2c39860907
  Stored in directory: /tmp/pip-ephem-wheel-cache-x5xfsulg/wheels/46/1b/e6/29b020e3f9420a24c3f463343afe5136aaaf955dbc9e46dfc5
Successfully built fr-core-news-sm
Inst