In [44]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [45]:
import pandas as pd
import numpy as np

from scipy.stats import wasserstein_distance, pearsonr, spearmanr

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

import string
from nltk import word_tokenize
import nltk
nltk.download('punkt')

from tqdm import tqdm

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [46]:
from torchtext import data
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [47]:
from torchtext.vocab import FastText
vectors = FastText('simple')

In [48]:
BATCH_SIZE = 128
EPOCHS = 100
LR = 0.0005
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [49]:
train_path = '/content/drive/MyDrive/emotion detection/train/'
test_path = '/content/drive/MyDrive/emotion detection/test/'
train_csv = 'train_df.csv'
test_csv = 'test_df.csv'

In [50]:
train_df = pd.read_csv(train_path + train_csv, index_col='index')
test_df = pd.read_csv(test_path + test_csv, index_col='index')

In [51]:
target_cols = train_df.columns[-6:]
train_df_n = train_df.copy()
train_df_n[target_cols] = train_df_n[target_cols].apply(lambda it: it/it.sum() if it.sum() else 0, axis=1)
train_df_n['target'] = train_df_n[target_cols].apply(lambda it: list(it), axis=1)
test_df_n = test_df.copy()
test_df_n[target_cols] = test_df_n[target_cols].apply(lambda it: it/it.sum() if it.sum() else 0, axis=1)
test_df_n['target'] = test_df_n[target_cols].apply(lambda it: list(it), axis=1)
train_df_n.head()

Unnamed: 0_level_0,text,0,1,2,3,4,5,target
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
1,Mortar assault leaves at least 18 dead,0.148649,0.013514,0.405405,0.0,0.432432,0.0,"[0.14864864864864866, 0.013513513513513514, 0...."
2,Goal delight for Sheva,0.0,0.0,0.0,0.709924,0.0,0.290076,"[0.0, 0.0, 0.0, 0.7099236641221374, 0.0, 0.290..."
3,Nigeria hostage feared dead is freed,0.081448,0.0,0.235294,0.298643,0.090498,0.294118,"[0.08144796380090498, 0.0, 0.23529411764705882..."
4,Bombers kill shoppers,0.231579,0.136842,0.329825,0.0,0.301754,0.0,"[0.23157894736842105, 0.1368421052631579, 0.32..."
5,"Vegetables, not fruit, slow brain decline",0.0,0.0,0.252525,0.262626,0.020202,0.464646,"[0.0, 0.0, 0.25252525252525254, 0.262626262626..."


In [52]:
train_input = []
train_target = []
test_input = []
test_target = []
# preprocess
for i in train_df_n.index:
  t = train_df_n.loc[i, 'text'].lower()
  t = "".join([char for char in t if char not in string.punctuation])
  train_input.append(word_tokenize(t))
  train_target.append(list(train_df_n.loc[i, target_cols]))

for i in test_df_n.index:
  t = test_df_n.loc[i, 'text'].lower()
  t = "".join([char for char in t if char not in string.punctuation])
  test_input.append(word_tokenize(t))
  test_target.append(list(test_df_n.loc[i, target_cols]))

In [53]:
max_features = 100000

tk = Tokenizer(lower = True, filters='', num_words=max_features)
full_text = train_input + test_input
# full_text = list(test_df_n['text'].values) + list(test_df_n['text'].values)
tk.fit_on_texts(full_text)

In [54]:
word2idx = tk.get_config()['word_index']

In [55]:
word2idx = eval(word2idx)
max(word2idx.values())

3457

In [56]:
def load_glove(path):
    """
    creates a dictionary mapping words to vectors from a file in glove format.
    """
    with open(path) as f:
        glove = {}
        for line in f.readlines():
            values = line.split()
            word = values[0]
            vector = np.array(values[1:], dtype='float32')
            glove[word] = vector
        return glove
def load_glove_embeddings(path, word2idx, embedding_dim=100):
    with open(path) as f:
        embeddings = np.zeros((max(word2idx.values()) + 1, embedding_dim))
        for line in f.readlines():
            values = line.split()
            word = values[0]
            index = word2idx.get(word)
            if index:
                vector = np.array(values[1:], dtype='float32')
                embeddings[index] = vector
        return torch.from_numpy(embeddings).float()

