In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt
from torch import nn

# Création de la classe Transformers

Création des différentes couches

In [5]:
class FeedForwardNetwork(nn.Module):
    
    def __init__(self, d_model, d_ff):
        super(FeedForwardNetwork, self).__init__()
        
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [7]:
# Optionel : Coder notre propre couche multi-attention head (pour ne pas utiliser celle fournit par torch.nn)

In [9]:
def mask(x):
    len = x.size(0)
    mask = torch.triu(torch.ones(len,len), diag = 1) * (-1) * float("inf") # Matrice triangulaire supérieur de valeur -inf
    return mask

In [11]:
class PositionalEncoding(nn.Module):

    def __init__(self, max_length, d_model):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_length, d_model)
        position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]
        

In [13]:
class Encoder(nn.Module):

    def __init__(self, d_model, n_heads, d_ff):
        super(Encoder, self).__init__()

        self.attention = torch.nn.MultiheadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = FeedForwardNetwork(d_model, d_ff)
        self.norm2 = nn.LayerNorm(d_model)
        

    def forward(self, x):
        attention_output = self.attention(x, x, x)
        x = self.norm1(x + attention_output)
        ffn_output = self.ffn(x)
        x = self.norm2(x + ffn_output)
        return x



In [15]:
class Decoder(nn.Module):

    def __init__(self, d_model, num_heads, d_ff):
        super(Decoder, self).__init__()

        self.attention1 = torch.nn.MultiheadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.attention2 = torch.nn.MultiheadAttention(d_model, n_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = FeedForwardNetwork(d_model, d_ff)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_output, mask):
        attention_output1 = self.attention1(x, x, x,attn_mask = mask)
        x = self.norm1(x + attention_output)
        attention_output2 = self.attention2(x, enc_output, enc_output)
        x = self.norm2(x + attention_output2)
        ffn_output = self.ffn(x)
        x = self.norm3(x + ffn_output)
        return x



In [19]:
class Transformer(nn.Module):

    def __init__(self, vocab_size, target_size, max_length, d_model, num_heads, d_ff, n_layers):
        super(Transformer, self).__init__()
        
        self.enc_embedding = nn.Embedding(vocab_size, d_model)
        self.dec_embedding = nn.Embedding(target_size, d_model)
        self.positional_encoding = PositionalEncoding(max_length, d_model)
        
        self.encoder_layers = [Encoder(d_model, num_heads,d_ff) for i in range(n_layers)]
        self.decoder_layers = [Decoder(d_model, num_heads, d_ff) for i in range(n_layers)]

        self.linear = nn.Linear(d_model, target_size)
        self.softmax = nn.Softmax()

    def forward(inp,out):
        mask = mask(out)
        out_embedded = self.positional_encoding(self.dec_embedding(out))
        inp_embedded = self.positional_encoding(self.enc_embedding(inp))

        enc_output = inp_embedded
        for encoder in self.encoder_layers:
            enc_output = encoder(enc_output)

        dec_output = out_embedded
        for decoder in self.decoder_layers:
            dec_output = decoder(dec_output, enc_output, mask)

        output = self.softmax(self.linear(dec_output))

        return output   

# Mise en forme des données

In [None]:
# TODO : Absolument tout