In [1]:
import torch

import random
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import f1_score, roc_auc_score

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x15a75f838b0>

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
args = {'batch_size': 128,
        'lr': 3e-3,
        'hidden_dim': 128,
        'n_layers': 3,
        'bidirectional': True,
        'dropout': 0.25,
        'n_epochs': 50
}

In [4]:
print(torch.cuda.is_available())

True


In [5]:
df = pd.read_csv("../data/training_data/data1.csv", sep=",")

for i in range(2, 11):
    if i != 3:
        filename = "../data/training_data/data" + str(i) + ".csv"
        df = df.append(pd.read_csv(filename, sep=","))
import math

train_num = math.ceil(0.7 * len(df))
valid_num = math.ceil(0.9 * len(df))
train_data = df.iloc[:train_num, :]
valid_data = df.iloc[train_num:valid_num, :].reset_index()
test_data = df.iloc[valid_num:, :].reset_index()

In [45]:
num_positive = (df["sentiment"] == "positive").sum()
num_negative = (df["sentiment"] == "negative").sum()
num_neutral = (df["sentiment"] == "neutral").sum()

args["weight"] = torch.tensor([num_negative / len(df), num_neutral / len(df), num_positive / len(df)], dtype=torch.float32)

print(args["weight"])

tensor([0.1568, 0.4639, 0.3793])


In [24]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

In [25]:
tokenized_train = train_data['text'].apply((
    lambda x: tokenizer.encode(x, add_special_tokens=True)))
tokenized_valid = valid_data['text'].apply((
    lambda x: tokenizer.encode(x, add_special_tokens=True)))
tokenized_test = test_data['text'].apply((
    lambda x: tokenizer.encode(x, add_special_tokens=True)))

In [26]:
# max_len = tokenizer.max_model_input_sizes['distilbert-base-uncased']

# print(max_len)

In [27]:
def get_max_len(tokenized):
    max_len = 0
    for i in tokenized.values:
        if len(i) > max_len:
            max_len = len(i)
    return max_len

In [28]:
max_len_train = get_max_len(tokenized_train)
print(max_len_train)
max_len_valid = get_max_len(tokenized_valid)
print(max_len_valid)
max_len_test = get_max_len(tokenized_test)
print(max_len_test)
max_len = max([max_len_train, max_len_valid, max_len_test])

112
60
55


In [29]:
padded_train = torch.tensor([i + [0] * (max_len - len(i)) 
                             for i in tokenized_train.values])
padded_valid = torch.tensor([i + [0] * (max_len - len(i)) 
                             for i in tokenized_valid.values])
padded_test = torch.tensor([i + [0] * (max_len - len(i)) 
                            for i in tokenized_test.values])

In [30]:
train_label = torch.tensor(train_data['sentiment'].replace(
    to_replace='positive', value=2).replace(
    to_replace='negative', value=0).replace(
    to_replace='neutral', value=1).to_numpy())
valid_label = torch.tensor(valid_data['sentiment'].replace(
    to_replace='positive', value=2).replace(
    to_replace='negative', value=0).replace(
    to_replace='neutral', value=1).to_numpy())
test_label = torch.tensor(test_data['sentiment'].replace(
    to_replace='positive', value=2).replace(
    to_replace='negative', value=0).replace(
    to_replace='neutral', value=1).to_numpy())

In [31]:
# Define the dataset and data iterators
class Dataset(data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, x, labels):
        'Initialization'
        self.x = x
        self.labels = labels

    def __len__(self):
        'Denotes the total number of samples'
        return self.x.shape[0]

    def __getitem__(self, index):
        'Generates one sample of data'

        # Load data and get label
        x = self.x[index]
        y = self.labels[index]

        return x, y

In [32]:
trainset = Dataset(padded_train, train_label)
validset = Dataset(padded_valid, valid_label)
testset = Dataset(padded_test, test_label)

train_loader = torch.utils.data.DataLoader(trainset,
                                           batch_size=args['batch_size'],
                                           shuffle=True,
                                           drop_last=True)
valid_loader = torch.utils.data.DataLoader(validset,
                                           batch_size=args['batch_size'],
                                           shuffle=True,
                                           drop_last=True)
test_loader = torch.utils.data.DataLoader(testset,
                                           batch_size=args['batch_size'],
                                           shuffle=True,
                                           drop_last=True)

In [33]:
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

