In [359]:
import torchtext.vocab as tv
import torch as torch
import torch.nn as nn

In [360]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 
device = torch.device(dev)

# 1. Préparation des données

In [361]:
# Fonction pour récupérer les données d'un fichier
def open_file(file):
    texts = []
    emotions = []
    with open(file, newline = None) as f:
        for line in f:
            splited = line.split(";")
            texts.append(splited[0].split()) # Retire les espaces
            emotions.append(splited[1][:-1]) # Permet de retirer le \n en fin de ligne
    return texts, emotions

In [362]:
# Récupération des données 
train_texts, train_emotions = open_file("train.txt")

### Rembourrage et rognage des phrases 

In [363]:
def rembourrage_rognage(liste, n_final, char):
    if(len(liste) > n_final):
        return liste[:n_final]
    if(len(liste) < n_final):
        return liste + [char]*(n_final - len(liste))
    return liste

### Création du vocabulaire

In [364]:
def get_vocabs(texts, emotions):
    #vocabulaire text
    text_vocab = tv.build_vocab_from_iterator(iter(texts), specials = ["<unk>"])
    
    # Ajout index défaut
    unknown_id = text_vocab.forward(["<unk>"])
    text_vocab.set_default_index(unknown_id[0])
    
    #vocabulaire emotion
    emotion_vocab = tv.build_vocab_from_iterator(iter([[emotion] for emotion in emotions]))
    
    return text_vocab, emotion_vocab

def forward_vocab(texts, emotions, text_vocab, emotion_vocab):
    unknown_id = text_vocab.forward(["<unk>"])[0]
#     sentence_length = max([len(text) for text in texts])
    sentence_length = 15
    
    texts_id = [text_vocab.forward(text) for text in texts]
    texts_id = torch.tensor([rembourrage_rognage(text, sentence_length, unknown_id) for text in texts_id])
    
    emotions_id = emotion_vocab.forward(emotions)
    emotions_id = torch.tensor(emotions_id)
    return texts_id, emotions_id

In [365]:
# (il est possible d'oublier certains mots qui sont dans le data set de train et pas dans le dataset de test)
text_vocab, emotion_vocab = get_vocabs(train_texts, train_emotions)
train_texts_id, train_emotions_id = forward_vocab(train_texts, train_emotions, text_vocab, emotion_vocab)

print(len(text_vocab))

15213


### Conversion en one hot

In [263]:
# Ne fonctionne pas sur ma machine car pas assez de mémoire, on va plutot encoder une nouvelle fois à chaque batch
texts_one_hot = torch.nn.functional.one_hot(torch.tensor(texts_id), len(text_vocab))
texts_one_hot.size()

  texts_one_hot = torch.nn.functional.one_hot(torch.tensor(texts_id), len(vocab))


RuntimeError: [enforce fail at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 128519424000 bytes.

# 2. Architecture du réseau

In [376]:
# Adapté de https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial
import torch.nn as nn

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, emb_size):
        super(RNN, self).__init__()

        self.act = nn.ReLU()
        self.hidden_size = hidden_size
        self.i2e = nn.Linear(input_size, emb_size)
        self.i2h = nn.Linear(emb_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(emb_size + hidden_size, output_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input, hidden):
        embedded = self.i2e(input.float())
#         combined = torch.cat((input, hidden), 1)
        combined = torch.cat((embedded, hidden), 1)
        combined = self.act(combined)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self, batch_size):
        hidden = torch.zeros(batch_size, self.hidden_size)
        return hidden.to(device)

# 3. Préparation des batchs

In [273]:
# Construction du modèle
n_hidden = 128
n_categories = len(emotion_vocab)
n_words = len(text_vocab)
emb_size = 100
rnn = RNN(n_words, n_hidden, n_categories, emb_size)

# Test du modèle avec un seul mot
word = text_vocab.forward(["cat"])
word = torch.tensor(word)
word = torch.nn.functional.one_hot(word, len(text_vocab))
hidden = torch.zeros(1, n_hidden)
print("output : ", rnn.forward(word, hidden))