In [57]:
train_tokenized = tk.texts_to_sequences(train_df_n['text'].fillna('missing'))
test_tokenized = tk.texts_to_sequences(test_df_n['text'].fillna('missing'))

max_len = 72
maxlen = 72
X_train = torch.tensor(pad_sequences(train_tokenized, maxlen = max_len), dtype=torch.long)
X_test = torch.tensor(pad_sequences(test_tokenized, maxlen = max_len), dtype=torch.long)

In [58]:
y_train = torch.tensor(train_df_n[target_cols].values, dtype=torch.float32)
y_test = torch.tensor(test_df_n[target_cols].values, dtype=torch.float32)

In [59]:
X_train

tensor([[   0,    0,    0,  ...,  676,  677,   50],
        [   0,    0,    0,  ..., 1244,    3, 1245],
        [   0,    0,    0,  ...,   50,   11,  282],
        ...,
        [   0,    0,    0,  ...,  285,  142,  436],
        [   0,    0,    0,  ...,   48,  180,  174],
        [   0,    0,    0,  ...,   43,   73, 1686]])

In [60]:
X_train.shape, y_train.shape

(torch.Size([250, 72]), torch.Size([250, 6]))

In [61]:
train = torch.utils.data.TensorDataset(X_train, y_train)
test = torch.utils.data.TensorDataset(X_test, y_test)
    
train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=1000, shuffle=False)

In [62]:
# Bi-LSTM(Attention) Parameters
embedding_dim = 100
n_hidden = 100
num_classes = len(train_target[0])
words = set([c for sublist in train_input+test_input for c in sublist])
vocab_size = len(words)
classes = 6

In [100]:
glove_twitter_path = '/content/drive/MyDrive/emotion detection/glove.twitter.27B.100d.txt'
glove_wiki_path = '/content/drive/MyDrive/emotion detection/glove.6B.100d.txt'
glove = load_glove(glove_twitter_path)

In [101]:
toy_embeddings = load_glove_embeddings(glove_wiki_path, word2idx)
toy_embeddings.requires_grad = False

In [102]:
class AttentionModel(torch.nn.Module):
	def __init__(self, weights = None):
		super(AttentionModel, self).__init__()
		"""
		Arguments
		---------
		batch_size : Size of the batch which is same as the batch_size of the data returned by the TorchText BucketIterator
		output_size : 2 = (pos, neg)
		hidden_sie : Size of the hidden_state of the LSTM
		vocab_size : Size of the vocabulary containing unique words
		embedding_length : Embeddding dimension of GloVe word embeddings
		weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table 
		
		--------
		
		"""
		self.batch_size = BATCH_SIZE
		self.output_size = 6
		self.hidden_size = n_hidden
		self.vocab_size = vocab_size
		self.embedding_length = embedding_dim
		self.word_embeddings = nn.Embedding(vocab_size, self.embedding_length).from_pretrained(weights)
		# self.word_embeddings.weights = nn.Parameter(weights, requires_grad=False)
		self.lstm = nn.LSTM(self.embedding_length, self.hidden_size, bidirectional=True, dropout=0.2)
		self.label = nn.Linear(self.hidden_size*2, self.output_size)
		self.dropout1 = nn.Dropout(0)
		self.dropout2 = nn.Dropout(0.5)
 
		
	def attention_net(self, lstm_output, final_state):
		""" 
		Now we will incorporate Attention mechanism in our LSTM model. In this new model, we will use attention to compute soft alignment score corresponding
		between each of the hidden_state and the last hidden_state of the LSTM. We will be using torch.bmm for the batch matrix multiplication.
		
		Arguments
		---------
		
		lstm_output : Final output of the LSTM which contains hidden layer outputs for each sequence.
		final_state : Final time-step hidden state (h_n) of the LSTM
		
		---------
		
		Returns : It performs attention mechanism by first computing weights for each of the sequence present in lstm_output and and then finally computing the
				  new hidden state.
				  
		Tensor Size :
					hidden.size() = (batch_size, hidden_size)
					attn_weights.size() = (batch_size, num_seq)
					soft_attn_weights.size() = (batch_size, num_seq)
					new_hidden_state.size() = (batch_size, hidden_size)
					  
		"""
		hidden = final_state.view(-1, n_hidden*2, 1)
		attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)
		soft_attn_weights = F.softmax(attn_weights, 1)
		new_hidden_state = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
		
		return new_hidden_state
	
	def forward(self, input_sentences, batch_size=None):
		""" 
		Parameters
		----------
		input_sentence: input_sentence of shape = (batch_size, num_sequences)
		batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1)
		
		Returns
		-------
		Output of the linear layer containing logits for pos & neg class which receives its input as the new_hidden_state which is basically the output of the Attention network.
		final_output.shape = (batch_size, output_size)
		
		"""
		input = self.word_embeddings(input_sentences)
		input = input.permute(1, 0, 2)
		if batch_size is None:
			h_0 = Variable(torch.zeros(2, self.batch_size, self.hidden_size))
			c_0 = Variable(torch.zeros(2, self.batch_size, self.hidden_size))
		else:
			h_0 = Variable(torch.zeros(2, batch_size, self.hidden_size))
			c_0 = Variable(torch.zeros(2, batch_size, self.hidden_size))
		output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))
		output = output.permute(1, 0, 2)
		attn_output = self.attention_net(self.dropout1(output), final_hidden_state)
		logits = self.label(self.dropout2(attn_output))
		return logits
		# return torch.nn.functional.softmax(logits)