In [34]:
class BERTGRUSentiment(nn.Module):
    def __init__(self,
                 bert,
                 hidden_dim,
                 output_dim,
                 n_layers,
                 bidirectional,
                 dropout):
        
        super().__init__()
        
        self.bert = bert
        
        embedding_dim = bert.config.to_dict()['dim']
        
        self.rnn = nn.GRU(embedding_dim,
                          hidden_dim,
                          num_layers = n_layers,
                          bidirectional = bidirectional,
                          batch_first = True,
                          dropout = 0 if n_layers < 2 else dropout)
        
        self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):
        
        #text = [batch size, sent len]
        attention_mask = text.masked_fill(text != 0, 1)
                
        with torch.no_grad():
            embedded = self.bert(text, attention_mask=attention_mask)[0]
                
        #embedded = [batch size, sent len, emb dim]
        
        _, hidden = self.rnn(embedded)
        
        #hidden = [n layers * n directions, batch size, emb dim]
        
        if self.rnn.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
        else:
            hidden = self.dropout(hidden[-1,:,:])
                
        #hidden = [batch size, hid dim]
        
        output = self.out(hidden)
        
        #output = [batch size, out dim]
        
        return output

In [35]:
model = BERTGRUSentiment(bert,
                         args['hidden_dim'],
                         3,
                         args['n_layers'],
                         args['bidirectional'],
                         args['dropout'])

In [36]:
for name, param in model.named_parameters():                
    if name.startswith('bert'):
        param.requires_grad = False

In [37]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 1,283,331 trainable parameters


In [38]:
for name, param in model.named_parameters():                
    if param.requires_grad:
        print(name)

rnn.weight_ih_l0
rnn.weight_hh_l0
rnn.bias_ih_l0
rnn.bias_hh_l0
rnn.weight_ih_l0_reverse
rnn.weight_hh_l0_reverse
rnn.bias_ih_l0_reverse
rnn.bias_hh_l0_reverse
rnn.weight_ih_l1
rnn.weight_hh_l1
rnn.bias_ih_l1
rnn.bias_hh_l1
rnn.weight_ih_l1_reverse
rnn.weight_hh_l1_reverse
rnn.bias_ih_l1_reverse
rnn.bias_hh_l1_reverse
rnn.weight_ih_l2
rnn.weight_hh_l2
rnn.bias_ih_l2
rnn.bias_hh_l2
rnn.weight_ih_l2_reverse
rnn.weight_hh_l2_reverse
rnn.bias_ih_l2_reverse
rnn.bias_hh_l2_reverse
out.weight
out.bias


In [47]:
optimizer = optim.Adam(model.parameters(), lr=args['lr'])
criterion = nn.CrossEntropyLoss(weight=args['weight'])
model = model.to(device)
criterion = criterion.to(device)

In [49]:
def multi_acc(y_pred, y_label):
    softmax = nn.Softmax(dim=1)
    y_pred_softmax = softmax(y_pred)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)

    # accu
    correct_pred = (y_pred_tags == y_label).float()
    acc = correct_pred.sum() / len(y_label)

    # roc-auc
    one_hot_label = nn.functional.one_hot(y_label)
    roc_auc = roc_auc_score(one_hot_label.detach().cpu(), y_pred_softmax.detach().cpu(), average="macro")

    # f1
    f1 = f1_score(y_label.detach().cpu(), y_pred_tags.detach().cpu(), average='weighted')
    
    return acc, roc_auc, f1

In [50]:
def train(model, data_loader, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    epoch_rocauc = 0
    epoch_f1 = 0
    
    model.train()
    
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        predictions = model(data).squeeze(1)
        
        loss = criterion(predictions, target)
        
        acc, roc_auc, f1 = multi_acc(predictions, target)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        epoch_rocauc += roc_auc
        epoch_f1 += f1

        print("batch idx {}: | train loss: {} | train accu: {:.3f} | train roc: {:.3f} | train f1: {}".format(
            batch_idx, loss.item(), acc.item(), roc_auc, f1))
        
    return epoch_loss / len(data_loader), epoch_acc / len(data_loader), epoch_rocauc / len(data_loader), epoch_f1 / len(data_loader)

In [51]:
def evaluate(model, data_loader, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    epoch_rocauc = 0
    epoch_f1 = 0
    model.eval()
    
    with torch.no_grad():
    
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = data.to(device), target.to(device)
            
            predictions = model(data).squeeze(1)
            
            loss = criterion(predictions, target)
            
            acc, roc_auc, f1 = multi_acc(predictions, target)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
            epoch_rocauc += roc_auc
            epoch_f1 += f1
        
    return epoch_loss / len(data_loader), epoch_acc / len(data_loader), epoch_rocauc / len(data_loader), epoch_f1 / len(data_loader)

In [52]:
import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
history = {
    "train_loss": [],
    "valid_loss": []
}

best_valid_loss = float('inf')

for epoch in range(args['n_epochs']):
    
    start_time = time.time()
    
    train_loss, train_acc, train_rocauc, train_f1 = train(model, train_loader, optimizer, criterion)
    history["train_loss"].append(train_loss)
    valid_loss, valid_acc, valid_rocauc, valid_f1 = evaluate(model, valid_loader, criterion)
    history["valid_loss"].append(valid_loss)
        
    end_time = time.time()
        
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
        
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best_model_3.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f} | Train rocauc: {train_rocauc} | Train f1: {train_f1}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f} | Val. rocauc: {valid_rocauc} | Val. f1: {valid_f1}%')