output :  (tensor([[-1.8118, -1.7736, -1.7852, -1.7635, -1.7650, -1.8545]],
       grad_fn=<LogSoftmaxBackward0>), tensor([[ 0.0150, -0.0220, -0.0057,  0.0171,  0.0417,  0.0316, -0.0133, -0.0637,
         -0.0425, -0.0549,  0.0423, -0.0502, -0.0308, -0.0556,  0.0100, -0.0539,
          0.0556,  0.0008,  0.0298, -0.0210,  0.0331,  0.0692, -0.0225, -0.0225,
          0.0676,  0.0201, -0.0125, -0.0298, -0.0620,  0.0578, -0.0011,  0.0487,
         -0.0470, -0.0161, -0.0609, -0.0472, -0.0196, -0.0406, -0.0429,  0.0258,
          0.0591,  0.0395, -0.0121,  0.0154, -0.0337,  0.0229, -0.0476, -0.0457,
         -0.0603, -0.0411,  0.0134,  0.0222, -0.0201, -0.0050, -0.0558,  0.0547,
         -0.0153, -0.0683, -0.0097, -0.0247, -0.0481, -0.0069, -0.0155, -0.0676,
          0.0486,  0.0568,  0.0521, -0.0302,  0.0447,  0.0465, -0.0010, -0.0071,
         -0.0266, -0.0197, -0.0349,  0.0272, -0.0371, -0.0046, -0.0485, -0.0597,
         -0.0069,  0.0232,  0.0127, -0.0457,  0.0479, -0.0506, -0.0371, -0.

# 4. Apprentissage du réseau

In [383]:
# Données sous la forme (batch_size, sentence_length, vocabulary_size)
def forward_model(model, X):
    hidden = model.initHidden(X.size()[0])
    optimizer.zero_grad()
    model.zero_grad()
    for i in range(X.size()[1]):
        output, hidden = model(X[:,i,:], hidden)
    return output

def train_batch(Y, X, model, learning_rate, criterion, optimizer):
    output = forward_model(model, X)

    max_index_X = torch.argmax(output, dim = 1)
    max_index_Y = torch.argmax(Y, dim = 1)
    acc = int(torch.sum(max_index_X==max_index_Y).item()/X.size()[0]*100)
    
    loss = criterion(output, Y)
    loss.backward()
    optimizer.step()
    # Add parameters' gradients to their values, multiplied by learning rate
    for p in rnn.parameters():
        p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, int(loss.item()*100)/100, acc

def train(model, batch_size, epochs, X_id, Y_id, nXvocab, nYvocab, learning_rate, criterion, optimizer, early_stop = None):
    print(X_id.size())
    n_batch = len(X_id) // batch_size
    
    min_loss = float("inf")
    early_stop_counter = 0
    for epoch in range(epochs):
        for batch in range(n_batch):
            XBatch =  X_id[batch*batch_size:(batch+1)*batch_size]
            XBatch = torch.nn.functional.one_hot(XBatch, nXvocab)
            
            YBatch = Y_id[batch*batch_size:(batch+1)*batch_size]
            YBatch = torch.nn.functional.one_hot(YBatch, nYvocab)
            
            output, loss, acc = train_batch(YBatch.to(device).to(torch.float32), XBatch.to(device).to(torch.float32), model, learning_rate, criterion, optimizer)
            
            if(early_stop):
                if(loss > min_loss):
                    early_stop_counter += 1
                    if(early_stop_counter >= early_stop):
                        return
                else:
                    early_stop_counter = 0
                    min_loss = loss
            
#             if(batch == 0):
            print(f'Epoch: {epoch+1}/{epochs}, Batch: {batch+1}/{n_batch}, Loss: {loss}, Accuracy: {acc}%')

In [None]:

batch_size = 16
epochs = 30
n_data = 10000

n_hidden = 128
n_categories = len(emotion_vocab)
n_words = len(text_vocab)
emb_size = 8
rnn = RNN(n_words, n_hidden, n_categories, emb_size)

learning_rate = 0.01
early_stop = 200
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.0001)

train(rnn, batch_size, epochs, train_texts_id[:n_data], train_emotions_id[:n_data], len(text_vocab), len(emotion_vocab), learning_rate, criterion, optimizer)


torch.Size([10000, 15])
Epoch: 1/30, Batch: 1/625, Loss: 1.79, Accuracy: 31%
Epoch: 1/30, Batch: 2/625, Loss: 1.79, Accuracy: 37%
Epoch: 1/30, Batch: 3/625, Loss: 1.79, Accuracy: 12%
Epoch: 1/30, Batch: 4/625, Loss: 1.79, Accuracy: 18%
Epoch: 1/30, Batch: 5/625, Loss: 1.78, Accuracy: 37%
Epoch: 1/30, Batch: 6/625, Loss: 1.79, Accuracy: 25%
Epoch: 1/30, Batch: 7/625, Loss: 1.79, Accuracy: 18%
Epoch: 1/30, Batch: 8/625, Loss: 1.79, Accuracy: 25%
Epoch: 1/30, Batch: 9/625, Loss: 1.79, Accuracy: 12%
Epoch: 1/30, Batch: 10/625, Loss: 1.79, Accuracy: 18%
Epoch: 1/30, Batch: 11/625, Loss: 1.79, Accuracy: 18%
Epoch: 1/30, Batch: 12/625, Loss: 1.78, Accuracy: 18%
Epoch: 1/30, Batch: 13/625, Loss: 1.78, Accuracy: 25%
Epoch: 1/30, Batch: 14/625, Loss: 1.79, Accuracy: 31%
Epoch: 1/30, Batch: 15/625, Loss: 1.79, Accuracy: 12%
Epoch: 1/30, Batch: 16/625, Loss: 1.78, Accuracy: 43%
Epoch: 1/30, Batch: 17/625, Loss: 1.78, Accuracy: 43%
Epoch: 1/30, Batch: 18/625, Loss: 1.78, Accuracy: 43%
Epoch: 1/30, 

Epoch: 1/30, Batch: 153/625, Loss: 1.7, Accuracy: 37%
Epoch: 1/30, Batch: 154/625, Loss: 1.62, Accuracy: 37%
Epoch: 1/30, Batch: 155/625, Loss: 1.68, Accuracy: 31%
Epoch: 1/30, Batch: 156/625, Loss: 1.61, Accuracy: 37%
Epoch: 1/30, Batch: 157/625, Loss: 1.58, Accuracy: 50%
Epoch: 1/30, Batch: 158/625, Loss: 1.67, Accuracy: 37%
Epoch: 1/30, Batch: 159/625, Loss: 1.81, Accuracy: 18%
Epoch: 1/30, Batch: 160/625, Loss: 1.48, Accuracy: 62%
Epoch: 1/30, Batch: 161/625, Loss: 1.71, Accuracy: 25%
Epoch: 1/30, Batch: 162/625, Loss: 1.67, Accuracy: 25%
Epoch: 1/30, Batch: 163/625, Loss: 1.7, Accuracy: 18%
Epoch: 1/30, Batch: 164/625, Loss: 1.7, Accuracy: 25%
Epoch: 1/30, Batch: 165/625, Loss: 1.55, Accuracy: 31%
Epoch: 1/30, Batch: 166/625, Loss: 1.58, Accuracy: 43%
Epoch: 1/30, Batch: 167/625, Loss: 1.58, Accuracy: 43%
Epoch: 1/30, Batch: 168/625, Loss: 1.61, Accuracy: 37%
Epoch: 1/30, Batch: 169/625, Loss: 1.69, Accuracy: 18%
Epoch: 1/30, Batch: 170/625, Loss: 1.7, Accuracy: 31%
Epoch: 1/30, B

Epoch: 1/30, Batch: 303/625, Loss: 1.69, Accuracy: 43%
Epoch: 1/30, Batch: 304/625, Loss: 1.66, Accuracy: 37%
Epoch: 1/30, Batch: 305/625, Loss: 1.58, Accuracy: 50%
Epoch: 1/30, Batch: 306/625, Loss: 1.54, Accuracy: 43%
Epoch: 1/30, Batch: 307/625, Loss: 1.78, Accuracy: 18%
Epoch: 1/30, Batch: 308/625, Loss: 1.69, Accuracy: 25%
Epoch: 1/30, Batch: 309/625, Loss: 1.72, Accuracy: 31%
Epoch: 1/30, Batch: 310/625, Loss: 1.69, Accuracy: 12%
Epoch: 1/30, Batch: 311/625, Loss: 1.61, Accuracy: 43%
Epoch: 1/30, Batch: 312/625, Loss: 1.62, Accuracy: 25%
Epoch: 1/30, Batch: 313/625, Loss: 1.68, Accuracy: 18%
Epoch: 1/30, Batch: 314/625, Loss: 1.69, Accuracy: 43%
Epoch: 1/30, Batch: 315/625, Loss: 1.75, Accuracy: 31%
Epoch: 1/30, Batch: 316/625, Loss: 1.64, Accuracy: 37%
Epoch: 1/30, Batch: 317/625, Loss: 1.61, Accuracy: 31%
Epoch: 1/30, Batch: 318/625, Loss: 1.68, Accuracy: 25%
Epoch: 1/30, Batch: 319/625, Loss: 1.63, Accuracy: 43%
Epoch: 1/30, Batch: 320/625, Loss: 1.66, Accuracy: 37%
Epoch: 1/3

Epoch: 1/30, Batch: 457/625, Loss: 1.66, Accuracy: 37%
Epoch: 1/30, Batch: 458/625, Loss: 1.76, Accuracy: 18%
Epoch: 1/30, Batch: 459/625, Loss: 1.64, Accuracy: 43%
Epoch: 1/30, Batch: 460/625, Loss: 1.74, Accuracy: 25%
Epoch: 1/30, Batch: 461/625, Loss: 1.57, Accuracy: 50%
Epoch: 1/30, Batch: 462/625, Loss: 1.62, Accuracy: 37%
Epoch: 1/30, Batch: 463/625, Loss: 1.71, Accuracy: 18%
Epoch: 1/30, Batch: 464/625, Loss: 1.63, Accuracy: 31%
Epoch: 1/30, Batch: 465/625, Loss: 1.77, Accuracy: 18%
Epoch: 1/30, Batch: 466/625, Loss: 1.6, Accuracy: 50%
Epoch: 1/30, Batch: 467/625, Loss: 1.58, Accuracy: 68%
Epoch: 1/30, Batch: 468/625, Loss: 1.69, Accuracy: 25%
Epoch: 1/30, Batch: 469/625, Loss: 1.74, Accuracy: 18%
Epoch: 1/30, Batch: 470/625, Loss: 1.76, Accuracy: 25%
Epoch: 1/30, Batch: 471/625, Loss: 1.6, Accuracy: 43%
Epoch: 1/30, Batch: 472/625, Loss: 1.75, Accuracy: 18%
Epoch: 1/30, Batch: 473/625, Loss: 1.64, Accuracy: 50%
Epoch: 1/30, Batch: 474/625, Loss: 1.76, Accuracy: 12%
Epoch: 1/30,

Epoch: 1/30, Batch: 610/625, Loss: 1.67, Accuracy: 25%
Epoch: 1/30, Batch: 611/625, Loss: 1.67, Accuracy: 37%
Epoch: 1/30, Batch: 612/625, Loss: 1.7, Accuracy: 31%
Epoch: 1/30, Batch: 613/625, Loss: 1.6, Accuracy: 43%
Epoch: 1/30, Batch: 614/625, Loss: 1.73, Accuracy: 25%
Epoch: 1/30, Batch: 615/625, Loss: 1.71, Accuracy: 25%
Epoch: 1/30, Batch: 616/625, Loss: 1.69, Accuracy: 37%
Epoch: 1/30, Batch: 617/625, Loss: 1.65, Accuracy: 43%
Epoch: 1/30, Batch: 618/625, Loss: 1.62, Accuracy: 43%
Epoch: 1/30, Batch: 619/625, Loss: 1.84, Accuracy: 6%
Epoch: 1/30, Batch: 620/625, Loss: 1.73, Accuracy: 25%
Epoch: 1/30, Batch: 621/625, Loss: 1.72, Accuracy: 31%
Epoch: 1/30, Batch: 622/625, Loss: 1.67, Accuracy: 31%
Epoch: 1/30, Batch: 623/625, Loss: 1.65, Accuracy: 43%
Epoch: 1/30, Batch: 624/625, Loss: 1.67, Accuracy: 31%
Epoch: 1/30, Batch: 625/625, Loss: 1.67, Accuracy: 31%
Epoch: 2/30, Batch: 1/625, Loss: 1.71, Accuracy: 25%
Epoch: 2/30, Batch: 2/625, Loss: 1.68, Accuracy: 25%
Epoch: 2/30, Batc

Epoch: 2/30, Batch: 138/625, Loss: 1.66, Accuracy: 37%
Epoch: 2/30, Batch: 139/625, Loss: 1.79, Accuracy: 6%
Epoch: 2/30, Batch: 140/625, Loss: 1.76, Accuracy: 31%
Epoch: 2/30, Batch: 141/625, Loss: 1.61, Accuracy: 37%
Epoch: 2/30, Batch: 142/625, Loss: 1.69, Accuracy: 37%
Epoch: 2/30, Batch: 143/625, Loss: 1.65, Accuracy: 43%
Epoch: 2/30, Batch: 144/625, Loss: 1.67, Accuracy: 31%
Epoch: 2/30, Batch: 145/625, Loss: 1.6, Accuracy: 43%
Epoch: 2/30, Batch: 146/625, Loss: 1.72, Accuracy: 31%
Epoch: 2/30, Batch: 147/625, Loss: 1.78, Accuracy: 18%
Epoch: 2/30, Batch: 148/625, Loss: 1.66, Accuracy: 25%
Epoch: 2/30, Batch: 149/625, Loss: 1.69, Accuracy: 37%
Epoch: 2/30, Batch: 150/625, Loss: 1.66, Accuracy: 37%
Epoch: 2/30, Batch: 151/625, Loss: 1.76, Accuracy: 12%
Epoch: 2/30, Batch: 152/625, Loss: 1.62, Accuracy: 50%
Epoch: 2/30, Batch: 153/625, Loss: 1.71, Accuracy: 37%
Epoch: 2/30, Batch: 154/625, Loss: 1.61, Accuracy: 37%
Epoch: 2/30, Batch: 155/625, Loss: 1.67, Accuracy: 31%
Epoch: 2/30,

Epoch: 2/30, Batch: 291/625, Loss: 1.7, Accuracy: 31%
Epoch: 2/30, Batch: 292/625, Loss: 1.74, Accuracy: 18%
Epoch: 2/30, Batch: 293/625, Loss: 1.68, Accuracy: 18%
Epoch: 2/30, Batch: 294/625, Loss: 1.67, Accuracy: 37%
Epoch: 2/30, Batch: 295/625, Loss: 1.73, Accuracy: 31%
Epoch: 2/30, Batch: 296/625, Loss: 1.63, Accuracy: 18%
Epoch: 2/30, Batch: 297/625, Loss: 1.64, Accuracy: 37%
Epoch: 2/30, Batch: 298/625, Loss: 1.79, Accuracy: 18%
Epoch: 2/30, Batch: 299/625, Loss: 1.68, Accuracy: 12%
Epoch: 2/30, Batch: 300/625, Loss: 1.7, Accuracy: 31%
Epoch: 2/30, Batch: 301/625, Loss: 1.71, Accuracy: 12%
Epoch: 2/30, Batch: 302/625, Loss: 1.7, Accuracy: 43%
Epoch: 2/30, Batch: 303/625, Loss: 1.69, Accuracy: 43%
Epoch: 2/30, Batch: 304/625, Loss: 1.66, Accuracy: 37%
Epoch: 2/30, Batch: 305/625, Loss: 1.59, Accuracy: 50%
Epoch: 2/30, Batch: 306/625, Loss: 1.53, Accuracy: 43%
Epoch: 2/30, Batch: 307/625, Loss: 1.77, Accuracy: 18%
Epoch: 2/30, Batch: 308/625, Loss: 1.69, Accuracy: 25%
Epoch: 2/30, 

Epoch: 2/30, Batch: 445/625, Loss: 1.63, Accuracy: 37%
Epoch: 2/30, Batch: 446/625, Loss: 1.59, Accuracy: 50%
Epoch: 2/30, Batch: 447/625, Loss: 1.57, Accuracy: 43%
Epoch: 2/30, Batch: 448/625, Loss: 1.69, Accuracy: 25%
Epoch: 2/30, Batch: 449/625, Loss: 1.72, Accuracy: 25%
Epoch: 2/30, Batch: 450/625, Loss: 1.66, Accuracy: 37%
Epoch: 2/30, Batch: 451/625, Loss: 1.76, Accuracy: 18%
Epoch: 2/30, Batch: 452/625, Loss: 1.6, Accuracy: 43%
Epoch: 2/30, Batch: 453/625, Loss: 1.7, Accuracy: 31%
Epoch: 2/30, Batch: 454/625, Loss: 1.77, Accuracy: 31%
Epoch: 2/30, Batch: 455/625, Loss: 1.66, Accuracy: 37%
Epoch: 2/30, Batch: 456/625, Loss: 1.59, Accuracy: 43%
Epoch: 2/30, Batch: 457/625, Loss: 1.66, Accuracy: 37%
Epoch: 2/30, Batch: 458/625, Loss: 1.76, Accuracy: 18%
Epoch: 2/30, Batch: 459/625, Loss: 1.64, Accuracy: 43%
Epoch: 2/30, Batch: 460/625, Loss: 1.74, Accuracy: 25%
Epoch: 2/30, Batch: 461/625, Loss: 1.57, Accuracy: 50%
Epoch: 2/30, Batch: 462/625, Loss: 1.62, Accuracy: 37%
Epoch: 2/30,

Epoch: 2/30, Batch: 597/625, Loss: 1.64, Accuracy: 25%
Epoch: 2/30, Batch: 598/625, Loss: 1.7, Accuracy: 25%
Epoch: 2/30, Batch: 599/625, Loss: 1.67, Accuracy: 43%
Epoch: 2/30, Batch: 600/625, Loss: 1.8, Accuracy: 12%
Epoch: 2/30, Batch: 601/625, Loss: 1.6, Accuracy: 50%
Epoch: 2/30, Batch: 602/625, Loss: 1.71, Accuracy: 18%
Epoch: 2/30, Batch: 603/625, Loss: 1.67, Accuracy: 31%
Epoch: 2/30, Batch: 604/625, Loss: 1.68, Accuracy: 18%
Epoch: 2/30, Batch: 605/625, Loss: 1.73, Accuracy: 37%
Epoch: 2/30, Batch: 606/625, Loss: 1.54, Accuracy: 50%
Epoch: 2/30, Batch: 607/625, Loss: 1.64, Accuracy: 37%
Epoch: 2/30, Batch: 608/625, Loss: 1.62, Accuracy: 31%
Epoch: 2/30, Batch: 609/625, Loss: 1.58, Accuracy: 37%
Epoch: 2/30, Batch: 610/625, Loss: 1.67, Accuracy: 25%
Epoch: 2/30, Batch: 611/625, Loss: 1.67, Accuracy: 37%
Epoch: 2/30, Batch: 612/625, Loss: 1.7, Accuracy: 31%
Epoch: 2/30, Batch: 613/625, Loss: 1.6, Accuracy: 43%
Epoch: 2/30, Batch: 614/625, Loss: 1.73, Accuracy: 25%
Epoch: 2/30, Ba

Epoch: 3/30, Batch: 125/625, Loss: 1.7, Accuracy: 31%
Epoch: 3/30, Batch: 126/625, Loss: 1.75, Accuracy: 18%
Epoch: 3/30, Batch: 127/625, Loss: 1.62, Accuracy: 43%
Epoch: 3/30, Batch: 128/625, Loss: 1.66, Accuracy: 37%
Epoch: 3/30, Batch: 129/625, Loss: 1.6, Accuracy: 43%
Epoch: 3/30, Batch: 130/625, Loss: 1.75, Accuracy: 18%
Epoch: 3/30, Batch: 131/625, Loss: 1.69, Accuracy: 12%
Epoch: 3/30, Batch: 132/625, Loss: 1.69, Accuracy: 37%
Epoch: 3/30, Batch: 133/625, Loss: 1.65, Accuracy: 43%
Epoch: 3/30, Batch: 134/625, Loss: 1.61, Accuracy: 37%
Epoch: 3/30, Batch: 135/625, Loss: 1.81, Accuracy: 12%
Epoch: 3/30, Batch: 136/625, Loss: 1.66, Accuracy: 37%
Epoch: 3/30, Batch: 137/625, Loss: 1.61, Accuracy: 37%
Epoch: 3/30, Batch: 138/625, Loss: 1.66, Accuracy: 37%
Epoch: 3/30, Batch: 139/625, Loss: 1.79, Accuracy: 6%
Epoch: 3/30, Batch: 140/625, Loss: 1.75, Accuracy: 31%
Epoch: 3/30, Batch: 141/625, Loss: 1.61, Accuracy: 37%
Epoch: 3/30, Batch: 142/625, Loss: 1.69, Accuracy: 37%
Epoch: 3/30, 

Epoch: 3/30, Batch: 277/625, Loss: 1.55, Accuracy: 50%
Epoch: 3/30, Batch: 278/625, Loss: 1.62, Accuracy: 31%
Epoch: 3/30, Batch: 279/625, Loss: 1.76, Accuracy: 25%
Epoch: 3/30, Batch: 280/625, Loss: 1.66, Accuracy: 37%
Epoch: 3/30, Batch: 281/625, Loss: 1.55, Accuracy: 56%
Epoch: 3/30, Batch: 282/625, Loss: 1.7, Accuracy: 25%
Epoch: 3/30, Batch: 283/625, Loss: 1.72, Accuracy: 25%
Epoch: 3/30, Batch: 284/625, Loss: 1.55, Accuracy: 50%
Epoch: 3/30, Batch: 285/625, Loss: 1.74, Accuracy: 25%
Epoch: 3/30, Batch: 286/625, Loss: 1.69, Accuracy: 18%
Epoch: 3/30, Batch: 287/625, Loss: 1.64, Accuracy: 43%
Epoch: 3/30, Batch: 288/625, Loss: 1.7, Accuracy: 31%
Epoch: 3/30, Batch: 289/625, Loss: 1.56, Accuracy: 31%
Epoch: 3/30, Batch: 290/625, Loss: 1.75, Accuracy: 18%
Epoch: 3/30, Batch: 291/625, Loss: 1.7, Accuracy: 31%
Epoch: 3/30, Batch: 292/625, Loss: 1.74, Accuracy: 18%
Epoch: 3/30, Batch: 293/625, Loss: 1.69, Accuracy: 18%
Epoch: 3/30, Batch: 294/625, Loss: 1.67, Accuracy: 25%
Epoch: 3/30, 

Epoch: 3/30, Batch: 430/625, Loss: 1.66, Accuracy: 37%
Epoch: 3/30, Batch: 431/625, Loss: 1.67, Accuracy: 31%
Epoch: 3/30, Batch: 432/625, Loss: 1.57, Accuracy: 56%
Epoch: 3/30, Batch: 433/625, Loss: 1.73, Accuracy: 25%
Epoch: 3/30, Batch: 434/625, Loss: 1.71, Accuracy: 25%
Epoch: 3/30, Batch: 435/625, Loss: 1.73, Accuracy: 25%
Epoch: 3/30, Batch: 436/625, Loss: 1.71, Accuracy: 25%
Epoch: 3/30, Batch: 437/625, Loss: 1.59, Accuracy: 56%
Epoch: 3/30, Batch: 438/625, Loss: 1.55, Accuracy: 62%
Epoch: 3/30, Batch: 439/625, Loss: 1.68, Accuracy: 31%
Epoch: 3/30, Batch: 440/625, Loss: 1.6, Accuracy: 37%
Epoch: 3/30, Batch: 441/625, Loss: 1.71, Accuracy: 31%
Epoch: 3/30, Batch: 442/625, Loss: 1.52, Accuracy: 56%
Epoch: 3/30, Batch: 443/625, Loss: 1.79, Accuracy: 18%
Epoch: 3/30, Batch: 444/625, Loss: 1.73, Accuracy: 18%
Epoch: 3/30, Batch: 445/625, Loss: 1.63, Accuracy: 37%
Epoch: 3/30, Batch: 446/625, Loss: 1.59, Accuracy: 50%
Epoch: 3/30, Batch: 447/625, Loss: 1.57, Accuracy: 43%
Epoch: 3/30

Epoch: 3/30, Batch: 580/625, Loss: 1.67, Accuracy: 31%
Epoch: 3/30, Batch: 581/625, Loss: 1.83, Accuracy: 12%
Epoch: 3/30, Batch: 582/625, Loss: 1.72, Accuracy: 31%
Epoch: 3/30, Batch: 583/625, Loss: 1.62, Accuracy: 31%
Epoch: 3/30, Batch: 584/625, Loss: 1.69, Accuracy: 37%
Epoch: 3/30, Batch: 585/625, Loss: 1.61, Accuracy: 50%
Epoch: 3/30, Batch: 586/625, Loss: 1.63, Accuracy: 37%
Epoch: 3/30, Batch: 587/625, Loss: 1.78, Accuracy: 25%
Epoch: 3/30, Batch: 588/625, Loss: 1.63, Accuracy: 37%
Epoch: 3/30, Batch: 589/625, Loss: 1.7, Accuracy: 31%
Epoch: 3/30, Batch: 590/625, Loss: 1.69, Accuracy: 25%
Epoch: 3/30, Batch: 591/625, Loss: 1.64, Accuracy: 43%
Epoch: 3/30, Batch: 592/625, Loss: 1.66, Accuracy: 18%
Epoch: 3/30, Batch: 593/625, Loss: 1.69, Accuracy: 18%
Epoch: 3/30, Batch: 594/625, Loss: 1.71, Accuracy: 25%
Epoch: 3/30, Batch: 595/625, Loss: 1.65, Accuracy: 25%
Epoch: 3/30, Batch: 596/625, Loss: 1.63, Accuracy: 43%
Epoch: 3/30, Batch: 597/625, Loss: 1.65, Accuracy: 25%
Epoch: 3/30

Epoch: 4/30, Batch: 111/625, Loss: 1.65, Accuracy: 37%
Epoch: 4/30, Batch: 112/625, Loss: 1.61, Accuracy: 50%
Epoch: 4/30, Batch: 113/625, Loss: 1.53, Accuracy: 50%
Epoch: 4/30, Batch: 114/625, Loss: 1.68, Accuracy: 25%
Epoch: 4/30, Batch: 115/625, Loss: 1.71, Accuracy: 25%
Epoch: 4/30, Batch: 116/625, Loss: 1.67, Accuracy: 31%
Epoch: 4/30, Batch: 117/625, Loss: 1.7, Accuracy: 37%
Epoch: 4/30, Batch: 118/625, Loss: 1.59, Accuracy: 50%
Epoch: 4/30, Batch: 119/625, Loss: 1.68, Accuracy: 25%
Epoch: 4/30, Batch: 120/625, Loss: 1.67, Accuracy: 31%
Epoch: 4/30, Batch: 121/625, Loss: 1.59, Accuracy: 50%
Epoch: 4/30, Batch: 122/625, Loss: 1.72, Accuracy: 25%
Epoch: 4/30, Batch: 123/625, Loss: 1.69, Accuracy: 25%
Epoch: 4/30, Batch: 124/625, Loss: 1.63, Accuracy: 37%
Epoch: 4/30, Batch: 125/625, Loss: 1.69, Accuracy: 31%
Epoch: 4/30, Batch: 126/625, Loss: 1.76, Accuracy: 18%
Epoch: 4/30, Batch: 127/625, Loss: 1.61, Accuracy: 43%
Epoch: 4/30, Batch: 128/625, Loss: 1.65, Accuracy: 37%
Epoch: 4/30

Epoch: 4/30, Batch: 264/625, Loss: 1.72, Accuracy: 31%
Epoch: 4/30, Batch: 265/625, Loss: 1.56, Accuracy: 50%
Epoch: 4/30, Batch: 266/625, Loss: 1.73, Accuracy: 25%
Epoch: 4/30, Batch: 267/625, Loss: 1.55, Accuracy: 50%
Epoch: 4/30, Batch: 268/625, Loss: 1.75, Accuracy: 18%
Epoch: 4/30, Batch: 269/625, Loss: 1.6, Accuracy: 37%
Epoch: 4/30, Batch: 270/625, Loss: 1.71, Accuracy: 25%
Epoch: 4/30, Batch: 271/625, Loss: 1.8, Accuracy: 18%
Epoch: 4/30, Batch: 272/625, Loss: 1.62, Accuracy: 43%
Epoch: 4/30, Batch: 273/625, Loss: 1.65, Accuracy: 37%
Epoch: 4/30, Batch: 274/625, Loss: 1.63, Accuracy: 37%
Epoch: 4/30, Batch: 275/625, Loss: 1.63, Accuracy: 37%
Epoch: 4/30, Batch: 276/625, Loss: 1.59, Accuracy: 43%
Epoch: 4/30, Batch: 277/625, Loss: 1.53, Accuracy: 50%
Epoch: 4/30, Batch: 278/625, Loss: 1.63, Accuracy: 31%
Epoch: 4/30, Batch: 279/625, Loss: 1.77, Accuracy: 25%
Epoch: 4/30, Batch: 280/625, Loss: 1.64, Accuracy: 37%
Epoch: 4/30, Batch: 281/625, Loss: 1.52, Accuracy: 56%
Epoch: 4/30,

Epoch: 4/30, Batch: 418/625, Loss: 1.8, Accuracy: 12%
Epoch: 4/30, Batch: 419/625, Loss: 1.69, Accuracy: 50%
Epoch: 4/30, Batch: 420/625, Loss: 1.5, Accuracy: 50%
Epoch: 4/30, Batch: 421/625, Loss: 1.6, Accuracy: 37%
Epoch: 4/30, Batch: 422/625, Loss: 1.57, Accuracy: 43%
Epoch: 4/30, Batch: 423/625, Loss: 1.45, Accuracy: 43%
Epoch: 4/30, Batch: 424/625, Loss: 1.58, Accuracy: 50%
Epoch: 4/30, Batch: 425/625, Loss: 1.85, Accuracy: 12%
Epoch: 4/30, Batch: 426/625, Loss: 1.59, Accuracy: 37%
Epoch: 4/30, Batch: 427/625, Loss: 1.7, Accuracy: 43%
Epoch: 4/30, Batch: 428/625, Loss: 1.78, Accuracy: 25%
Epoch: 4/30, Batch: 429/625, Loss: 1.63, Accuracy: 43%
Epoch: 4/30, Batch: 430/625, Loss: 1.57, Accuracy: 37%
Epoch: 4/30, Batch: 431/625, Loss: 1.61, Accuracy: 50%
Epoch: 4/30, Batch: 432/625, Loss: 1.44, Accuracy: 56%
Epoch: 4/30, Batch: 433/625, Loss: 1.69, Accuracy: 25%
Epoch: 4/30, Batch: 434/625, Loss: 1.65, Accuracy: 43%
Epoch: 4/30, Batch: 435/625, Loss: 1.69, Accuracy: 37%
Epoch: 4/30, B

Epoch: 4/30, Batch: 572/625, Loss: 1.62, Accuracy: 37%
Epoch: 4/30, Batch: 573/625, Loss: 1.67, Accuracy: 31%
Epoch: 4/30, Batch: 574/625, Loss: 1.65, Accuracy: 31%
Epoch: 4/30, Batch: 575/625, Loss: 1.58, Accuracy: 62%
Epoch: 4/30, Batch: 576/625, Loss: 1.64, Accuracy: 31%
Epoch: 4/30, Batch: 577/625, Loss: 1.74, Accuracy: 31%
Epoch: 4/30, Batch: 578/625, Loss: 1.64, Accuracy: 56%
Epoch: 4/30, Batch: 579/625, Loss: 1.53, Accuracy: 50%
Epoch: 4/30, Batch: 580/625, Loss: 1.63, Accuracy: 50%
Epoch: 4/30, Batch: 581/625, Loss: 1.82, Accuracy: 12%
Epoch: 4/30, Batch: 582/625, Loss: 1.67, Accuracy: 43%
Epoch: 4/30, Batch: 583/625, Loss: 1.56, Accuracy: 31%
Epoch: 4/30, Batch: 584/625, Loss: 1.67, Accuracy: 43%
Epoch: 4/30, Batch: 585/625, Loss: 1.53, Accuracy: 56%
Epoch: 4/30, Batch: 586/625, Loss: 1.63, Accuracy: 37%
Epoch: 4/30, Batch: 587/625, Loss: 1.71, Accuracy: 25%
Epoch: 4/30, Batch: 588/625, Loss: 1.58, Accuracy: 56%
Epoch: 4/30, Batch: 589/625, Loss: 1.63, Accuracy: 31%
Epoch: 4/3

Epoch: 5/30, Batch: 102/625, Loss: 1.74, Accuracy: 37%
Epoch: 5/30, Batch: 103/625, Loss: 1.71, Accuracy: 50%
Epoch: 5/30, Batch: 104/625, Loss: 1.51, Accuracy: 56%
Epoch: 5/30, Batch: 105/625, Loss: 1.67, Accuracy: 31%
Epoch: 5/30, Batch: 106/625, Loss: 1.64, Accuracy: 43%
Epoch: 5/30, Batch: 107/625, Loss: 1.56, Accuracy: 56%
Epoch: 5/30, Batch: 108/625, Loss: 1.62, Accuracy: 37%
Epoch: 5/30, Batch: 109/625, Loss: 1.67, Accuracy: 37%
Epoch: 5/30, Batch: 110/625, Loss: 1.66, Accuracy: 50%
Epoch: 5/30, Batch: 111/625, Loss: 1.57, Accuracy: 50%
Epoch: 5/30, Batch: 112/625, Loss: 1.55, Accuracy: 56%
Epoch: 5/30, Batch: 113/625, Loss: 1.47, Accuracy: 56%
Epoch: 5/30, Batch: 114/625, Loss: 1.69, Accuracy: 25%
Epoch: 5/30, Batch: 115/625, Loss: 1.65, Accuracy: 50%
Epoch: 5/30, Batch: 116/625, Loss: 1.62, Accuracy: 37%
Epoch: 5/30, Batch: 117/625, Loss: 1.64, Accuracy: 43%
Epoch: 5/30, Batch: 118/625, Loss: 1.49, Accuracy: 56%
Epoch: 5/30, Batch: 119/625, Loss: 1.67, Accuracy: 31%
Epoch: 5/3

# 5. Test du modèle

In [340]:
# Importation des données
test_texts, test_emotions = open_file("test.txt")
test_texts_id, test_emotions_id = forward_vocab(test_texts, test_emotions, text_vocab, emotion_vocab)


In [385]:
for i in range(test_texts_id.size()[0]):
    sentence = test_texts_id[i, :]
    sentence_one_hot = torch.nn.functional.one_hot(sentence,  len(text_vocab))
    output = forward_model(rnn, sentence_one_hot[None, ])
    if(i< 10):
        print(output)

tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