In [103]:
# from torchtext.vocab import FastText
# ft_embedding = FastText('simple')
# weights = torch.FloatTensor(ft_embedding.vectors)

In [104]:
model = AttentionModel(nn.Parameter(toy_embeddings, requires_grad=False))
# model.word_embeddings.weight = nn.Parameter(toy_embeddings, requires_grad=False)
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=0.001)
criterion = nn.MSELoss()

  "num_layers={}".format(dropout, num_layers))


In [105]:
best_param = 0
best_loss = 0
best_metric_param = 0
best_metric = 0
for e in range(EPOCHS):
    for i, (texts, labels) in tqdm(enumerate(train_loader)):
      optimizer.zero_grad()
      pred = model(texts, batch_size=texts.shape[0])
      loss = criterion(pred, labels)
      loss.backward()
      optimizer.step()
    with torch.no_grad():
        for i, (texts, labels) in tqdm(enumerate(test_loader)):
          pred = model(texts, batch_size=1000)
          loss = criterion(pred, labels)
          RMSED = np.mean([np.sqrt(criterion(pred[i], labels[i])) for i in range(1000)])
          if best_param == 0 or (loss.item() < best_loss):
              best_param = model.state_dict()
              best_loss = loss.item()
          if best_metric_param == 0 or (RMSED < best_metric):
              best_metric_param = model.state_dict()
              best_metric = RMSED
          if (e % 2 == 0):
            print(f'epoch {e}: loss: {loss.item()}, RMSED: {RMSED}')

2it [00:00,  3.23it/s]
1it [00:00,  1.04it/s]


epoch 0: loss: 0.06463617086410522, RMSED: 0.2500159442424774


2it [00:00,  2.85it/s]
1it [00:00,  1.03it/s]
2it [00:00,  2.92it/s]
1it [00:01,  1.02s/it]


epoch 2: loss: 0.05602327361702919, RMSED: 0.23151318728923798


2it [00:00,  2.85it/s]
1it [00:01,  1.01s/it]
2it [00:00,  3.29it/s]
1it [00:00,  1.08it/s]


epoch 4: loss: 0.048512063920497894, RMSED: 0.2136252224445343


2it [00:00,  3.28it/s]
1it [00:00,  1.05it/s]
2it [00:00,  2.90it/s]
1it [00:00,  1.05it/s]


epoch 6: loss: 0.04310239106416702, RMSED: 0.19991110265254974


2it [00:00,  3.39it/s]
1it [00:00,  1.05it/s]
2it [00:00,  2.94it/s]
1it [00:01,  1.00s/it]


epoch 8: loss: 0.040205344557762146, RMSED: 0.19158005714416504


2it [00:00,  3.27it/s]
1it [00:00,  1.00it/s]
2it [00:00,  2.93it/s]
1it [00:00,  1.02it/s]


epoch 10: loss: 0.03894900903105736, RMSED: 0.1883058249950409


