In [30]:
import torch
import torch.nn as nn
import math

import torch.optim as optim
import matplotlib.pyplot as plt

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return self.pe[:, :x.size(1)]

In [31]:
class multihead(nn.Module):
    def __init__(self, dmodel, h):
        super().__init__()

        self.dk=dmodel/h #queries & keys dimension

        assert (
            self.dk*h == dmodel
        ), "Embedding size needs to be divisible by heads"

        self.dk = int(self.dk)
        self.dv = self.dk #values dimension

        self.h = h

        self.Wq = nn.Linear(dmodel, dmodel, bias=False)
        self.Wk = nn.Linear(dmodel, dmodel, bias=False)
        self.Wv = nn.Linear(dmodel, dmodel, bias=False)
        self.Wo = nn.Linear(dmodel, dmodel, bias=False)

    def attention(self, q, k, v, mask=None): #vect de dim 64
        product = torch.matmul(q, k.transpose(-2, -1)) # (2, 8, 9, 64) . (2, 8, 64, 9) = (2, 8, 9, 9)
        
        if mask is not None:
            product = product.masked_fill(mask == 0, float("-1e20"))
            
        product = product / math.sqrt(self.dk)
        score = torch.softmax(product, dim=-1)
        return torch.matmul(score, v)

    def forward(self, q, k, v, mask=None):

        h = self.h
        dv= self.dv
        dk= self.dk

        b, t, dmod = k.size()
        t_q= q.size(1)

        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)
        
        q = q.view(b, t_q, h, dk).transpose(1, 2)
        k = k.view(b, t, h, dk).transpose(1, 2)
        v = v.view(b, t, h, dv).transpose(1, 2)

        out = self.attention(q, k, v, mask)
        out = out.transpose(1, 2).contiguous().view(b, t_q, dmod)
        return self.Wo(out)

In [32]:
class feedForward(nn.Module):
    def __init__(self, dmodel, dff):
        super().__init__()
        self.linear1 = nn.Linear(dmodel, dff)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(dff, dmodel)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x

In [33]:
class TransformerBlock(nn.Module):
    def __init__(self, dmodel, h, dff, pdrop):
        super().__init__()
        self.multihead = multihead(dmodel, h)
        self.FF = feedForward(dmodel, dff)
        self.layerNorm1 = nn.LayerNorm(dmodel)
        self.layerNorm2 = nn.LayerNorm(dmodel)
        self.dropout = nn.Dropout(pdrop)

    def forward(self, q, k, v, mask):
        subL1 = self.multihead(q, k, v, mask)
        y = self.layerNorm1(q + self.dropout(subL1))
        subL2 = self.FF(y)
        y = self.layerNorm2(y + self.dropout(subL2))
        return y

In [34]:
class Decoder(nn.Module):
    def __init__(self, dmodel, h, dff, pdrop):
        super().__init__()
        self.multihead = multihead(dmodel, h)
        self.norm = nn.LayerNorm(dmodel)
        self.dropout = nn.Dropout(pdrop)
        self.block = TransformerBlock(dmodel, h, dff, pdrop)
        
    def forward(self, x, enc_output, trg_mask, src_mask):
        mask_att = self.multihead(x, x, x, trg_mask)
        y = self.norm(x + self.dropout(mask_att))
        y = self.block(y, enc_output, enc_output, src_mask)
        return y

In [35]:
class Transformer(nn.Module):
    def __init__(self, N, dmodel, h, dff, pdrop, src_vocab_size, trg_vocab_size, max_seq_length):
        super().__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, dmodel)
        self.trg_embedding = nn.Embedding(trg_vocab_size, dmodel)
        self.positional_encoding = PositionalEncoding(dmodel, max_seq_length)

        self.encoder = nn.ModuleList([TransformerBlock(dmodel, h, dff, pdrop) for _ in range(N)])
        self.decoder = nn.ModuleList([Decoder(dmodel, h, dff, pdrop) for _ in range(N)])

        self.linear = nn.Linear(dmodel, trg_vocab_size)
        self.dropout = nn.Dropout(pdrop)


    def create_mask(self, src, trg):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        trg_mask = (trg != 0).unsqueeze(1).unsqueeze(3)
        trg_len = trg.size(1)
        diago = torch.tril(torch.ones(1, trg_len, trg_len)).bool()
        trg_mask = trg_mask & diago
        return src_mask, trg_mask


    def forward(self, src, trg):
        src_mask, trg_mask = self.create_mask(src, trg)

        src_embed = self.dropout(self.src_embedding(src) + self.positional_encoding(src))
        trg_embed = self.dropout(self.trg_embedding(trg) + self.positional_encoding(trg))

        enc_out = src_embed
        for enc_layer in self.encoder:
            enc_out = enc_layer(enc_out, enc_out, enc_out, src_mask)

        dec_out = trg_embed
        for dec_layer in self.decoder:
            dec_out = dec_layer(dec_out, enc_out, trg_mask, src_mask)

        output = self.linear(dec_out)
        return torch.softmax(output, dim=2)
    
    def infer(self, src, trg):
        self.eval()
        with torch.no_grad():
            output = self(src, trg[:,:-1])
            #print("trg[:,1:]:", trg[:,1:])
            #print("Probas token:", output[0])
            print("Tokens:", torch.argmax(output,dim=-1))

    def trainer(self, src, trg, epoch):
        criterion = nn.CrossEntropyLoss(ignore_index=0)
        optimizer = optim.Adam(self.parameters())
        errors =  []
        
        src = src.type(torch.LongTensor)
        trg = trg.type(torch.LongTensor)

        self.train()
        for epoch in range(epoch):
            pred = self(src, trg[:,:-1])
            #loss = criterion(pred.reshape(src.size(0)*(max_seq_length-1), vocab_size), trg[:,1:].reshape(src.size(0)*(max_seq_length-1)))
            #loss = criterion(output.contiguous().view(-1, vocab_size), trg[:, 1:].contiguous().view(-1))
            #loss = criterion(output.contiguous().view(-1, vocab_size), trg.contiguous().view(-1))
            loss = criterion(pred[0], trg[0,1:])
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            errors.append(loss.item())
            print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

        plt.plot(errors)
        plt.title("Errors")
        plt.show()


In [36]:
N=6
dmodel=512
h=8
dff = 2048
pdrop = 0.1

vocab_size=2
#size = 11
#src = torch.zeros(size).unsqueeze(0).type(torch.LongTensor)
#src = torch.ones(size).unsqueeze(0).type(torch.LongTensor)
src = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0 , 0, 0, 0]]).type(torch.LongTensor)
#trg = torch.ones(size).unsqueeze(0).type(torch.LongTensor)
trg = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0]]).type(torch.LongTensor)
max_seq_length = src.size(1)

transfo = Transformer(N, dmodel, h, dff, pdrop, vocab_size, vocab_size, max_seq_length)
transfo.infer(src, trg)
#transfo.trainer(src, trg, 2)

Tokens: tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1]])