batch idx 0: | train loss: 0.7391921877861023 | train accu: 0.609 | train roc: 0.765 | train f1: 0.5658255912162161
batch idx 1: | train loss: 0.6732923984527588 | train accu: 0.641 | train roc: 0.823 | train f1: 0.5774831964152353
batch idx 2: | train loss: 0.5953079462051392 | train accu: 0.672 | train roc: 0.859 | train f1: 0.637475397040858
batch idx 3: | train loss: 0.5380573272705078 | train accu: 0.742 | train roc: 0.857 | train f1: 0.7267660440613026
batch idx 4: | train loss: 0.6385010480880737 | train accu: 0.656 | train roc: 0.818 | train f1: 0.6272423094997107
batch idx 5: | train loss: 0.7005694508552551 | train accu: 0.594 | train roc: 0.817 | train f1: 0.5361637654613817
batch idx 6: | train loss: 0.6688470244407654 | train accu: 0.617 | train roc: 0.829 | train f1: 0.5833167170105893
batch idx 7: | train loss: 0.5266201496124268 | train accu: 0.742 | train roc: 0.887 | train f1: 0.7357954545454546
batch idx 8: | train loss: 0.6955465078353882 | train accu: 0.609 | train

batch idx 71: | train loss: 0.6074802875518799 | train accu: 0.641 | train roc: 0.868 | train f1: 0.6171826893472906
batch idx 72: | train loss: 0.6048340797424316 | train accu: 0.688 | train roc: 0.849 | train f1: 0.6736773757740773
batch idx 73: | train loss: 0.6174636483192444 | train accu: 0.688 | train roc: 0.844 | train f1: 0.6816752772177419
batch idx 74: | train loss: 0.6100317239761353 | train accu: 0.742 | train roc: 0.847 | train f1: 0.7347996760548297
batch idx 75: | train loss: 0.6456020474433899 | train accu: 0.656 | train roc: 0.857 | train f1: 0.6469127463445645
batch idx 76: | train loss: 0.6313872337341309 | train accu: 0.672 | train roc: 0.866 | train f1: 0.6690546772068511
batch idx 77: | train loss: 0.5901423692703247 | train accu: 0.664 | train roc: 0.907 | train f1: 0.6462786181672503
batch idx 78: | train loss: 0.677832305431366 | train accu: 0.641 | train roc: 0.793 | train f1: 0.6182996323529412
batch idx 79: | train loss: 0.7062072157859802 | train accu: 0.63

batch idx 141: | train loss: 0.5984336733818054 | train accu: 0.656 | train roc: 0.863 | train f1: 0.6529928167721648
batch idx 142: | train loss: 0.622446596622467 | train accu: 0.656 | train roc: 0.849 | train f1: 0.6401164367358478
batch idx 143: | train loss: 0.5644274950027466 | train accu: 0.711 | train roc: 0.876 | train f1: 0.6947164967385555
batch idx 144: | train loss: 0.5138830542564392 | train accu: 0.742 | train roc: 0.876 | train f1: 0.6991050642587094
batch idx 145: | train loss: 0.6772159934043884 | train accu: 0.562 | train roc: 0.846 | train f1: 0.522317587334366
batch idx 146: | train loss: 0.6039571166038513 | train accu: 0.648 | train roc: 0.854 | train f1: 0.6362598288621646
batch idx 147: | train loss: 0.6610156893730164 | train accu: 0.680 | train roc: 0.809 | train f1: 0.6675848606488051
batch idx 148: | train loss: 0.49540936946868896 | train accu: 0.711 | train roc: 0.921 | train f1: 0.6828265249222024
batch idx 149: | train loss: 0.5505807995796204 | train a

batch idx 211: | train loss: 0.6106320023536682 | train accu: 0.695 | train roc: 0.854 | train f1: 0.6808676421957671
batch idx 212: | train loss: 0.6129406094551086 | train accu: 0.664 | train roc: 0.842 | train f1: 0.6531454392067959
batch idx 213: | train loss: 0.6045274138450623 | train accu: 0.648 | train roc: 0.850 | train f1: 0.6326841490408264
batch idx 214: | train loss: 0.6144751310348511 | train accu: 0.703 | train roc: 0.854 | train f1: 0.6887547348484848
batch idx 215: | train loss: 0.6692363023757935 | train accu: 0.641 | train roc: 0.829 | train f1: 0.6147632129774988
batch idx 216: | train loss: 0.5718383193016052 | train accu: 0.672 | train roc: 0.860 | train f1: 0.6419596006530848
batch idx 217: | train loss: 0.5724655389785767 | train accu: 0.695 | train roc: 0.841 | train f1: 0.6616468702865761
batch idx 218: | train loss: 0.6330844163894653 | train accu: 0.672 | train roc: 0.831 | train f1: 0.6449165059078852
batch idx 219: | train loss: 0.6099568009376526 | train 