2it [00:00,  3.26it/s]
1it [00:01,  1.01s/it]
2it [00:00,  2.89it/s]
1it [00:01,  1.02s/it]


epoch 12: loss: 0.039757028222084045, RMSED: 0.19072064757347107


2it [00:00,  2.99it/s]
1it [00:00,  1.07it/s]
2it [00:00,  3.37it/s]
1it [00:00,  1.08it/s]


epoch 14: loss: 0.03882193937897682, RMSED: 0.18872299790382385


2it [00:00,  3.29it/s]
1it [00:01,  1.25s/it]
2it [00:00,  2.71it/s]
1it [00:01,  1.05s/it]


epoch 16: loss: 0.03861618414521217, RMSED: 0.18811435997486115


2it [00:00,  2.19it/s]
1it [00:02,  2.22s/it]
2it [00:02,  1.03s/it]
1it [00:02,  2.57s/it]


epoch 18: loss: 0.038888223469257355, RMSED: 0.1888512820005417


2it [00:00,  2.10it/s]
1it [00:01,  1.35s/it]
2it [00:00,  3.26it/s]
1it [00:01,  1.54s/it]


epoch 20: loss: 0.039002664387226105, RMSED: 0.18922030925750732


2it [00:00,  3.34it/s]
1it [00:01,  1.02s/it]
2it [00:00,  3.33it/s]
1it [00:00,  1.00it/s]


epoch 22: loss: 0.03910086676478386, RMSED: 0.18945574760437012


2it [00:01,  1.39it/s]
1it [00:01,  1.44s/it]
2it [00:00,  2.90it/s]
1it [00:01,  1.00s/it]


epoch 24: loss: 0.038663942366838455, RMSED: 0.18837299942970276


2it [00:00,  3.24it/s]
1it [00:00,  1.07it/s]
2it [00:00,  3.07it/s]
1it [00:01,  1.02s/it]


epoch 26: loss: 0.03864944353699684, RMSED: 0.18823929131031036


2it [00:00,  2.70it/s]
1it [00:01,  1.03s/it]
2it [00:00,  3.23it/s]
1it [00:00,  1.08it/s]


epoch 28: loss: 0.03849542513489723, RMSED: 0.18782010674476624


2it [00:00,  3.16it/s]
1it [00:01,  1.00s/it]
2it [00:00,  3.22it/s]
1it [00:00,  1.06it/s]


epoch 30: loss: 0.038697078824043274, RMSED: 0.18821027874946594


2it [00:00,  2.95it/s]
1it [00:01,  1.02s/it]
2it [00:00,  3.05it/s]
1it [00:00,  1.04it/s]


epoch 32: loss: 0.03830481693148613, RMSED: 0.18731771409511566


2it [00:00,  2.79it/s]
1it [00:01,  1.04s/it]
2it [00:00,  3.16it/s]
1it [00:01,  1.00s/it]


epoch 34: loss: 0.038081347942352295, RMSED: 0.18658718466758728


2it [00:00,  2.87it/s]
1it [00:01,  1.03s/it]
2it [00:00,  2.79it/s]
1it [00:00,  1.06it/s]


epoch 36: loss: 0.03818010166287422, RMSED: 0.18677353858947754


2it [00:00,  3.26it/s]
1it [00:00,  1.00it/s]
2it [00:00,  2.90it/s]
1it [00:01,  1.02s/it]


epoch 38: loss: 0.03796951845288277, RMSED: 0.18662549555301666


2it [00:00,  2.93it/s]
1it [00:01,  1.01s/it]
2it [00:00,  2.92it/s]
1it [00:01,  1.01s/it]


epoch 40: loss: 0.03776027262210846, RMSED: 0.18599772453308105


2it [00:00,  2.88it/s]
1it [00:01,  1.03s/it]
2it [00:00,  2.87it/s]
1it [00:01,  1.01s/it]


epoch 42: loss: 0.0378168523311615, RMSED: 0.1860327124595642


2it [00:00,  2.84it/s]
1it [00:01,  1.01s/it]
2it [00:00,  2.96it/s]
1it [00:00,  1.01it/s]


epoch 44: loss: 0.037792764604091644, RMSED: 0.1861218810081482


