In [1]:
import torch
import torch.nn as nn
import math

import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler
import matplotlib.pyplot as plt

torch.set_printoptions(precision=2)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return self.pe[:, :x.size(1)]

In [2]:
class multihead(nn.Module):
    def __init__(self, dmodel, h):
        super().__init__()

        self.dk=dmodel/h #queries & keys dimension

        assert (
            self.dk*h == dmodel
        ), "Embedding size needs to be divisible by heads"

        self.dk = int(self.dk)
        self.dv = self.dk

        self.h = h

        self.Wq = nn.Linear(dmodel, dmodel, bias=False)
        self.Wk = nn.Linear(dmodel, dmodel, bias=False)
        self.Wv = nn.Linear(dmodel, dmodel, bias=False)
        self.Wo = nn.Linear(dmodel, dmodel, bias=False)

    def attention(self, q, k, v, mask=None): #vect de dim 64
        product = torch.matmul(q, k.transpose(-2, -1)) # (2, 8, 9, 64) . (2, 8, 64, 9) = (2, 8, 9, 9)
        if mask is not None:
            product = product.masked_fill(mask == 0, float("-1e20"))
            
        product = product / math.sqrt(self.dk)
        score = torch.softmax(product, dim=-1)
        return torch.matmul(score, v)

    def forward(self, q, k, v, mask=None):

        h = self.h
        dv= self.dv
        dk= self.dk

        b, t, dmod = k.size()
        t_q= q.size(1)

        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)

        q = q.view(b, t_q, h, dk).transpose(1, 2)
        k = k.view(b, t, h, dk).transpose(1, 2)
        v = v.view(b, t, h, dv).transpose(1, 2)

        out = self.attention(q, k, v, mask)
        out = out.transpose(1, 2).contiguous().view(b, t_q, dmod)
        return self.Wo(out)

In [3]:
class feedForward(nn.Module):
    def __init__(self, dmodel, dff):
        super().__init__()
        self.linear1 = nn.Linear(dmodel, dff)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(dff, dmodel)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, dmodel, h, dff, pdrop):
        super().__init__()
        self.multihead = multihead(dmodel, h)
        self.FF = feedForward(dmodel, dff)
        self.layerNorm1 = nn.LayerNorm(dmodel)
        self.layerNorm2 = nn.LayerNorm(dmodel)
        self.dropout1 = nn.Dropout(pdrop)
        self.dropout2 = nn.Dropout(pdrop)

    def forward(self, q, k, v, mask):
        subL1 = self.multihead(q, k, v, mask)
        y = self.layerNorm1(q + self.dropout1(subL1))
        subL2 = self.FF(y)
        y = self.layerNorm2(y + self.dropout2(subL2))
        return y

In [5]:
class Decoder(nn.Module):
    def __init__(self, dmodel, h, dff, pdrop):
        super().__init__()
        self.multihead = multihead(dmodel, h)
        self.norm = nn.LayerNorm(dmodel)
        self.dropout = nn.Dropout(pdrop)
        self.block = TransformerBlock(dmodel, h, dff, pdrop)
        
    def forward(self, x, enc_output, src_mask, trg_mask=None):
        mask_att = self.multihead(x, x, x, trg_mask)
        y = self.norm(x + self.dropout(mask_att))
        y = self.block(y, enc_output, enc_output, src_mask)
        return y

In [6]:
def toTokens(logits):
    return torch.argmax(logits,dim=-1)

In [7]:
class Scheduler(_LRScheduler):
    def __init__(self, optim, dmodel, warmup_steps):
        self.dmodel = dmodel
        self.warmup_steps = warmup_steps
        self.num_param_groups = len(optim.param_groups)
        super().__init__(optim)
   
    def get_lr(self):
        step_num = self._step_count
        dmodel = self.dmodel
        warmup_steps = self.warmup_steps
    
        lrate = dmodel**(-0.5) * min(step_num**(-0.5), step_num*warmup_steps**(-1.5))
        return [lrate]