batch idx 38: | train loss: 0.5696861147880554 | train accu: 0.695 | train roc: 0.898 | train f1: 0.6736474419354603
batch idx 39: | train loss: 0.6721664071083069 | train accu: 0.648 | train roc: 0.854 | train f1: 0.628418823455864
batch idx 40: | train loss: 0.5255534052848816 | train accu: 0.742 | train roc: 0.893 | train f1: 0.7334146594684385
batch idx 41: | train loss: 0.4659331440925598 | train accu: 0.750 | train roc: 0.913 | train f1: 0.7356750009370783
batch idx 42: | train loss: 0.5757931470870972 | train accu: 0.672 | train roc: 0.879 | train f1: 0.6410280233025312
batch idx 43: | train loss: 0.5607723593711853 | train accu: 0.711 | train roc: 0.879 | train f1: 0.6778058307926829
batch idx 44: | train loss: 0.5932444930076599 | train accu: 0.688 | train roc: 0.857 | train f1: 0.6804998531717875
batch idx 45: | train loss: 0.5335450172424316 | train accu: 0.703 | train roc: 0.887 | train f1: 0.6966401143790848
batch idx 46: | train loss: 0.5549493432044983 | train accu: 0.72

batch idx 109: | train loss: 0.5792325735092163 | train accu: 0.742 | train roc: 0.854 | train f1: 0.7171026524644946
batch idx 110: | train loss: 0.5324451327323914 | train accu: 0.688 | train roc: 0.880 | train f1: 0.665759154040404
batch idx 111: | train loss: 0.5531772375106812 | train accu: 0.727 | train roc: 0.872 | train f1: 0.7103921058714748
batch idx 112: | train loss: 0.5147844552993774 | train accu: 0.742 | train roc: 0.865 | train f1: 0.7177002644123343
batch idx 113: | train loss: 0.6302530169487 | train accu: 0.641 | train roc: 0.845 | train f1: 0.6203343949044586
batch idx 114: | train loss: 0.6487938165664673 | train accu: 0.617 | train roc: 0.837 | train f1: 0.563745843280275
batch idx 115: | train loss: 0.557648777961731 | train accu: 0.719 | train roc: 0.883 | train f1: 0.7082290516501042
batch idx 116: | train loss: 0.5709347128868103 | train accu: 0.727 | train roc: 0.868 | train f1: 0.7191883350388075
batch idx 117: | train loss: 0.6871102452278137 | train accu: 

batch idx 179: | train loss: 0.5058582425117493 | train accu: 0.711 | train roc: 0.894 | train f1: 0.6716162008281574
batch idx 180: | train loss: 0.6031274199485779 | train accu: 0.594 | train roc: 0.850 | train f1: 0.5335881090454262
batch idx 181: | train loss: 0.5133737921714783 | train accu: 0.688 | train roc: 0.881 | train f1: 0.6554657443410796
batch idx 182: | train loss: 0.5930353403091431 | train accu: 0.695 | train roc: 0.855 | train f1: 0.6678453947368421
batch idx 183: | train loss: 0.5303606390953064 | train accu: 0.734 | train roc: 0.880 | train f1: 0.7233416204131977
batch idx 184: | train loss: 0.5928698182106018 | train accu: 0.727 | train roc: 0.852 | train f1: 0.7163106304639735
batch idx 185: | train loss: 0.4218425750732422 | train accu: 0.789 | train roc: 0.938 | train f1: 0.7831835074327038
batch idx 186: | train loss: 0.601362407207489 | train accu: 0.711 | train roc: 0.853 | train f1: 0.6990152994791667
batch idx 187: | train loss: 0.6507884860038757 | train a

batch idx 6: | train loss: 0.42570847272872925 | train accu: 0.820 | train roc: 0.922 | train f1: 0.8172683566433566
batch idx 7: | train loss: 0.5803934335708618 | train accu: 0.711 | train roc: 0.879 | train f1: 0.7072859274563821
batch idx 8: | train loss: 0.5508607029914856 | train accu: 0.742 | train roc: 0.878 | train f1: 0.7377735040869976
batch idx 9: | train loss: 0.5071380734443665 | train accu: 0.781 | train roc: 0.888 | train f1: 0.7702809343434345
batch idx 10: | train loss: 0.514122486114502 | train accu: 0.750 | train roc: 0.907 | train f1: 0.7396623717489859
batch idx 11: | train loss: 0.5722194910049438 | train accu: 0.719 | train roc: 0.866 | train f1: 0.7034976420248096
batch idx 12: | train loss: 0.5115107893943787 | train accu: 0.734 | train roc: 0.902 | train f1: 0.6986838825293638
batch idx 13: | train loss: 0.4541555345058441 | train accu: 0.781 | train roc: 0.934 | train f1: 0.7751878246144785
batch idx 14: | train loss: 0.5093557834625244 | train accu: 0.750 |