2it [00:00,  3.28it/s]
1it [00:00,  1.05it/s]
2it [00:00,  3.34it/s]
1it [00:00,  1.03it/s]


epoch 46: loss: 0.0376753956079483, RMSED: 0.18576352298259735


2it [00:00,  2.86it/s]
1it [00:01,  1.00s/it]
2it [00:00,  3.28it/s]
1it [00:00,  1.07it/s]


epoch 48: loss: 0.03750083968043327, RMSED: 0.18528996407985687


2it [00:00,  3.37it/s]
1it [00:00,  1.04it/s]
2it [00:00,  2.80it/s]
1it [00:01,  1.05s/it]


epoch 50: loss: 0.0371403768658638, RMSED: 0.184238001704216


2it [00:00,  2.76it/s]
1it [00:00,  1.01it/s]
2it [00:00,  3.26it/s]
1it [00:00,  1.08it/s]


epoch 52: loss: 0.03703632578253746, RMSED: 0.1837502419948578


2it [00:00,  3.38it/s]
1it [00:00,  1.00it/s]
2it [00:01,  1.83it/s]
1it [00:01,  1.47s/it]


epoch 54: loss: 0.03663594275712967, RMSED: 0.1830284595489502


2it [00:00,  2.01it/s]
1it [00:01,  1.88s/it]
2it [00:01,  1.35it/s]
1it [00:01,  1.80s/it]


epoch 56: loss: 0.036513689905405045, RMSED: 0.18274277448654175


2it [00:01,  1.67it/s]
1it [00:00,  1.08it/s]
2it [00:00,  3.22it/s]
1it [00:01,  1.76s/it]


epoch 58: loss: 0.03632550686597824, RMSED: 0.18221347033977509


2it [00:01,  1.82it/s]
1it [00:01,  1.32s/it]
2it [00:01,  1.96it/s]
1it [00:01,  1.70s/it]


epoch 60: loss: 0.03580443561077118, RMSED: 0.18079394102096558


2it [00:00,  2.33it/s]
1it [00:01,  1.65s/it]
2it [00:00,  2.39it/s]
1it [00:01,  1.23s/it]


epoch 62: loss: 0.03573842719197273, RMSED: 0.18038317561149597


2it [00:00,  2.29it/s]
1it [00:01,  1.28s/it]
2it [00:01,  1.88it/s]
1it [00:01,  1.53s/it]


epoch 64: loss: 0.035457734018564224, RMSED: 0.17988277971744537


2it [00:01,  1.06it/s]
1it [00:01,  1.65s/it]
2it [00:00,  2.06it/s]
1it [00:01,  1.37s/it]


epoch 66: loss: 0.03505782037973404, RMSED: 0.178359717130661


2it [00:00,  2.40it/s]
1it [00:01,  1.39s/it]
2it [00:01,  1.44it/s]
1it [00:01,  1.32s/it]


epoch 68: loss: 0.035267192870378494, RMSED: 0.1789679080247879


2it [00:00,  2.15it/s]
1it [00:01,  1.62s/it]
2it [00:01,  1.92it/s]
1it [00:01,  1.32s/it]


epoch 70: loss: 0.035335712134838104, RMSED: 0.17930884659290314


2it [00:00,  2.02it/s]
1it [00:01,  1.21s/it]
2it [00:00,  3.29it/s]
1it [00:00,  1.01it/s]


epoch 72: loss: 0.03531739488244057, RMSED: 0.1786946952342987


2it [00:00,  2.91it/s]
1it [00:00,  1.02it/s]
2it [00:00,  2.80it/s]
1it [00:01,  1.03s/it]


epoch 74: loss: 0.03535587713122368, RMSED: 0.17876224219799042


2it [00:00,  2.89it/s]
1it [00:01,  1.04s/it]
2it [00:00,  3.01it/s]
1it [00:00,  1.06it/s]


epoch 76: loss: 0.035102568566799164, RMSED: 0.178348109126091


2it [00:00,  3.28it/s]
1it [00:00,  1.07it/s]
2it [00:00,  3.36it/s]
1it [00:00,  1.07it/s]


epoch 78: loss: 0.034587759524583817, RMSED: 0.1769094169139862


