In [1]:
import torch
from torch import nn
from torch.nn import NLLLoss
from torch.optim import Adam, SparseAdam
import numpy as np
import pickle
from icecream import ic
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class AMORE (nn.Module):
    def __init__(self):
        super(AMORE, self).__init__()
        self.epochs = 0
        self.word_input_size = 300
        self.speaker_input_size = 65
        self.speaker_embedding_size = 134
        self.lstm_hidden_size = 459        
        
        self.SpeakerEmbed = nn.Embedding(self.speaker_input_size+1, self.speaker_embedding_size, padding_idx = self.speaker_input_size)
        self.Dropout1 = nn.Dropout(8e-3)
        self.Activation = nn.Tanh()
        self.LSTM = nn.LSTM(input_size = self.word_input_size + self.speaker_embedding_size, hidden_size = self.lstm_hidden_size, bidirectional=True)
        self.Dropout2 = nn.Dropout(13e-4)
        self.Classifier = nn.Linear(2*self.lstm_hidden_size, self.speaker_embedding_size)
        self.Cos = nn.CosineSimilarity(dim = 2)
        self.Softmax = nn.LogSoftmax(dim = 1)
        
    def forward(self, x):
        x0 = [i[0] for i in x[0]]
        # print(len(x0), len(x0[0]))
        # print(type(x0), type(x0[0]))
        words = torch.stack(x0)
        words = words.to(device)
        speak_embed = torch.sum(self.SpeakerEmbed(torch.LongTensor(x[1]).to(device)),dim=1)
        # print("speek", speak_embed.size())
        embed = torch.cat((words,speak_embed),dim=1)
        embed = self.Dropout1(embed)
        embed = self.Activation(embed)
        h1, _= self.LSTM(embed)
        h1 = self.Dropout2(h1)
        # print(type(h1[0]), type(h1[1]))
        # print(len(h1[0]), len(h1[1]))
        # print(len(x0))
        # print(h1[0].size())
        # print(h1[1][0].size())
        # print(h1[1][1].size())
        e1 = self.Classifier(h1)
        o1 = self.Cos(e1.unsqueeze(1),self.SpeakerEmbed.weight)
        o2 = self.Softmax(o1)
        return o2
        

In [3]:
with open("../output/train_padded.pkl", "rb") as f:
    train_loader = pickle.load(f)
with open("../output/test_padded.pkl", "rb") as f:
    test_loader = pickle.load(f)
with open("../output/val_padded.pkl", "rb") as f:
    val_loader = pickle.load(f)

In [4]:
# model initialization

model = AMORE()
model.to(device)
loss_func = NLLLoss()
# optimizer = Adam(model.parameters(), 5e-4)
optimizer = Adam(model.parameters(), 5e-4)

In [5]:
# training loop