batch idx 77: | train loss: 0.4996533691883087 | train accu: 0.758 | train roc: 0.883 | train f1: 0.7500904268145647
batch idx 78: | train loss: 0.5635194182395935 | train accu: 0.742 | train roc: 0.865 | train f1: 0.7328509852216748
batch idx 79: | train loss: 0.5787677764892578 | train accu: 0.633 | train roc: 0.901 | train f1: 0.6083029398762158
batch idx 80: | train loss: 0.522854208946228 | train accu: 0.727 | train roc: 0.883 | train f1: 0.7029634757383967
batch idx 81: | train loss: 0.49080029129981995 | train accu: 0.742 | train roc: 0.877 | train f1: 0.7275720526720322
batch idx 82: | train loss: 0.5613107085227966 | train accu: 0.727 | train roc: 0.878 | train f1: 0.7232830413633985
batch idx 83: | train loss: 0.5220333337783813 | train accu: 0.719 | train roc: 0.891 | train f1: 0.697696314102564
batch idx 84: | train loss: 0.5728065371513367 | train accu: 0.680 | train roc: 0.883 | train f1: 0.6770726625069845
batch idx 85: | train loss: 0.592415452003479 | train accu: 0.672

batch idx 147: | train loss: 0.5605822205543518 | train accu: 0.727 | train roc: 0.894 | train f1: 0.7133275493510628
batch idx 148: | train loss: 0.4932960569858551 | train accu: 0.758 | train roc: 0.855 | train f1: 0.7410166715391531
batch idx 149: | train loss: 0.4991525411605835 | train accu: 0.734 | train roc: 0.910 | train f1: 0.7176307624113476
batch idx 150: | train loss: 0.5111368298530579 | train accu: 0.734 | train roc: 0.892 | train f1: 0.7251173071443835
batch idx 151: | train loss: 0.5436308979988098 | train accu: 0.734 | train roc: 0.876 | train f1: 0.7109334625322998
batch idx 152: | train loss: 0.49769940972328186 | train accu: 0.750 | train roc: 0.887 | train f1: 0.7351349430062062
batch idx 153: | train loss: 0.5504583120346069 | train accu: 0.695 | train roc: 0.880 | train f1: 0.6773793522785458
batch idx 154: | train loss: 0.4165370762348175 | train accu: 0.820 | train roc: 0.920 | train f1: 0.8120258652066306
batch idx 155: | train loss: 0.5930152535438538 | train

batch idx 217: | train loss: 0.4789619743824005 | train accu: 0.742 | train roc: 0.906 | train f1: 0.7246473116384712
batch idx 218: | train loss: 0.6116453409194946 | train accu: 0.695 | train roc: 0.869 | train f1: 0.6832212273582687
batch idx 219: | train loss: 0.5208355784416199 | train accu: 0.758 | train roc: 0.905 | train f1: 0.7514245977502868
batch idx 220: | train loss: 0.4661950170993805 | train accu: 0.781 | train roc: 0.897 | train f1: 0.7773222117794486
batch idx 221: | train loss: 0.60245680809021 | train accu: 0.648 | train roc: 0.862 | train f1: 0.6316730242566511
batch idx 222: | train loss: 0.48790937662124634 | train accu: 0.773 | train roc: 0.897 | train f1: 0.763033536585366
batch idx 223: | train loss: 0.4947451055049896 | train accu: 0.719 | train roc: 0.899 | train f1: 0.6885989010989011
batch idx 224: | train loss: 0.6108207702636719 | train accu: 0.680 | train roc: 0.861 | train f1: 0.6610362465111643
batch idx 225: | train loss: 0.5377004146575928 | train ac

batch idx 44: | train loss: 0.5451882481575012 | train accu: 0.734 | train roc: 0.885 | train f1: 0.7247570784384795
batch idx 45: | train loss: 0.659453272819519 | train accu: 0.641 | train roc: 0.825 | train f1: 0.6307031250000001
batch idx 46: | train loss: 0.6251614093780518 | train accu: 0.680 | train roc: 0.866 | train f1: 0.6756905241935485
batch idx 47: | train loss: 0.5799912810325623 | train accu: 0.719 | train roc: 0.875 | train f1: 0.7056414662084765
batch idx 48: | train loss: 0.5297915935516357 | train accu: 0.773 | train roc: 0.875 | train f1: 0.7619582636566331
batch idx 49: | train loss: 0.5552845597267151 | train accu: 0.664 | train roc: 0.873 | train f1: 0.6422615361628737
batch idx 50: | train loss: 0.5301243662834167 | train accu: 0.680 | train roc: 0.898 | train f1: 0.647051282051282
batch idx 51: | train loss: 0.6028410792350769 | train accu: 0.672 | train roc: 0.841 | train f1: 0.6288244912790699
batch idx 52: | train loss: 0.5337679386138916 | train accu: 0.656

