## Implémentation d'un RNN Encodeur-Décodeur qui ne prend pas en entrée les mots cibles
*Figure (schéma de l'architecture) à faire*

In [9]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

### Partie Encodeur 
- Classique, pas de modifications ici

In [4]:
class Encoder(nn.Module):
    def __init__(self, emb_dim, hidden_size, num_layers=1, vocab_size=1000) -> None:
        super(Encoder, self).__init__()

        self.num_layers = num_layers
        self.emb_dim = emb_dim
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(
            num_embeddings=self.vocab_size,
            embedding_dim=self.emb_dim
        )

        self.rnn = nn.RNN(
            input_size=self.emb_dim,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
        )

    def forward(self, input_sequence, hidden):

        embedded_sequence =  self.embedding(input_sequence)
        output_sequence, hidden = self.rnn(embedded_sequence, hidden)

        return output_sequence, hidden

    # def hinit(self):
    #     print('hey')
    #     # num_layers, batch_size=1, hidden_size 
    #     h0 = torch.zeros(size=(self.num_layers, 1, self.hidden_size))
    #     return h0


In [5]:
encodeur = Encoder(100, 50, 2, 1000)

In [30]:
batch_sequence = torch.randint(1000, size=(32, 10))

batch_hidden_input = torch.randint(10, size=(2, 32, 50), dtype=torch.float32)

enc_res = encodeur(batch_sequence, batch_hidden_input)
enc_res[0].shape, enc_res[1].shape

(torch.Size([32, 10, 50]), torch.Size([2, 32, 50]))

### Décodeur
- Partie avec des modifications vis à vis de l'architecture clasique 

In [31]:
class Decodeur(nn.Module):

    def __init__(self, emb_dim, hidden_size, num_layers=1, vocab_size=500) -> None:
        super(Decodeur, self).__init__()

        self.vocab_size = vocab_size
        self.num_layers = num_layers

        self.embedding_dim = emb_dim
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(
            num_embeddings=self.vocab_size,
            embedding_dim=self.embedding_dim
        )

        self.rnn = nn.RNN(
            input_size=self.embedding_dim + self.hidden_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
        )

        self.linear = nn.Linear(
            in_features=self.embedding_dim + 2 * self.hidden_size,
            out_features=self.vocab_size,
            bias=False
        )

    def forward(self, input_sequence, hidden, context):
        seq_len = input_sequence.shape[1]

        embedded_input = self.embedding(input_sequence) # batch_size, seq_len, emb_dim
        
        c = context[-1] # batch_size, hidden_size
        c = torch.unsqueeze(c, dim=1) # batch_size, 1, hidden_size
        c = torch.cat((c, )*seq_len, dim=1)  #batch_size, seq_len, hidden_size
        
        z_concat = torch.cat((embedded_input, c), dim=2)
        sequence_output, hidden = self.rnn(z_concat, hidden)
        # sequence_ouput : batch_size, seq_len, hidden_size
        # hidden : num_layers, batch_size, hidden_size 

        output_concat = torch.cat((z_concat, sequence_output), dim=2)
        preds =  F.softmax(self.linear(output_concat), dim=2) #batch_size, sequence_len, vocab_size

        return preds, hidden



In [32]:
decodeur = Decodeur(150, 60, 2, 1000)

In [33]:
context = torch.randint(10, size=(2, 32, 60), dtype=torch.float32)
input_sequence = torch.randint(1000, size=(32, 11))
h_init = context

res = decodeur(input_sequence, h_init, context)
res[0].shape, res[1].shape

(torch.Size([32, 11, 1000]), torch.Size([2, 32, 60]))