num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    model.epochs += 1
    train_loss = 0
    count = 0
    print("Epoch:", model.epochs)
    # for sent in tqdm(train_loader):
    for sent in train_loader:
        model.zero_grad()

        X, Y = sent
        
        y = model(X)

        loss = loss_func(y, torch.LongTensor(Y).to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        # count += 1
        # if count%100 == 0:
        #     print(f"Sample No {count}:", loss.item())
        
        
    train_loss /= len(train_loader)
    print("Train Loss:", train_loss)
    
    
    # validation
    
    model.eval()
    val_loss = 0
    targets = []
    preds = []
    for sent in val_loader:
        X, Y = sent
        targets+=[i.tolist() for i in Y]
        y = model(X)
        preds += torch.argmax(y, dim=1).to('cpu')
        loss = loss_func(y, torch.LongTensor(Y).to(device))
        val_loss += loss.item()
    val_loss /= len(val_loader)
    ic("Validation Loss:", val_loss)
    # print(targets)
    # print(preds)
    ic("F1:", f1_score(targets, preds, average='macro'))

Epoch: 1
Train Loss: 3.238766387462616


ic| 'Validation Loss:', val_loss: 3.1485816662128157
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 2
Train Loss: 3.0604518032073975


ic| 'Validation Loss:', val_loss: 2.9787194435413067
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 3
Train Loss: 2.9093290820121767


ic| 'Validation Loss:', val_loss: 2.8482879638671874
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 4
Train Loss: 2.7994093880653383


ic| 'Validation Loss:', val_loss: 2.7562070259681115
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 5
Train Loss: 2.7213806071281432


ic| 'Validation Loss:', val_loss: 2.6897900654719424
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 6
Train Loss: 2.6650792541503905


ic| 'Validation Loss:', val_loss: 2.641460290321937
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 7
Train Loss: 2.623841689109802


ic| 'Validation Loss:', val_loss: 2.605931050960834
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 8
Train Loss: 2.593778124332428


ic| 'Validation Loss:', val_loss: 2.5801412178919865
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 9
Train Loss: 2.572117781639099


ic| 'Validation Loss:', val_loss: 2.5617798034961408
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 10
Train Loss: 2.5568638896942137


ic| 'Validation Loss:', val_loss: 2.548923892241258
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 11
Train Loss: 2.5463603281974794


ic| 'Validation Loss:', val_loss: 2.540395711018489
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 12
Train Loss: 2.5391981711387634


ic| 'Validation Loss:', val_loss: 2.534191898199228
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 13
Train Loss: 2.5344146490097046


ic| 'Validation Loss:', val_loss: 2.530274823995737
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 14
Train Loss: 2.5311341638565064


ic| 'Validation Loss:', val_loss: 2.5274132325099066
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 15
Train Loss: 2.5288338141441344


ic| 'Validation Loss:', val_loss: 2.5253447532653808
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 16
Train Loss: 2.527147609233856


ic| 'Validation Loss:', val_loss: 2.5238500705132116
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 17
Train Loss: 2.525817913532257


ic| 'Validation Loss:', val_loss: 2.522659202722403
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 18
Train Loss: 2.52478932094574


ic| 'Validation Loss:', val_loss: 2.521511833484356
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 19
Train Loss: 2.5239252614974976


ic| 'Validation Loss:', val_loss: 2.5206856764279877
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 20
Train Loss: 2.523239028930664


ic| 'Validation Loss:', val_loss: 2.520109561773447
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 21
Train Loss: 2.522670226097107


ic| 'Validation Loss:', val_loss: 2.519399037727943
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 22
Train Loss: 2.522194649219513


ic| 'Validation Loss:', val_loss: 2.5189257108248198
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 23
Train Loss: 2.5218241395950316


ic| 'Validation Loss:', val_loss: 2.518518726642315
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 24
Train Loss: 2.521494788646698


ic| 'Validation Loss:', val_loss: 2.5181983727675217
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 25
Train Loss: 2.5212567162513735


ic| 'Validation Loss:', val_loss: 2.517916639034565
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 26
Train Loss: 2.52105970621109


ic| 'Validation Loss:', val_loss: 2.5176520641033466
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 27
Train Loss: 2.520910499095917


ic| 'Validation Loss:', val_loss: 2.5174741378197303
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 28
Train Loss: 2.5207827134132383


ic| 'Validation Loss:', val_loss: 2.517328610787025
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 29
Train Loss: 2.520685447216034


ic| 'Validation Loss:', val_loss: 2.517168851999136
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 30
Train Loss: 2.5206080226898195


ic| 'Validation Loss:', val_loss: 2.517084162051861
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 31
Train Loss: 2.520549430370331


ic| 'Validation Loss:', val_loss: 2.516924869097196
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 32
Train Loss: 2.5204973487854003


ic| 'Validation Loss:', val_loss: 2.5168614864349363
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 33
Train Loss: 2.5204399604797363


ic| 'Validation Loss:', val_loss: 2.516773359592144
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 34
Train Loss: 2.520409378051758


ic| 'Validation Loss:', val_loss: 2.516688376206618
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 35
Train Loss: 2.5203801727294923


ic| 'Validation Loss:', val_loss: 2.5166514286628137
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 36
Train Loss: 2.5203503470420836


ic| 'Validation Loss:', val_loss: 2.5166221398573656
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 37
Train Loss: 2.5203251123428343


ic| 'Validation Loss:', val_loss: 2.5165519897754374
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 38
Train Loss: 2.520300879955292


ic| 'Validation Loss:', val_loss: 2.5164950700906608
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 39
Train Loss: 2.5202920541763305


ic| 'Validation Loss:', val_loss: 2.516451116708609
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 40
Train Loss: 2.5202840909957884


ic| 'Validation Loss:', val_loss: 2.5164192309746376
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 41
Train Loss: 2.52026181268692


ic| 'Validation Loss:', val_loss: 2.5163862998668964
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 42
Train Loss: 2.5202559123039245


ic| 'Validation Loss:', val_loss: 2.5163505737598126
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 43
Train Loss: 2.520241982936859


ic| 'Validation Loss:', val_loss: 2.5163071338947
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 44
Train Loss: 2.520229236602783


ic| 'Validation Loss:', val_loss: 2.516283207673293
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 45
Train Loss: 2.520224304199219


ic| 'Validation Loss:', val_loss: 2.516253922535823
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 46
Train Loss: 2.520217573642731


ic| 'Validation Loss:', val_loss: 2.516247606277466
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 47
Train Loss: 2.520214339256287


ic| 'Validation Loss:', val_loss: 2.516212558746338
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 48
Train Loss: 2.520199770450592


ic| 'Validation Loss:', val_loss: 2.516203212738037
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 49
Train Loss: 2.5202016887664795


ic| 'Validation Loss:', val_loss: 2.516163206100464
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


Epoch: 50
Train Loss: 2.520198531627655


ic| 'Validation Loss:', val_loss: 2.5161672922281118
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.03127870264064294


In [6]:
targets[0]

[0]

In [7]:
# testing loop
model.eval()
test_loss = 0
targets = []
preds = []
for sent in test_loader:
    X, Y = sent
    targets += [i.tolist() for i in Y]
    y = model(X)
    
    preds += torch.argmax(y, dim=1).to('cpu')
    
    loss = loss_func(y, torch.LongTensor(Y).to(device))
    test_loss += loss.item()
test_loss /= len(test_loader)
ic("Test Loss:", test_loss)
ic("F1:", f1_score(targets, preds, average='macro'))

ic| 'Test Loss:', test_loss: 2.516904660633632
ic| "F1:": 'F1:'
    f1_score(targets, preds, average='macro'): 0.020804552418472313


('F1:', 0.020804552418472313)