batch idx 114: | train loss: 0.5892453193664551 | train accu: 0.703 | train roc: 0.826 | train f1: 0.690165626047603
batch idx 115: | train loss: 0.470944344997406 | train accu: 0.750 | train roc: 0.909 | train f1: 0.7475039821690503
batch idx 116: | train loss: 0.5098384022712708 | train accu: 0.742 | train roc: 0.913 | train f1: 0.7289708646616542
batch idx 117: | train loss: 0.5981770157814026 | train accu: 0.656 | train roc: 0.874 | train f1: 0.6381174865779645
batch idx 118: | train loss: 0.5390107035636902 | train accu: 0.766 | train roc: 0.892 | train f1: 0.756036895253225
batch idx 119: | train loss: 0.5079920887947083 | train accu: 0.766 | train roc: 0.905 | train f1: 0.7655551912098075
batch idx 120: | train loss: 0.5625738501548767 | train accu: 0.695 | train roc: 0.913 | train f1: 0.6913247473870956
batch idx 121: | train loss: 0.5510715246200562 | train accu: 0.750 | train roc: 0.891 | train f1: 0.7416266925680874
batch idx 122: | train loss: 0.5150070190429688 | train acc

batch idx 184: | train loss: 0.5194953083992004 | train accu: 0.680 | train roc: 0.901 | train f1: 0.6506923811544991
batch idx 185: | train loss: 0.45370545983314514 | train accu: 0.773 | train roc: 0.932 | train f1: 0.7666164169026763
batch idx 186: | train loss: 0.5184860229492188 | train accu: 0.742 | train roc: 0.859 | train f1: 0.7466261108760051
batch idx 187: | train loss: 0.5567883849143982 | train accu: 0.727 | train roc: 0.889 | train f1: 0.725292248337689
batch idx 188: | train loss: 0.49478861689567566 | train accu: 0.727 | train roc: 0.899 | train f1: 0.7175138533674339
batch idx 189: | train loss: 0.43352270126342773 | train accu: 0.727 | train roc: 0.935 | train f1: 0.6909473623918896
batch idx 190: | train loss: 0.5110863447189331 | train accu: 0.727 | train roc: 0.853 | train f1: 0.6992797298384706
batch idx 191: | train loss: 0.5210874676704407 | train accu: 0.727 | train roc: 0.912 | train f1: 0.705016121031746
batch idx 192: | train loss: 0.4587893486022949 | train

batch idx 11: | train loss: 0.5052547454833984 | train accu: 0.734 | train roc: 0.925 | train f1: 0.7223660714285715
batch idx 12: | train loss: 0.39660245180130005 | train accu: 0.781 | train roc: 0.928 | train f1: 0.7676873059006211
batch idx 13: | train loss: 0.5216965675354004 | train accu: 0.727 | train roc: 0.901 | train f1: 0.7126677119165479
batch idx 14: | train loss: 0.5364985466003418 | train accu: 0.727 | train roc: 0.911 | train f1: 0.7203458536536615
batch idx 15: | train loss: 0.48108771443367004 | train accu: 0.789 | train roc: 0.896 | train f1: 0.7769454661051711
batch idx 16: | train loss: 0.5330414772033691 | train accu: 0.750 | train roc: 0.913 | train f1: 0.7517387068021923
batch idx 17: | train loss: 0.5723127722740173 | train accu: 0.758 | train roc: 0.899 | train f1: 0.7483634492767415
batch idx 18: | train loss: 0.6368551254272461 | train accu: 0.648 | train roc: 0.883 | train f1: 0.6407609554248493
batch idx 19: | train loss: 0.5350262522697449 | train accu: 0

batch idx 82: | train loss: 0.4897913634777069 | train accu: 0.758 | train roc: 0.916 | train f1: 0.7572244673237816
batch idx 83: | train loss: 0.4449706971645355 | train accu: 0.828 | train roc: 0.928 | train f1: 0.8257431315549476
batch idx 84: | train loss: 0.42794522643089294 | train accu: 0.789 | train roc: 0.928 | train f1: 0.7848121279761905
batch idx 85: | train loss: 0.44691866636276245 | train accu: 0.742 | train roc: 0.926 | train f1: 0.7183342086834734
batch idx 86: | train loss: 0.5146911144256592 | train accu: 0.789 | train roc: 0.894 | train f1: 0.7742956349206349
batch idx 87: | train loss: 0.39830559492111206 | train accu: 0.789 | train roc: 0.936 | train f1: 0.776309258993494
batch idx 88: | train loss: 0.4327946603298187 | train accu: 0.820 | train roc: 0.915 | train f1: 0.8108938079542163
batch idx 89: | train loss: 0.5167926549911499 | train accu: 0.695 | train roc: 0.890 | train f1: 0.6723559864723657
batch idx 90: | train loss: 0.547295093536377 | train accu: 0.