In [8]:
class Transformer(nn.Module):
    def __init__(self, N, dmodel, h, dff, pdrop, src_vocab_size, trg_vocab_size, max_seq_length, BOS, EOS, PAD, device):
        super().__init__()

        self.dmodel = dmodel
        self.max_seq_length = max_seq_length
        self.BOS = torch.tensor([BOS], dtype=torch.long, device=device)
        self.EOS = torch.tensor([EOS], dtype=torch.long, device=device)
        self.PAD = torch.tensor([PAD], dtype=torch.long, device=device)

        self.src_embedding = nn.Embedding(src_vocab_size, dmodel)
        self.trg_embedding = nn.Embedding(trg_vocab_size, dmodel)
        self.positional_encoding = PositionalEncoding(dmodel, max_seq_length)

        self.encoder = nn.ModuleList([TransformerBlock(dmodel, h, dff, pdrop) for _ in range(N)])
        self.decoder = nn.ModuleList([Decoder(dmodel, h, dff, pdrop) for _ in range(N)])

        self.linear = nn.Linear(dmodel, trg_vocab_size)
        self.dropout1 = nn.Dropout(pdrop)
        self.dropout2 = nn.Dropout(pdrop)


    def create_mask(self, src, trg=None):
        src_mask = (src != self.PAD).unsqueeze(1).unsqueeze(2)
        if trg is None:
            return src_mask
        else:
            trg_mask = (trg != self.PAD).unsqueeze(1).unsqueeze(3)
            trg_len = trg.size(1)
            diago = torch.tril(torch.ones(1, trg_len, trg_len)).bool()
            trg_mask = trg_mask & diago
            return src_mask, trg_mask


    def forward(self, src, trg):

        assert (
            src.size(0) == trg.size(0)
        ), "batch src and batch trg are not equals"

        src_mask, trg_mask = self.create_mask(src, trg)
        
        src_embed = self.dropout1(self.src_embedding(src) + self.positional_encoding(src))
        trg_embed = self.dropout2(self.trg_embedding(trg) + self.positional_encoding(trg))

        enc_out = src_embed
        for enc_layer in self.encoder:
            enc_out = enc_layer(enc_out, enc_out, enc_out, src_mask)

        dec_out = trg_embed
        for dec_layer in self.decoder:
            dec_out = dec_layer(dec_out, enc_out, src_mask, trg_mask)

        output = self.linear(dec_out)
        return torch.softmax(output, dim=2)


    def trainer(self, src, trg, epoch=1, printersrc='My father'):
        criterion = nn.CrossEntropyLoss(ignore_index=self.PAD, label_smoothing = 0.1)
        optimizer = optim.Adam(self.parameters() ,betas = (0.9, 0.98), eps = 1.0e-9)
        scheduler = Scheduler(optimizer, self.dmodel, warmup_steps = 4000)

        errors_list =  []

        self.train()
        for epoch in range(epoch):
            error = 0
            for batch in range(src.size(0)):
                optimizer.zero_grad()
                
                pred = self(src[batch].unsqueeze(0), trg[batch,:-1].unsqueeze(0))
                loss = criterion(pred[0], trg[batch,1:])

                loss.backward()
                optimizer.step()

                if scheduler is not None:
                    scheduler.step()
                error += loss.item()

            print(f"Epoch: {epoch+1}, Loss: {error}")
            errors_list.append(error)

        plt.plot(errors_list)
        plt.title("Errors")
        plt.show()


    def predict(self, src):
        self.eval()
        with torch.no_grad():
            BOS = self.BOS
            Tokens_ids = BOS.repeat(src.size(0), 1)#create the batch of sentence predict, each batch begin with BOS

            src_mask = self.create_mask(src)
            src_embed = self.dropout1(self.src_embedding(src) + self.positional_encoding(src))

            enc_out = src_embed
            
            for enc_layer in self.encoder:
                enc_out = enc_layer(enc_out, enc_out, enc_out, src_mask)

            for i in range(self.max_seq_length-1):
            
                dec_out = self.dropout2(self.trg_embedding(Tokens_ids) + self.positional_encoding(Tokens_ids))
                for dec_layer in self.decoder:
                    dec_out = dec_layer(dec_out, enc_out, src_mask)

                dec_out = self.linear(dec_out)
                transfo_out = torch.softmax(dec_out, dim=2)

                transfo_out = transfo_out[:,-1,:]#we take the last token infered

                tok = toTokens(transfo_out)

                Tokens_ids = torch.cat((Tokens_ids, tok.unsqueeze(1)), dim=1)
            
            return Tokens_ids[:, 1:]# vire le bos


    def infer(self, src, trg):#infer the last token of the trg, knowing the src and all except the last token of trg
        self.eval()
        with torch.no_grad():
            output = self(src, trg[:,:-1])
            return toTokens(output)

In [9]:
N=6; dmodel=512; h=8; dff=2048; pdrop=0.1; device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
src = torch.tensor([[1, 2, 3, 4, 0], [1, 3, 4, 5, 0], [1, 4, 5, 6, 0], [1, 5, 6, 7, 0], [1, 6, 7, 8, 0], [1, 7, 8, 9, 0]], dtype=torch.long, device=device)
trg = torch.tensor([[1, 4, 6, 8, 0], [1, 6, 8, 10, 0], [1, 8, 10, 12, 0], [1, 10, 12, 14, 0], [1, 12, 14, 16, 0], [1, 14, 16, 18, 0]], dtype=torch.long, device=device)

BOS=1; EOS=0; PAD=0
max_length=5

src_vocab_size = 20
trg_vocab_size = 20

transfo = Transformer(N, dmodel, h, dff, pdrop, src_vocab_size, trg_vocab_size, max_length, BOS, EOS, PAD, device)
transfo.trainer(src, trg, 50)
transfo.predict(src)