2it [00:00,  2.91it/s]
1it [00:00,  1.06it/s]
2it [00:00,  3.26it/s]
1it [00:00,  1.08it/s]


epoch 80: loss: 0.0344538539648056, RMSED: 0.17651578783988953


2it [00:00,  3.33it/s]
1it [00:00,  1.09it/s]
2it [00:00,  2.88it/s]
1it [00:01,  1.03s/it]


epoch 82: loss: 0.034337498247623444, RMSED: 0.176080584526062


2it [00:00,  3.04it/s]
1it [00:00,  1.03it/s]
2it [00:00,  2.82it/s]
1it [00:01,  1.03s/it]


epoch 84: loss: 0.03484064340591431, RMSED: 0.17731580138206482


2it [00:00,  3.21it/s]
1it [00:00,  1.08it/s]
2it [00:00,  3.32it/s]
1it [00:00,  1.07it/s]


epoch 86: loss: 0.0347614660859108, RMSED: 0.17669183015823364


2it [00:00,  3.38it/s]
1it [00:01,  1.15s/it]
2it [00:00,  3.37it/s]
1it [00:00,  1.10it/s]


epoch 88: loss: 0.034263622015714645, RMSED: 0.1753934919834137


2it [00:00,  3.38it/s]
1it [00:01,  1.01s/it]
2it [00:00,  3.13it/s]
1it [00:01,  1.20s/it]


epoch 90: loss: 0.03416528180241585, RMSED: 0.17498010396957397


2it [00:00,  2.25it/s]
1it [00:01,  1.39s/it]
2it [00:00,  2.13it/s]
1it [00:00,  1.05it/s]


epoch 92: loss: 0.0340050607919693, RMSED: 0.17480850219726562


2it [00:00,  2.99it/s]
1it [00:01,  1.09s/it]
2it [00:00,  3.02it/s]
1it [00:00,  1.06it/s]


epoch 94: loss: 0.0339960902929306, RMSED: 0.17443545162677765


2it [00:00,  2.88it/s]
1it [00:00,  1.00it/s]
2it [00:00,  3.07it/s]
1it [00:00,  1.06it/s]


epoch 96: loss: 0.03390655666589737, RMSED: 0.1745503842830658


2it [00:00,  3.40it/s]
1it [00:00,  1.09it/s]
2it [00:00,  3.36it/s]
1it [00:00,  1.06it/s]


epoch 98: loss: 0.034050874412059784, RMSED: 0.17420363426208496


2it [00:00,  3.36it/s]
1it [00:00,  1.08it/s]


In [106]:
best_loss, best_metric

(0.033519282937049866, 0.17346622)

In [107]:
model.load_state_dict(best_metric_param)

<All keys matched successfully>

In [108]:
def acc(x, y):
  return np.argmax(x) == np.argmax(y)

In [109]:
with torch.no_grad():
  for _, (texts, labels) in tqdm(enumerate(test_loader)):
    pred = model(texts, batch_size=1000)
    loss = criterion(pred, labels)
    acc_1 =  np.mean([acc(pred[i], labels[i]) for i in range(1000)])
    APd = np.mean([pearsonr(pred[i], labels[i]) for i in range(1000)])
    ASd = np.mean([spearmanr(pred[i], labels[i]) for i in range(1000)])
    APe = np.mean([pearsonr(pred[:, i], labels[:, i]) for i in range(6)])
    RMSED = np.mean([np.sqrt(criterion(pred[i], labels[i])) for i in range(1000)])
    WD = np.mean([wasserstein_distance(pred[i], labels[i]) for i in range(1000)])

1it [00:02,  2.06s/it]


In [110]:
loss.item(), acc_1, APd, ASd, APe, RMSED, WD

(0.03349388763308525,
 0.414,
 0.378681727369863,
 0.377293546274097,
 0.15366352624215018,
 0.17295001,
 0.10663217979374652)

In [112]:
print("------\nAcc@1: ", acc_1, "\nAPd: ", APd, "\nAPe: ",APe, "\nRMSED: ",RMSED, "\nWDD: ",WD)
'------'

------
Acc@1:  0.414 
APd:  0.378681727369863 
APe:  0.15366352624215018 
RMSED:  0.17295001 
WDD:  0.10663217979374652


'------'