batch idx 152: | train loss: 0.5270892977714539 | train accu: 0.750 | train roc: 0.903 | train f1: 0.7305076215127984
batch idx 153: | train loss: 0.45883703231811523 | train accu: 0.781 | train roc: 0.905 | train f1: 0.7490851963932805
batch idx 154: | train loss: 0.3975779414176941 | train accu: 0.797 | train roc: 0.925 | train f1: 0.7799836601307191
batch idx 155: | train loss: 0.48022088408470154 | train accu: 0.727 | train roc: 0.917 | train f1: 0.7133826894423159
batch idx 156: | train loss: 0.44628089666366577 | train accu: 0.766 | train roc: 0.929 | train f1: 0.7512516408401102
batch idx 157: | train loss: 0.4985222816467285 | train accu: 0.766 | train roc: 0.908 | train f1: 0.7548651183341917
batch idx 158: | train loss: 0.5000792741775513 | train accu: 0.711 | train roc: 0.902 | train f1: 0.7060424498746867
batch idx 159: | train loss: 0.4544568359851837 | train accu: 0.750 | train roc: 0.919 | train f1: 0.7441307977736549
batch idx 160: | train loss: 0.7284483313560486 | tra

batch idx 222: | train loss: 0.3828570246696472 | train accu: 0.781 | train roc: 0.934 | train f1: 0.7565518465909091
batch idx 223: | train loss: 0.5068293213844299 | train accu: 0.750 | train roc: 0.900 | train f1: 0.7428868898883891
batch idx 224: | train loss: 0.38458847999572754 | train accu: 0.812 | train roc: 0.927 | train f1: 0.7978414035946029
batch idx 225: | train loss: 0.5654535293579102 | train accu: 0.703 | train roc: 0.871 | train f1: 0.6783869864179047
batch idx 226: | train loss: 0.465587854385376 | train accu: 0.727 | train roc: 0.907 | train f1: 0.6921337841969382
batch idx 227: | train loss: 0.4516223073005676 | train accu: 0.773 | train roc: 0.926 | train f1: 0.7700430052403735
batch idx 228: | train loss: 0.5901066660881042 | train accu: 0.719 | train roc: 0.858 | train f1: 0.7056899641577061
batch idx 229: | train loss: 0.45934247970581055 | train accu: 0.805 | train roc: 0.918 | train f1: 0.8026514938521517
batch idx 230: | train loss: 0.45462626218795776 | trai

batch idx 49: | train loss: 0.5526480674743652 | train accu: 0.750 | train roc: 0.884 | train f1: 0.7411855376325516
batch idx 50: | train loss: 0.514997661113739 | train accu: 0.789 | train roc: 0.890 | train f1: 0.7872556987900914
batch idx 51: | train loss: 0.5215384364128113 | train accu: 0.781 | train roc: 0.914 | train f1: 0.7842671251348436
batch idx 52: | train loss: 0.378592312335968 | train accu: 0.852 | train roc: 0.944 | train f1: 0.851278913486669
batch idx 53: | train loss: 0.5211778283119202 | train accu: 0.781 | train roc: 0.901 | train f1: 0.7790276759530792
batch idx 54: | train loss: 0.4427436888217926 | train accu: 0.758 | train roc: 0.932 | train f1: 0.7573418090062112
batch idx 55: | train loss: 0.44135236740112305 | train accu: 0.773 | train roc: 0.941 | train f1: 0.7637286389897857
batch idx 56: | train loss: 0.4947296977043152 | train accu: 0.750 | train roc: 0.903 | train f1: 0.74046875
batch idx 57: | train loss: 0.423003613948822 | train accu: 0.773 | train 

batch idx 119: | train loss: 0.4155066907405853 | train accu: 0.789 | train roc: 0.942 | train f1: 0.7864774816176471
batch idx 120: | train loss: 0.5265995860099792 | train accu: 0.727 | train roc: 0.863 | train f1: 0.7082142857142857
batch idx 121: | train loss: 0.3831344246864319 | train accu: 0.805 | train roc: 0.934 | train f1: 0.7958469105758126
batch idx 122: | train loss: 0.41346174478530884 | train accu: 0.812 | train roc: 0.927 | train f1: 0.8019551282051283
batch idx 123: | train loss: 0.5179879069328308 | train accu: 0.758 | train roc: 0.906 | train f1: 0.7438032710544141
batch idx 124: | train loss: 0.4221670925617218 | train accu: 0.766 | train roc: 0.920 | train f1: 0.7335539253000432
batch idx 125: | train loss: 0.4542403519153595 | train accu: 0.781 | train roc: 0.923 | train f1: 0.770686544991511
batch idx 126: | train loss: 0.347992867231369 | train accu: 0.844 | train roc: 0.948 | train f1: 0.8409844509179251
batch idx 127: | train loss: 0.44514548778533936 | train 

batch idx 189: | train loss: 0.4805808961391449 | train accu: 0.781 | train roc: 0.913 | train f1: 0.7659125707718177
batch idx 190: | train loss: 0.5184698700904846 | train accu: 0.734 | train roc: 0.898 | train f1: 0.7275904778364197
batch idx 191: | train loss: 0.4582973122596741 | train accu: 0.750 | train roc: 0.911 | train f1: 0.7444622312362548
batch idx 192: | train loss: 0.5622130036354065 | train accu: 0.750 | train roc: 0.904 | train f1: 0.7473807251908398
batch idx 193: | train loss: 0.4014185070991516 | train accu: 0.789 | train roc: 0.928 | train f1: 0.7809899566013697
batch idx 194: | train loss: 0.45610469579696655 | train accu: 0.797 | train roc: 0.928 | train f1: 0.7920298946101767
batch idx 195: | train loss: 0.5288923382759094 | train accu: 0.734 | train roc: 0.889 | train f1: 0.7160335354300871
batch idx 196: | train loss: 0.5613672733306885 | train accu: 0.695 | train roc: 0.889 | train f1: 0.683034093855328
batch idx 197: | train loss: 0.3758898675441742 | train 

batch idx 16: | train loss: 0.40063098073005676 | train accu: 0.797 | train roc: 0.903 | train f1: 0.7748122358816751
batch idx 17: | train loss: 0.3672890365123749 | train accu: 0.820 | train roc: 0.953 | train f1: 0.812591131314864
batch idx 18: | train loss: 0.3326413333415985 | train accu: 0.820 | train roc: 0.956 | train f1: 0.810055441634389
batch idx 19: | train loss: 0.4176826775074005 | train accu: 0.750 | train roc: 0.925 | train f1: 0.729828477443609
batch idx 20: | train loss: 0.45345360040664673 | train accu: 0.766 | train roc: 0.910 | train f1: 0.7637880986937591
batch idx 21: | train loss: 0.38278165459632874 | train accu: 0.844 | train roc: 0.938 | train f1: 0.8418537801484229
batch idx 22: | train loss: 0.3975241780281067 | train accu: 0.820 | train roc: 0.942 | train f1: 0.816950002929802
batch idx 23: | train loss: 0.4432690143585205 | train accu: 0.789 | train roc: 0.938 | train f1: 0.7877617431733286
batch idx 24: | train loss: 0.4504992663860321 | train accu: 0.81

batch idx 87: | train loss: 0.42158886790275574 | train accu: 0.734 | train roc: 0.925 | train f1: 0.7213046907107891
batch idx 88: | train loss: 0.49347206950187683 | train accu: 0.773 | train roc: 0.904 | train f1: 0.7708670236013986
batch idx 89: | train loss: 0.37206050753593445 | train accu: 0.781 | train roc: 0.942 | train f1: 0.7725095319976076
batch idx 90: | train loss: 0.4351467490196228 | train accu: 0.750 | train roc: 0.934 | train f1: 0.7385900444664033
batch idx 91: | train loss: 0.429010272026062 | train accu: 0.797 | train roc: 0.939 | train f1: 0.7934189003799071
batch idx 92: | train loss: 0.44829487800598145 | train accu: 0.742 | train roc: 0.931 | train f1: 0.722532236346733
batch idx 93: | train loss: 0.3481709659099579 | train accu: 0.859 | train roc: 0.951 | train f1: 0.8542998120300752
batch idx 94: | train loss: 0.4007661044597626 | train accu: 0.773 | train roc: 0.940 | train f1: 0.7704581314548005
batch idx 95: | train loss: 0.45691630244255066 | train accu: 

In [None]:
model.load_state_dict(torch.load('best_model_3.pt'))

In [None]:
valid_loss, valid_acc, valid_rocauc, valid_f1 = evaluate(model, valid_loader, criterion)
print("Valid loss: {} | Valid Acc: {:.3f} | Valid ROC-AUC: {} | Valid f1: {}".format(
    valid_loss, valid_acc, valid_rocauc, valid_f1))
test_loss, test_acc, test_rocauc, test_f1 = evaluate(model, test_loader, criterion)
print("Test loss: {} | Test Acc: {:.3f} | Test ROC-AUC: {} | Test f1: {}".format(
    test_loss, test_acc, test_rocauc, test_f1))

In [None]:
import matplotlib.pyplot as plt

def plot_history(hist):
    plt.figure(figsize=(10, 7))
    plt.plot(np.arange(1, args["n_epochs"] + 1), history["train_loss"], label="training loss")
    plt.plot(np.arange(1, args["n_epochs"] + 1), history["valid_loss"], label="validation loss")
    plt.legend(loc="best")
    plt.title("Training and Validation Losses")
    plt.show()

In [None]:
plot_history(history)

In [None]:
a= "COVID fears in Toronto: to me single biggest worry right now is this: the situation is massively worse for the average person now than it was at peak. why? Because at peak it was almost all LTCFs. Now? Unchecked community spread. That's terrifying to me."
a_encoded = tokenizer.encode(a, add_special_tokens=True)
a_final = torch.tensor(a_encoded + [0] * (max_len - len(a_encoded))).view(1, -1).to(device)
softmax = nn.Softmax(dim=1)
probs = softmax(model(a_final))
probs

## Randomized Search for optimal hyper