In [1]:
import numpy as np
from numpy import savetxt
import pandas as pd
import matplotlib.pyplot as plt
import time
from tqdm.notebook import tqdm
import sys

# gensim for pretrained embedding
from gensim.models import KeyedVectors
from gensim.scripts.glove2word2vec import glove2word2vec
from gensim.test.utils import datapath, get_tmpfile


# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils import data
from torch.autograd import Variable

# torchtext
import torchtext.vocab as vocab


from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

unable to import 'smart_open.gcs', disabling that module


In [2]:
print (torch.cuda.is_available())
print (torch.cuda.current_device())
print (torch.cuda.get_device_name(0))
print (torch.cuda.memory_allocated())
print (torch.cuda.memory_cached())

True
0
GeForce GTX 970
0
0


In [3]:
df = pd.read_csv("data/cleaned_amzn_data_4-15_10Kwords.csv", encoding='utf8', index_col=0)

In [4]:
drop_cols = ['review', 'cleaned_reviews']

try:
    df.drop(drop_cols, axis=1, inplace=True)
except:
    print ("Probably dropped already")
df = df.rename(columns={'overall': 'recommendation'})
df.head()

Unnamed: 0,recommendation,encoded_1,encoded_2,encoded_3,encoded_4,encoded_5,encoded_6,encoded_7,encoded_8,encoded_9,...,encoded_185,encoded_186,encoded_187,encoded_188,encoded_189,encoded_190,encoded_191,encoded_192,encoded_193,encoded_194
0,0,0,0,0,0,0,0,0,0,0,...,4059,9289,8594,9289,4934,7474,3382,652,2097,2876
1,1,0,0,0,0,0,0,0,0,0,...,3340,8561,9289,214,5126,6257,2827,6823,1256,8798
2,0,0,0,0,0,0,0,0,0,0,...,1745,5242,506,2434,7599,8764,5242,7146,6949,3506
4,1,0,0,0,0,0,0,0,0,0,...,7514,5853,5815,9606,595,8561,243,2076,2734,9289
5,1,0,0,0,0,0,0,0,0,0,...,8375,3595,1356,2298,8561,7502,2298,1329,6555,6758


# Parameters for Grid Search

In [5]:
LEARNING_RATES_LIST = [0.001, 0.0001]
EPOCHS_LIST = [50]
BATCH_SIZES_LIST = [64, 128]
LSTM_DIMS_LIST = [64, 128]
EMBED_DIMS_LIST = [100]

**Storage for global vars**

In [6]:
best_model = None
highest_acc = 0

# Load pretrained embedding

In [7]:
# google word2vec embedding #
# embed_path = 'data/GoogleNews-vectors-negative300.bin.gz'
# word2vec = KeyedVectors.load_word2vec_format(embed_path, binary=True)
# weights = word2vec.wv.vectors
# weights


# glove embedding #
glove_input_file = 'data/glove.6b/glove.6B.100d.txt'
word2vec_output_file = 'data/glove_to_word2vec.txt'

try:
    pretrained_embedding = KeyedVectors.load_word2vec_format(word2vec_output_file, binary=False)
except:
    print ("Converting word2vec file. If this fails, please download the glove.6b.100d file")
    glove2word2vec(glove_input_file, word2vec_output_file)

In [8]:
weights = pretrained_embedding.wv.vectors
pretrained_embedding.wv.vectors.shape

  """Entry point for launching an IPython kernel.
  


(400000, 100)

In [9]:
MAX_SEQ_LEN = len(df.columns.tolist())-1
VOCAB_SIZE = 14845 # 10746 - but need to use max(amzn_vocab, steam_vocab)

In [10]:
# only need 2-3 lines for attention
class Attention(nn.Module):
    def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
        super(Attention, self).__init__(**kwargs)
        
        self.supports_masking = True

        self.bias = bias
        self.feature_dim = feature_dim
        self.step_dim = step_dim
        self.features_dim = 0
        
        weight = torch.zeros(feature_dim, 1)
        nn.init.kaiming_uniform_(weight)
        self.weight = nn.Parameter(weight)
        
        if bias:
            self.b = nn.Parameter(torch.zeros(step_dim))
    
    def forward(self, x, mask=None):
        feature_dim = self.feature_dim 
        step_dim = self.step_dim

        eij = torch.mm(
            x.contiguous().view(-1, feature_dim), 
            self.weight
        ).view(-1, step_dim)
        
        if self.bias:
            eij = eij + self.b
            
        eij = torch.tanh(eij)
        a = torch.exp(eij)
        
        if mask is not None:
            a = a * mask

        a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)

        weighted_input = x * torch.unsqueeze(a, -1)
        return torch.sum(weighted_input, 1)

In [11]:
# build pytorch model
DROPOUT = 0.1
for LEARNING_RATE in LEARNING_RATES_LIST:
    for EPOCHS in EPOCHS_LIST:
        for EMBED_DIM in EMBED_DIMS_LIST:
            for LSTM_DIM in LSTM_DIMS_LIST:
                for BATCH_SIZE in BATCH_SIZES_LIST:
                    class Attention_Net(nn.Module):
                        def __init__(self):
                            super(Attention_Net, self).__init__()

                            # define architecture
                            # self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM) # add pretrained embeding
                            weights_ = Variable(torch.from_numpy(weights))
                            print (weights_.size())
                            self.embedding = nn.Embedding.from_pretrained(weights_)

                            self.lstm = nn.LSTM(EMBED_DIM, 
                                                LSTM_DIM, 
                                                bidirectional=True,
                                                dropout=0.2,
                                                batch_first=True)

                            # attention layer
                    #         self.attention_layer = Attention(LSTM_DIM * 2, MAX_SEQ_LEN)
                            # try tanh

                            self.linear = nn.Linear(LSTM_DIM*2, 2)

                        def forward(self, x):
                            embedding = self.embedding(x)
                            embedding = torch.squeeze(torch.unsqueeze(embedding, 0)).view(BATCH_SIZE, MAX_SEQ_LEN, -1)
                            lstm_out, (hidden, cell) = self.lstm(embedding)
                    #         attention = self.attention_layer(lstm_out)

                            out = self.linear(lstm_out[:, -1, :])
                            return out

                    class AmznDataset(data.Dataset):
                        def __init__(self, data):
                            #'Initialization'
                            self.data = data
                            text_cols = [x for x in df.columns.tolist() if x.startswith("encoded")]
                            self.train = torch.tensor(data[text_cols].values).type(torch.LongTensor).cuda()
                            labels = data['recommendation'].tolist()

                            self.one_hot_labels = torch.tensor(np.array(labels)).squeeze().type(torch.LongTensor).cuda() # change to longtensor if using custom loss


                        def __len__(self):
                            #'Denotes the total number of samples'
                            return len(self.data)

                        def __getitem__(self, index):
                            #'Generates one sample of data'

                            # Load data and get label
                            X = self.train[index]
                            Y = self.one_hot_labels[index]
                            return X, Y

                    train_num = int(0.8 * len(df))
                    amzn_dataset = AmznDataset(df[:train_num])
                    amzn_data_loader = data.DataLoader(amzn_dataset, batch_size=BATCH_SIZE, num_workers=0, drop_last=True, shuffle=True)
                    amzn_data_loader

                    attention_model = Attention_Net().cuda()
                    loss_function = nn.CrossEntropyLoss()
                    optimizer = optim.Adam(attention_model.parameters(), lr=LEARNING_RATE) # even lower for transfer learning

                    # training loop
                    start = time.time()

                    for i in range(EPOCHS):
                        second_start = time.time()
                        running_loss = 0
                        correct = 0
                        attention_model.train()

                        with tqdm(total=len(amzn_data_loader), file=sys.stdout) as pbar:
                            for idx, (train_X, train_Y) in enumerate(amzn_data_loader):

                                optimizer.zero_grad()

                                pred_y = attention_model(train_X) 
                                loss = loss_function(pred_y, train_Y)
                                loss.backward()
                                optimizer.step()
                                running_loss += loss

                                # calc accuracy
                                pred1_mask = pred_y[:, 1] > 0.5
                                masked_trainY_1 = train_Y[pred1_mask]
                                masked_trainY_0 = train_Y[~pred1_mask]
                                ones_predicted_correct = torch.sum(masked_trainY_1)
                                zeros_predicted_correct = torch.sum(masked_trainY_0)
                                correct += ones_predicted_correct.add(zeros_predicted_correct)
                                correct_ = correct.cpu().numpy()

                                # update progress bar
                                pbar.set_description('ep{} | loss: {} | acc: {}%'.format(i+1, torch.round(running_loss), round(correct_ / ((idx+1) * BATCH_SIZE)*100, 1)))
                                pbar.update(1)
                                tqdm._instances.clear()



                        print ('Epoch {} | took {} seconds | summed loss: {} | avg loss: {}'
                                       .format(i+1, time.time() - second_start, running_loss, running_loss / (len(amzn_data_loader) * BATCH_SIZE)))

                    print ("Took {} seconds".format(time.time() - start))

                    print (attention_model)

                    amzn_eval_dataset = AmznDataset(df[train_num:])
                    amzn_eval_data_loader = data.DataLoader(amzn_eval_dataset, batch_size=BATCH_SIZE, num_workers=0, drop_last=True)

                    # evaluate
                    correct = 0
                    eval_loss = 0
                    attention_model.eval()
                    with torch.no_grad():
                        for i, (test_X, test_Y) in enumerate(amzn_eval_data_loader):
                            preds = attention_model(test_X).squeeze()
                            preds = torch.nn.functional.softmax(preds)
                            for idx, each_pred in enumerate(preds):
                                if each_pred[0] >= 0.5 and test_Y[idx] == 0:
                                    correct += 1
                                elif each_pred[0] < 0.5 and test_Y[idx] == 1:
                                    correct += 1
                            loss = loss_function(preds, test_Y)
                            eval_loss += loss
                    
                    #save best run
                    accuracy = correct / len(amzn_eval_dataset)
                    if accuracy > highest_acc:
                        highest_acc = accuracy
                        best_model = attention_model.state_dict()
                        best_batch = BATCH_SIZE
                        best_lstm_dim = LSTM_DIM
                        best_epochs = EPOCHS
                        best_lr = LEARNING_RATE
                        
                        
                    print ("Eval accuracy: {}".format(correct / len(amzn_eval_dataset)))
                    print ("Eval summed loss: {} | avg loss: {}".format(eval_loss, eval_loss / len(amzn_eval_dataset)))

torch.Size([400000, 100])


  "num_layers={}".format(dropout, num_layers))


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 1 | took 48.88352966308594 seconds | summed loss: 1139.0201416015625 | avg loss: 0.007348137907683849


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 2 | took 48.80402636528015 seconds | summed loss: 948.1005859375 | avg loss: 0.006116462405771017


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 3 | took 48.43626379966736 seconds | summed loss: 870.7559204101562 | avg loss: 0.005617490503937006


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 4 | took 47.24839520454407 seconds | summed loss: 818.929931640625 | avg loss: 0.0052831461653113365


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 5 | took 47.31571841239929 seconds | summed loss: 779.294677734375 | avg loss: 0.005027448292821646


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 6 | took 48.493967056274414 seconds | summed loss: 744.1740112304688 | avg loss: 0.0048008752055466175


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 7 | took 49.741090059280396 seconds | summed loss: 713.6930541992188 | avg loss: 0.004604233894497156


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 8 | took 49.833125829696655 seconds | summed loss: 686.5621337890625 | avg loss: 0.004429204855114222


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 9 | took 49.32012438774109 seconds | summed loss: 657.7935180664062 | avg loss: 0.004243610426783562


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 10 | took 49.733802795410156 seconds | summed loss: 633.8463134765625 | avg loss: 0.00408912030979991


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 11 | took 49.22477602958679 seconds | summed loss: 605.163818359375 | avg loss: 0.003904081415385008


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 12 | took 49.09643363952637 seconds | summed loss: 581.6756591796875 | avg loss: 0.0037525526713579893


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 13 | took 49.03535223007202 seconds | summed loss: 559.8792724609375 | avg loss: 0.0036119380965828896


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 14 | took 49.17172884941101 seconds | summed loss: 539.0673217773438 | avg loss: 0.0034776742104440928


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 15 | took 47.125707149505615 seconds | summed loss: 514.1107788085938 | avg loss: 0.0033166727516800165


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 16 | took 48.801743268966675 seconds | summed loss: 494.9975280761719 | avg loss: 0.00319336773827672


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 17 | took 48.70255970954895 seconds | summed loss: 477.7843322753906 | avg loss: 0.003082320559769869


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 18 | took 49.11364960670471 seconds | summed loss: 454.3981018066406 | avg loss: 0.002931449329480529


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 19 | took 49.21627593040466 seconds | summed loss: 438.31195068359375 | avg loss: 0.0028276732191443443


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 20 | took 48.3027606010437 seconds | summed loss: 422.8690185546875 | avg loss: 0.0027280463837087154


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 21 | took 50.30768322944641 seconds | summed loss: 406.5371398925781 | avg loss: 0.0026226849295198917


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 22 | took 50.21878981590271 seconds | summed loss: 391.7413024902344 | avg loss: 0.0025272329803556204


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 23 | took 48.39933228492737 seconds | summed loss: 376.0062561035156 | avg loss: 0.002425721613690257


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 24 | took 50.13641023635864 seconds | summed loss: 364.40557861328125 | avg loss: 0.0023508823942393064


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 25 | took 49.19547986984253 seconds | summed loss: 349.1036071777344 | avg loss: 0.0022521652281284332


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 26 | took 49.45544147491455 seconds | summed loss: 339.3089599609375 | avg loss: 0.002188977086916566


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 27 | took 49.40216326713562 seconds | summed loss: 328.8640441894531 | avg loss: 0.002121594035997987


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 28 | took 47.26447153091431 seconds | summed loss: 318.01922607421875 | avg loss: 0.002051631221547723


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 29 | took 46.8647096157074 seconds | summed loss: 305.2007751464844 | avg loss: 0.0019689355976879597


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 30 | took 46.25269818305969 seconds | summed loss: 302.7147521972656 | avg loss: 0.001952897640876472


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 31 | took 46.23296666145325 seconds | summed loss: 289.0130310058594 | avg loss: 0.0018645040690898895


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 32 | took 48.065335273742676 seconds | summed loss: 281.2535095214844 | avg loss: 0.001814445131458342


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 33 | took 50.31537652015686 seconds | summed loss: 272.8815002441406 | avg loss: 0.0017604349413886666


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 34 | took 47.88261699676514 seconds | summed loss: 300.6541748046875 | avg loss: 0.0019396042916923761


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 35 | took 70.49311089515686 seconds | summed loss: 252.24070739746094 | avg loss: 0.001627275487408042


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 36 | took 74.59058880805969 seconds | summed loss: 266.898193359375 | avg loss: 0.001721834996715188


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 37 | took 73.00527954101562 seconds | summed loss: 261.4459533691406 | avg loss: 0.0016866610385477543


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 38 | took 73.75021457672119 seconds | summed loss: 248.8384552001953 | avg loss: 0.0016053265426307917


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 39 | took 72.81670069694519 seconds | summed loss: 249.13270568847656 | avg loss: 0.0016072249272838235


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 40 | took 73.56586050987244 seconds | summed loss: 244.26446533203125 | avg loss: 0.0015758185181766748


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 41 | took 73.4819483757019 seconds | summed loss: 237.51812744140625 | avg loss: 0.0015322959516197443


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 42 | took 71.70776653289795 seconds | summed loss: 233.3733673095703 | avg loss: 0.001505556982010603


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 43 | took 74.14649844169617 seconds | summed loss: 231.6949462890625 | avg loss: 0.001494728960096836


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 44 | took 75.97145128250122 seconds | summed loss: 227.80239868164062 | avg loss: 0.0014696171274408698


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 45 | took 76.65492343902588 seconds | summed loss: 227.6731414794922 | avg loss: 0.0014687832444906235


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 46 | took 78.86441922187805 seconds | summed loss: 218.3759002685547 | avg loss: 0.0014088040916249156


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 47 | took 76.14458012580872 seconds | summed loss: 212.96339416503906 | avg loss: 0.0013738864799961448


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 48 | took 85.59425830841064 seconds | summed loss: 237.20150756835938 | avg loss: 0.001530253328382969


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 49 | took 81.50543427467346 seconds | summed loss: 219.9207763671875 | avg loss: 0.0014187705237418413


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 50 | took 82.26378989219666 seconds | summed loss: 202.0861053466797 | avg loss: 0.0013037141179665923
Took 2870.2313010692596 seconds
Attention_Net(
  (embedding): Embedding(400000, 100)
  (lstm): LSTM(100, 64, batch_first=True, dropout=0.2, bidirectional=True)
  (linear): Linear(in_features=128, out_features=2, bias=True)
)




Eval accuracy: 0.8277810750135474
Eval summed loss: 288.8629455566406 | avg loss: 0.007453950587660074
torch.Size([400000, 100])


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 1 | took 56.5425660610199 seconds | summed loss: 587.4685668945312 | avg loss: 0.003789924317970872


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 2 | took 32.141677379608154 seconds | summed loss: 493.9292907714844 | avg loss: 0.003186476184055209


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 3 | took 32.20674538612366 seconds | summed loss: 453.40216064453125 | avg loss: 0.0029250243678689003


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 4 | took 32.834548473358154 seconds | summed loss: 429.5970458984375 | avg loss: 0.0027714509051293135


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 5 | took 33.48216509819031 seconds | summed loss: 410.0317687988281 | avg loss: 0.0026452296879142523


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 6 | took 33.54922294616699 seconds | summed loss: 396.97412109375 | avg loss: 0.0025609913282096386


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 7 | took 33.55231261253357 seconds | summed loss: 384.99615478515625 | avg loss: 0.0024837180972099304


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 8 | took 33.24259090423584 seconds | summed loss: 372.2865905761719 | avg loss: 0.0024017251562327147


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 9 | took 33.25823664665222 seconds | summed loss: 362.5625915527344 | avg loss: 0.0023389928974211216


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 10 | took 33.14546871185303 seconds | summed loss: 350.88458251953125 | avg loss: 0.0022636547219008207


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 11 | took 33.393667459487915 seconds | summed loss: 340.99835205078125 | avg loss: 0.0021998758893460035


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 12 | took 33.5687141418457 seconds | summed loss: 333.1418151855469 | avg loss: 0.0021491912193596363


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 13 | took 33.60169053077698 seconds | summed loss: 322.1197204589844 | avg loss: 0.002078084507957101


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 14 | took 32.87501788139343 seconds | summed loss: 312.8095397949219 | avg loss: 0.002018021885305643


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 15 | took 34.621339559555054 seconds | summed loss: 304.20892333984375 | avg loss: 0.001962536945939064


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 16 | took 32.75344491004944 seconds | summed loss: 295.61822509765625 | avg loss: 0.0019071160349994898


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 17 | took 31.89088225364685 seconds | summed loss: 284.0843200683594 | avg loss: 0.0018327075522392988


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 18 | took 54.332340717315674 seconds | summed loss: 276.16656494140625 | avg loss: 0.0017816278850659728


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 19 | took 54.146952629089355 seconds | summed loss: 264.9232177734375 | avg loss: 0.0017090939218178391


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 20 | took 53.00317025184631 seconds | summed loss: 260.1147155761719 | avg loss: 0.0016780728474259377


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 21 | took 53.254287242889404 seconds | summed loss: 251.3904266357422 | avg loss: 0.0016217901138588786


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 22 | took 53.147114515304565 seconds | summed loss: 243.0941925048828 | avg loss: 0.0015682687517255545


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 23 | took 51.45812726020813 seconds | summed loss: 234.439208984375 | avg loss: 0.0015124330529943109


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 24 | took 52.51570105552673 seconds | summed loss: 227.121337890625 | avg loss: 0.0014652233803644776


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 25 | took 52.80074954032898 seconds | summed loss: 220.68621826171875 | avg loss: 0.0014237086288630962


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 26 | took 50.30807042121887 seconds | summed loss: 212.16900634765625 | avg loss: 0.0013687616446986794


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 27 | took 51.88766074180603 seconds | summed loss: 205.9298553466797 | avg loss: 0.0013285111635923386


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 28 | took 52.086028814315796 seconds | summed loss: 202.8035125732422 | avg loss: 0.001308342325501144


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 29 | took 51.34104633331299 seconds | summed loss: 193.78004455566406 | avg loss: 0.0012501293094828725


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 30 | took 53.926902294158936 seconds | summed loss: 190.06228637695312 | avg loss: 0.001226145075634122


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 31 | took 53.598963260650635 seconds | summed loss: 182.82296752929688 | avg loss: 0.0011794421589002013


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 32 | took 51.69017767906189 seconds | summed loss: 178.79385375976562 | avg loss: 0.001153449178673327


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 33 | took 53.83328890800476 seconds | summed loss: 171.02012634277344 | avg loss: 0.0011032987385988235


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 34 | took 53.90896272659302 seconds | summed loss: 165.94869995117188 | avg loss: 0.0010705814929679036


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 35 | took 53.87122583389282 seconds | summed loss: 162.10494995117188 | avg loss: 0.0010457844473421574


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 36 | took 34.46932911872864 seconds | summed loss: 162.1345977783203 | avg loss: 0.001045975717715919


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 37 | took 32.135090351104736 seconds | summed loss: 152.13345336914062 | avg loss: 0.000981455552391708


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 38 | took 35.22176384925842 seconds | summed loss: 151.33480834960938 | avg loss: 0.0009763032430782914


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 39 | took 34.25524973869324 seconds | summed loss: 145.06292724609375 | avg loss: 0.0009358415845781565


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 40 | took 32.651774644851685 seconds | summed loss: 142.70704650878906 | avg loss: 0.0009206431568600237


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 41 | took 38.71059966087341 seconds | summed loss: 139.06381225585938 | avg loss: 0.0008971396018750966


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 42 | took 51.323944091796875 seconds | summed loss: 138.46287536621094 | avg loss: 0.0008932627970352769


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 43 | took 47.927802324295044 seconds | summed loss: 128.19403076171875 | avg loss: 0.0008270156104117632


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 44 | took 47.54757523536682 seconds | summed loss: 128.4368438720703 | avg loss: 0.0008285820367746055


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 45 | took 48.14254140853882 seconds | summed loss: 131.18313598632812 | avg loss: 0.0008462991681881249


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 46 | took 47.882190227508545 seconds | summed loss: 128.59568786621094 | avg loss: 0.0008296067826449871


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 47 | took 47.598814249038696 seconds | summed loss: 123.96809387207031 | avg loss: 0.0007997528882697225


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 48 | took 47.45559477806091 seconds | summed loss: 114.2665786743164 | avg loss: 0.0007371656829491258


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 49 | took 47.155938386917114 seconds | summed loss: 116.16864776611328 | avg loss: 0.0007494364981539547


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 50 | took 48.15639638900757 seconds | summed loss: 114.8936767578125 | avg loss: 0.0007412112900055945
Took 2178.492645740509 seconds
Attention_Net(
  (embedding): Embedding(400000, 100)
  (lstm): LSTM(100, 64, batch_first=True, dropout=0.2, bidirectional=True)
  (linear): Linear(in_features=128, out_features=2, bias=True)
)
Eval accuracy: 0.7884034784403788
Eval summed loss: 154.44448852539062 | avg loss: 0.003985355608165264
torch.Size([400000, 100])


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 1 | took 95.6023952960968 seconds | summed loss: 1144.7178955078125 | avg loss: 0.00738489581272006


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 2 | took 95.90421104431152 seconds | summed loss: 926.1547241210938 | avg loss: 0.005974883679300547


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 3 | took 97.33690071105957 seconds | summed loss: 831.4724731445312 | avg loss: 0.005364061798900366


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 4 | took 98.01703596115112 seconds | summed loss: 772.2047119140625 | avg loss: 0.004981708712875843


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 5 | took 100.7918848991394 seconds | summed loss: 713.7374877929688 | avg loss: 0.00460452027618885


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 6 | took 102.20597839355469 seconds | summed loss: 652.7891235351562 | avg loss: 0.004211325664073229


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 7 | took 102.68026161193848 seconds | summed loss: 591.97412109375 | avg loss: 0.003818990895524621


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 8 | took 104.19538950920105 seconds | summed loss: 525.8984375 | avg loss: 0.003392718033865094


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 9 | took 98.67047047615051 seconds | summed loss: 457.4631652832031 | avg loss: 0.002951222937554121


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 10 | took 75.0800416469574 seconds | summed loss: 395.49237060546875 | avg loss: 0.0025514320004731417


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 11 | took 65.29516363143921 seconds | summed loss: 339.6652526855469 | avg loss: 0.002191275591030717


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 12 | took 62.87720441818237 seconds | summed loss: 290.3785400390625 | avg loss: 0.0018733133329078555


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 13 | took 61.38602638244629 seconds | summed loss: 248.8980255126953 | avg loss: 0.0016057108296081424


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 14 | took 61.42015242576599 seconds | summed loss: 213.4970703125 | avg loss: 0.0013773293467238545


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 15 | took 61.335211753845215 seconds | summed loss: 187.37588500976562 | avg loss: 0.0012088143266737461


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 16 | took 61.533154249191284 seconds | summed loss: 171.5428924560547 | avg loss: 0.0011066712904721498


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 17 | took 61.3236825466156 seconds | summed loss: 148.68997192382812 | avg loss: 0.000959240656811744


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 18 | took 61.367443561553955 seconds | summed loss: 139.9540252685547 | avg loss: 0.0009028826025314629


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 19 | took 61.29001712799072 seconds | summed loss: 128.76803588867188 | avg loss: 0.0008307186653837562


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 20 | took 61.36760091781616 seconds | summed loss: 123.72327423095703 | avg loss: 0.0007981734815984964


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 21 | took 61.320016860961914 seconds | summed loss: 113.62411499023438 | avg loss: 0.0007330210064537823


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 22 | took 61.333799600601196 seconds | summed loss: 115.96735382080078 | avg loss: 0.0007481378852389753


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 23 | took 61.29945921897888 seconds | summed loss: 104.17273712158203 | avg loss: 0.0006720474921166897


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 24 | took 61.34031701087952 seconds | summed loss: 97.13455200195312 | avg loss: 0.000626642198767513


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 25 | took 61.37286138534546 seconds | summed loss: 102.67903900146484 | avg loss: 0.0006624112138524652


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 26 | took 61.38428783416748 seconds | summed loss: 93.85041046142578 | avg loss: 0.0006054553086869419


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 27 | took 61.2910373210907 seconds | summed loss: 92.87922668457031 | avg loss: 0.0005991898942738771


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 28 | took 61.40554857254028 seconds | summed loss: 97.53352355957031 | avg loss: 0.0006292160833254457


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 29 | took 61.40261149406433 seconds | summed loss: 87.65447998046875 | avg loss: 0.0005654835840687156


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 30 | took 61.78601264953613 seconds | summed loss: 94.87063598632812 | avg loss: 0.0006120370235294104


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 31 | took 61.25092911720276 seconds | summed loss: 86.14989471435547 | avg loss: 0.0005557771073654294


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 32 | took 61.268234729766846 seconds | summed loss: 84.04969024658203 | avg loss: 0.0005422281101346016


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 33 | took 61.27816200256348 seconds | summed loss: 83.9260482788086 | avg loss: 0.0005414304323494434


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 34 | took 61.34725213050842 seconds | summed loss: 82.42933654785156 | avg loss: 0.0005317747127264738


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 35 | took 61.39849662780762 seconds | summed loss: 80.75968933105469 | avg loss: 0.0005210033850744367


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 36 | took 61.39077687263489 seconds | summed loss: 81.61888885498047 | avg loss: 0.0005265463260002434


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 37 | took 61.286330699920654 seconds | summed loss: 80.87782287597656 | avg loss: 0.0005217654979787767


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 38 | took 61.30494713783264 seconds | summed loss: 78.82878112792969 | avg loss: 0.0005085465381853282


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 39 | took 61.30362343788147 seconds | summed loss: 75.83354187011719 | avg loss: 0.0004892233991995454


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 40 | took 61.36105465888977 seconds | summed loss: 81.74138641357422 | avg loss: 0.0005273365532048047


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 41 | took 61.32674813270569 seconds | summed loss: 74.13951873779297 | avg loss: 0.00047829479444772005


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 42 | took 61.336336612701416 seconds | summed loss: 73.6639175415039 | avg loss: 0.00047522655222564936


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 43 | took 61.34622383117676 seconds | summed loss: 77.8762435913086 | avg loss: 0.0005024014390073717


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 44 | took 61.50438046455383 seconds | summed loss: 71.68363189697266 | avg loss: 0.00046245119301602244


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 45 | took 61.41961669921875 seconds | summed loss: 75.30152130126953 | avg loss: 0.00048579121357761323


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 46 | took 61.430357694625854 seconds | summed loss: 71.13533020019531 | avg loss: 0.0004589139425661415


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 47 | took 61.48274898529053 seconds | summed loss: 76.2435302734375 | avg loss: 0.0004918683553114533


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 48 | took 61.46987199783325 seconds | summed loss: 65.12422180175781 | avg loss: 0.0004201346018817276


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 49 | took 61.450555086135864 seconds | summed loss: 73.4808349609375 | avg loss: 0.0004740454605780542


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 50 | took 61.450563192367554 seconds | summed loss: 64.67162322998047 | avg loss: 0.00041721476009115577
Took 3431.0955550670624 seconds
Attention_Net(
  (embedding): Embedding(400000, 100)
  (lstm): LSTM(100, 128, batch_first=True, dropout=0.2, bidirectional=True)
  (linear): Linear(in_features=256, out_features=2, bias=True)
)
Eval accuracy: 0.8286584264444044
Eval summed loss: 289.04754638671875 | avg loss: 0.007458714302629232
torch.Size([400000, 100])


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 1 | took 44.10019874572754 seconds | summed loss: 592.372802734375 | avg loss: 0.0038215629756450653


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 2 | took 43.97172474861145 seconds | summed loss: 486.9513854980469 | avg loss: 0.0031414597760885954


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 3 | took 44.19677925109863 seconds | summed loss: 436.83209228515625 | avg loss: 0.002818126231431961


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 4 | took 43.96434760093689 seconds | summed loss: 407.0212707519531 | avg loss: 0.0026258081197738647


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 5 | took 43.952603816986084 seconds | summed loss: 380.27191162109375 | avg loss: 0.002453240565955639


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 6 | took 43.948357343673706 seconds | summed loss: 356.1169128417969 | avg loss: 0.0022974100429564714


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 7 | took 43.93590450286865 seconds | summed loss: 330.2706298828125 | avg loss: 0.0021306683775037527


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 8 | took 43.93308448791504 seconds | summed loss: 304.6337890625 | avg loss: 0.0019652778282761574


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 9 | took 43.96021366119385 seconds | summed loss: 278.2686767578125 | avg loss: 0.0017951892223209143


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 10 | took 43.92073321342468 seconds | summed loss: 252.37496948242188 | avg loss: 0.0016281416174024343


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 11 | took 43.950252294540405 seconds | summed loss: 225.3983917236328 | avg loss: 0.0014541081618517637


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 12 | took 43.846924781799316 seconds | summed loss: 199.12049865722656 | avg loss: 0.0012845820747315884


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 13 | took 43.809571504592896 seconds | summed loss: 174.6566925048828 | avg loss: 0.0011267592199146748


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 14 | took 43.853286266326904 seconds | summed loss: 151.91043090820312 | avg loss: 0.0009800167754292488


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 15 | took 43.81916379928589 seconds | summed loss: 134.84487915039062 | avg loss: 0.0008699221070855856


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 16 | took 43.80668354034424 seconds | summed loss: 115.26805877685547 | avg loss: 0.0007436265586875379


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 17 | took 43.80595541000366 seconds | summed loss: 103.56733703613281 | avg loss: 0.0006681418744847178


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 18 | took 43.8823983669281 seconds | summed loss: 92.32363891601562 | avg loss: 0.0005956056411378086


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 19 | took 43.92766094207764 seconds | summed loss: 79.15406799316406 | avg loss: 0.0005106450407765806


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 20 | took 43.83591294288635 seconds | summed loss: 72.01655578613281 | avg loss: 0.000464598968392238


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 21 | took 43.84594106674194 seconds | summed loss: 64.00621795654297 | avg loss: 0.0004129220324102789


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 22 | took 43.82134008407593 seconds | summed loss: 64.42146301269531 | avg loss: 0.00041560089448466897


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 23 | took 43.86390805244446 seconds | summed loss: 56.703670501708984 | avg loss: 0.0003658112545963377


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 24 | took 43.846463203430176 seconds | summed loss: 47.14365768432617 | avg loss: 0.00030413694912567735


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 25 | took 43.90849423408508 seconds | summed loss: 52.93524169921875 | avg loss: 0.00034150006831623614


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 26 | took 43.909220933914185 seconds | summed loss: 48.16199493408203 | avg loss: 0.0003107065276708454


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 27 | took 43.82876515388489 seconds | summed loss: 44.51277542114258 | avg loss: 0.0002871643810067326


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 28 | took 43.84507703781128 seconds | summed loss: 42.94059371948242 | avg loss: 0.0002770218125078827


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 29 | took 43.834248781204224 seconds | summed loss: 39.03020095825195 | avg loss: 0.0002517947577871382


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 30 | took 43.7915620803833 seconds | summed loss: 40.368682861328125 | avg loss: 0.0002604296896606684


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 31 | took 43.85541558265686 seconds | summed loss: 46.81563949584961 | avg loss: 0.00030202080961316824


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 32 | took 43.82013750076294 seconds | summed loss: 35.048362731933594 | avg loss: 0.00022610680025536567


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 33 | took 43.820286989212036 seconds | summed loss: 38.33414077758789 | avg loss: 0.00024730426957830787


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 34 | took 43.87869381904602 seconds | summed loss: 38.68189239501953 | avg loss: 0.00024954770924523473


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 35 | took 43.78467082977295 seconds | summed loss: 34.19429397583008 | avg loss: 0.00022059696493670344


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 36 | took 43.86130714416504 seconds | summed loss: 40.655941009521484 | avg loss: 0.0002622828760650009


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 37 | took 43.8692045211792 seconds | summed loss: 34.122745513916016 | avg loss: 0.00022013539273757488


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 38 | took 43.81885504722595 seconds | summed loss: 36.915428161621094 | avg loss: 0.0002381517697358504


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 39 | took 43.842447996139526 seconds | summed loss: 30.359521865844727 | avg loss: 0.00019585777772590518


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 40 | took 44.254136085510254 seconds | summed loss: 30.586219787597656 | avg loss: 0.0001973202743101865


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 41 | took 43.75328993797302 seconds | summed loss: 31.725357055664062 | avg loss: 0.0002046691661234945


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 42 | took 43.81260108947754 seconds | summed loss: 37.66170883178711 | avg loss: 0.00024296622723340988


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 43 | took 43.794564723968506 seconds | summed loss: 27.79277992248535 | avg loss: 0.00017929900786839426


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 44 | took 43.88371729850769 seconds | summed loss: 27.828577041625977 | avg loss: 0.00017952994676306844


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 45 | took 43.793243169784546 seconds | summed loss: 32.41975402832031 | avg loss: 0.00020914891501888633


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 46 | took 43.801645278930664 seconds | summed loss: 33.345211029052734 | avg loss: 0.00021511930390261114


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 47 | took 43.747464418411255 seconds | summed loss: 31.239288330078125 | avg loss: 0.0002015334030147642


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 48 | took 43.873533725738525 seconds | summed loss: 25.867677688598633 | avg loss: 0.00016687963216099888


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 49 | took 43.80359101295471 seconds | summed loss: 28.61679458618164 | avg loss: 0.00018461495346855372


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 50 | took 43.80343317985535 seconds | summed loss: 28.6556453704834 | avg loss: 0.0001848655956564471
Took 2193.8549666404724 seconds
Attention_Net(
  (embedding): Embedding(400000, 100)
  (lstm): LSTM(100, 128, batch_first=True, dropout=0.2, bidirectional=True)
  (linear): Linear(in_features=256, out_features=2, bias=True)
)
Eval accuracy: 0.8340773617526385
Eval summed loss: 142.3045654296875 | avg loss: 0.003672091756016016
torch.Size([400000, 100])


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 1 | took 44.71309018135071 seconds | summed loss: 1280.8297119140625 | avg loss: 0.008262990973889828


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 2 | took 44.7454776763916 seconds | summed loss: 1181.5570068359375 | avg loss: 0.007622555363923311


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 3 | took 44.76343011856079 seconds | summed loss: 1110.148681640625 | avg loss: 0.007161879912018776


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 4 | took 44.764230251312256 seconds | summed loss: 1055.34033203125 | avg loss: 0.0068082963116467


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 5 | took 44.719478607177734 seconds | summed loss: 1018.7325439453125 | avg loss: 0.00657212920486927


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 6 | took 44.70217561721802 seconds | summed loss: 989.232421875 | avg loss: 0.0063818152993917465


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 7 | took 44.63876748085022 seconds | summed loss: 965.9315185546875 | avg loss: 0.006231494713574648


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 8 | took 44.736103773117065 seconds | summed loss: 945.9450073242188 | avg loss: 0.006102556362748146


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 9 | took 44.847270488739014 seconds | summed loss: 929.41259765625 | avg loss: 0.005995900835841894


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 10 | took 45.20586085319519 seconds | summed loss: 917.6075439453125 | avg loss: 0.005919743329286575


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 11 | took 44.732054233551025 seconds | summed loss: 902.85400390625 | avg loss: 0.005824564024806023


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 12 | took 44.7690749168396 seconds | summed loss: 893.9813232421875 | avg loss: 0.005767324008047581


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 13 | took 44.72795820236206 seconds | summed loss: 882.4024658203125 | avg loss: 0.005692625418305397


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 14 | took 44.732829332351685 seconds | summed loss: 873.1064453125 | avg loss: 0.005632654298096895


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 15 | took 44.76404929161072 seconds | summed loss: 866.9088745117188 | avg loss: 0.005592672154307365


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 16 | took 44.766642332077026 seconds | summed loss: 857.791748046875 | avg loss: 0.005533854942768812


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 17 | took 44.75744342803955 seconds | summed loss: 852.8198852539062 | avg loss: 0.005501780193299055


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 18 | took 44.71976709365845 seconds | summed loss: 843.6814575195312 | avg loss: 0.005442825611680746


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 19 | took 44.72535300254822 seconds | summed loss: 838.2529296875 | avg loss: 0.005407804623246193


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 20 | took 44.81429052352905 seconds | summed loss: 831.7166137695312 | avg loss: 0.005365636665374041


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 21 | took 44.78931641578674 seconds | summed loss: 826.2283325195312 | avg loss: 0.005330230575054884


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 22 | took 44.76087307929993 seconds | summed loss: 822.5127563476562 | avg loss: 0.005306260194629431


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 23 | took 44.69442558288574 seconds | summed loss: 815.3113403320312 | avg loss: 0.005259801633656025


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 24 | took 44.72688150405884 seconds | summed loss: 809.4909057617188 | avg loss: 0.005222252570092678


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 25 | took 44.86675000190735 seconds | summed loss: 804.3583374023438 | avg loss: 0.005189140792936087


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 26 | took 44.948097229003906 seconds | summed loss: 801.30126953125 | avg loss: 0.005169419106096029


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 27 | took 44.864299297332764 seconds | summed loss: 794.9715576171875 | avg loss: 0.005128584336489439


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 28 | took 44.788838386535645 seconds | summed loss: 790.7387084960938 | avg loss: 0.0051012770272791386


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 29 | took 44.81825137138367 seconds | summed loss: 788.8309936523438 | avg loss: 0.005088969599455595


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 30 | took 44.94153952598572 seconds | summed loss: 784.5574340820312 | avg loss: 0.005061399657279253


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 31 | took 44.71788835525513 seconds | summed loss: 780.4039916992188 | avg loss: 0.005034604575484991


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 32 | took 44.85487103462219 seconds | summed loss: 774.4432983398438 | avg loss: 0.004996150732040405


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 33 | took 44.94241189956665 seconds | summed loss: 771.7864990234375 | avg loss: 0.004979010671377182


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 34 | took 44.81264591217041 seconds | summed loss: 767.6890869140625 | avg loss: 0.004952577408403158


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 35 | took 44.70714497566223 seconds | summed loss: 764.7955322265625 | avg loss: 0.0049339099787175655


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 36 | took 44.74571704864502 seconds | summed loss: 760.8866577148438 | avg loss: 0.0049086930230259895


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 37 | took 44.728922843933105 seconds | summed loss: 757.4923095703125 | avg loss: 0.004886795300990343


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 38 | took 44.81194067001343 seconds | summed loss: 753.6105346679688 | avg loss: 0.00486175250262022


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 39 | took 44.77926802635193 seconds | summed loss: 750.8595581054688 | avg loss: 0.004844005219638348


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 40 | took 44.770293951034546 seconds | summed loss: 747.2393798828125 | avg loss: 0.004820650443434715


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 41 | took 44.665794372558594 seconds | summed loss: 742.2608032226562 | avg loss: 0.004788532387465239


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 42 | took 44.78773307800293 seconds | summed loss: 739.6954956054688 | avg loss: 0.0047719827853143215


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 43 | took 44.760796546936035 seconds | summed loss: 736.1255493164062 | avg loss: 0.004748952109366655


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 44 | took 44.71010613441467 seconds | summed loss: 734.7365112304688 | avg loss: 0.004739990923553705


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 45 | took 44.7234525680542 seconds | summed loss: 731.02978515625 | avg loss: 0.004716077819466591


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 46 | took 44.701518535614014 seconds | summed loss: 727.437744140625 | avg loss: 0.0046929046511650085


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 47 | took 44.73020124435425 seconds | summed loss: 725.6038818359375 | avg loss: 0.004681074060499668


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 48 | took 44.8558144569397 seconds | summed loss: 721.8973388671875 | avg loss: 0.004657161887735128


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 49 | took 44.87369918823242 seconds | summed loss: 717.9417114257812 | avg loss: 0.004631643183529377


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 50 | took 44.766170263290405 seconds | summed loss: 715.10205078125 | avg loss: 0.004613323602825403
Took 2239.0489325523376 seconds
Attention_Net(
  (embedding): Embedding(400000, 100)
  (lstm): LSTM(100, 64, batch_first=True, dropout=0.2, bidirectional=True)
  (linear): Linear(in_features=128, out_features=2, bias=True)
)
Eval accuracy: 0.8570949345857095
Eval summed loss: 280.8079528808594 | avg loss: 0.0072460961528122425
torch.Size([400000, 100])


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 1 | took 30.200644731521606 seconds | summed loss: 653.821533203125 | avg loss: 0.004217986017465591


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 2 | took 30.257193565368652 seconds | summed loss: 609.2301025390625 | avg loss: 0.003930313978344202


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 3 | took 30.147531986236572 seconds | summed loss: 586.0955810546875 | avg loss: 0.0037810667417943478


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 4 | took 30.286689519882202 seconds | summed loss: 563.2030639648438 | avg loss: 0.0036333806347101927


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 5 | took 30.271993160247803 seconds | summed loss: 541.1929931640625 | avg loss: 0.003491387702524662


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 6 | took 30.245906591415405 seconds | summed loss: 524.6987915039062 | avg loss: 0.003384978976100683


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 7 | took 30.227763652801514 seconds | summed loss: 512.4710083007812 | avg loss: 0.003306094091385603


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 8 | took 30.24718475341797 seconds | summed loss: 501.2860107421875 | avg loss: 0.003233936382457614


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 9 | took 30.217519760131836 seconds | summed loss: 492.1490478515625 | avg loss: 0.0031749913468956947


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 10 | took 30.172611713409424 seconds | summed loss: 485.8370056152344 | avg loss: 0.003134270664304495


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 11 | took 30.222018003463745 seconds | summed loss: 477.46551513671875 | avg loss: 0.0030802637338638306


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 12 | took 30.279457807540894 seconds | summed loss: 472.64910888671875 | avg loss: 0.0030491917859762907


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 13 | took 30.163031578063965 seconds | summed loss: 466.142822265625 | avg loss: 0.0030072180088609457


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 14 | took 30.165637016296387 seconds | summed loss: 462.956298828125 | avg loss: 0.0029866606928408146


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 15 | took 30.13159441947937 seconds | summed loss: 457.8348083496094 | avg loss: 0.0029536206275224686


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 16 | took 30.163182497024536 seconds | summed loss: 454.3245849609375 | avg loss: 0.002930975053459406


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 17 | took 30.168170928955078 seconds | summed loss: 450.181640625 | avg loss: 0.002904247958213091


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 18 | took 30.17327880859375 seconds | summed loss: 447.7882995605469 | avg loss: 0.0028888077940791845


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 19 | took 30.142637968063354 seconds | summed loss: 442.6481018066406 | avg loss: 0.002855646889656782


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 20 | took 30.19191336631775 seconds | summed loss: 441.17626953125 | avg loss: 0.00284615159034729


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 21 | took 30.13976740837097 seconds | summed loss: 436.9469909667969 | avg loss: 0.0028188673313707113


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 22 | took 30.14277720451355 seconds | summed loss: 435.4394226074219 | avg loss: 0.0028091417625546455


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 23 | took 30.323749542236328 seconds | summed loss: 431.40496826171875 | avg loss: 0.0027831143233925104


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 24 | took 30.375959157943726 seconds | summed loss: 430.6361389160156 | avg loss: 0.002778154332190752


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 25 | took 30.597779989242554 seconds | summed loss: 427.10308837890625 | avg loss: 0.0027553616091609


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 26 | took 30.72792887687683 seconds | summed loss: 424.5196533203125 | avg loss: 0.002738695126026869


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 27 | took 30.71666717529297 seconds | summed loss: 422.3311767578125 | avg loss: 0.0027245767414569855


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 28 | took 30.78982162475586 seconds | summed loss: 422.29241943359375 | avg loss: 0.0027243266813457012


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 29 | took 30.82657790184021 seconds | summed loss: 419.7847595214844 | avg loss: 0.002708149142563343


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 30 | took 30.64208459854126 seconds | summed loss: 417.7420349121094 | avg loss: 0.002694970928132534


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 31 | took 30.60966157913208 seconds | summed loss: 414.9217224121094 | avg loss: 0.0026767761446535587


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 32 | took 30.579084157943726 seconds | summed loss: 413.0580749511719 | avg loss: 0.00266475323587656


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 33 | took 30.62530827522278 seconds | summed loss: 411.10162353515625 | avg loss: 0.0026521317195147276


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 34 | took 30.527672052383423 seconds | summed loss: 409.14923095703125 | avg loss: 0.002639536280184984


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 35 | took 30.581414461135864 seconds | summed loss: 408.9769287109375 | avg loss: 0.0026384247466921806


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 36 | took 30.629358530044556 seconds | summed loss: 407.250732421875 | avg loss: 0.002627288457006216


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 37 | took 30.67992925643921 seconds | summed loss: 404.75872802734375 | avg loss: 0.0026112119667232037


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 38 | took 30.590718507766724 seconds | summed loss: 404.6126708984375 | avg loss: 0.0026102697011083364


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 39 | took 30.540529251098633 seconds | summed loss: 403.1177062988281 | avg loss: 0.002600625157356262


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 40 | took 30.468977212905884 seconds | summed loss: 401.51824951171875 | avg loss: 0.0025903068017214537


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 41 | took 30.618051052093506 seconds | summed loss: 399.0054931640625 | avg loss: 0.0025740962009876966


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 42 | took 31.066184997558594 seconds | summed loss: 397.3922424316406 | avg loss: 0.0025636886712163687


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 43 | took 31.096785068511963 seconds | summed loss: 395.7444152832031 | avg loss: 0.0025530580896884203


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 44 | took 30.56187081336975 seconds | summed loss: 395.266357421875 | avg loss: 0.002549974014982581


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 45 | took 30.607177734375 seconds | summed loss: 393.42791748046875 | avg loss: 0.002538113621994853


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 46 | took 30.64014744758606 seconds | summed loss: 392.538330078125 | avg loss: 0.0025323748122900724


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 47 | took 30.625843286514282 seconds | summed loss: 390.3876647949219 | avg loss: 0.0025185002014040947


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 48 | took 30.61044216156006 seconds | summed loss: 388.1963806152344 | avg loss: 0.0025043636560440063


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 49 | took 30.58130383491516 seconds | summed loss: 389.28887939453125 | avg loss: 0.0025114116724580526


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 50 | took 30.690624475479126 seconds | summed loss: 386.72015380859375 | avg loss: 0.0024948399513959885
Took 1522.3530540466309 seconds
Attention_Net(
  (embedding): Embedding(400000, 100)
  (lstm): LSTM(100, 64, batch_first=True, dropout=0.2, bidirectional=True)
  (linear): Linear(in_features=128, out_features=2, bias=True)
)
Eval accuracy: 0.8523469150775423
Eval summed loss: 141.1426544189453 | avg loss: 0.0036421092227101326
torch.Size([400000, 100])


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 1 | took 61.995505571365356 seconds | summed loss: 1266.7259521484375 | avg loss: 0.00817200355231762


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 2 | took 61.77913761138916 seconds | summed loss: 1173.651123046875 | avg loss: 0.007571551948785782


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 3 | took 61.94371557235718 seconds | summed loss: 1085.311767578125 | avg loss: 0.007001650054007769


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 4 | took 62.43803906440735 seconds | summed loss: 1021.7838134765625 | avg loss: 0.006591813638806343


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 5 | took 62.026779651641846 seconds | summed loss: 979.8098754882812 | avg loss: 0.006321027874946594


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 6 | took 61.91066908836365 seconds | summed loss: 950.5851440429688 | avg loss: 0.006132490932941437


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 7 | took 61.944419145584106 seconds | summed loss: 927.5241088867188 | avg loss: 0.005983717739582062


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 8 | took 61.84855127334595 seconds | summed loss: 909.4969482421875 | avg loss: 0.0058674197643995285


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 9 | took 61.82961893081665 seconds | summed loss: 891.1504516601562 | avg loss: 0.0057490612380206585


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 10 | took 61.760247230529785 seconds | summed loss: 877.4931030273438 | avg loss: 0.0056609539315104485


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 11 | took 61.795857429504395 seconds | summed loss: 864.0142822265625 | avg loss: 0.0055739982053637505


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 12 | took 61.88647174835205 seconds | summed loss: 851.3230590820312 | avg loss: 0.005492123309522867


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 13 | took 61.864519357681274 seconds | summed loss: 843.6232299804688 | avg loss: 0.005442449823021889


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 14 | took 61.85058617591858 seconds | summed loss: 829.4486083984375 | avg loss: 0.005351005122065544


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 15 | took 61.81273579597473 seconds | summed loss: 823.3616943359375 | avg loss: 0.005311736837029457


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 16 | took 61.769222021102905 seconds | summed loss: 812.5009155273438 | avg loss: 0.005241671111434698


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 17 | took 61.87558555603027 seconds | summed loss: 805.9449462890625 | avg loss: 0.005199376493692398


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 18 | took 62.07761216163635 seconds | summed loss: 795.0942993164062 | avg loss: 0.005129375960677862


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 19 | took 62.00162363052368 seconds | summed loss: 788.7221069335938 | avg loss: 0.005088267382234335


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 20 | took 61.9319908618927 seconds | summed loss: 782.1483154296875 | avg loss: 0.00504585774615407


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 21 | took 61.83932447433472 seconds | summed loss: 775.09765625 | avg loss: 0.00500037195160985


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 22 | took 61.862343072891235 seconds | summed loss: 765.4334106445312 | avg loss: 0.004938025493174791


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 23 | took 61.79703712463379 seconds | summed loss: 759.4224853515625 | avg loss: 0.004899247083812952


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 24 | took 61.89706206321716 seconds | summed loss: 752.2681274414062 | avg loss: 0.004853092599660158


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 25 | took 61.81681036949158 seconds | summed loss: 743.8848266601562 | avg loss: 0.004799009300768375


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 26 | took 61.80924725532532 seconds | summed loss: 741.3629760742188 | avg loss: 0.004782740026712418


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 27 | took 61.80795454978943 seconds | summed loss: 732.3431396484375 | avg loss: 0.004724550526589155


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 28 | took 61.88516068458557 seconds | summed loss: 727.698974609375 | avg loss: 0.004694589879363775


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 29 | took 61.88806176185608 seconds | summed loss: 719.4456787109375 | avg loss: 0.004641345702111721


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 30 | took 61.83678364753723 seconds | summed loss: 710.2293701171875 | avg loss: 0.004581888671964407


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 31 | took 61.92767596244812 seconds | summed loss: 705.8535766601562 | avg loss: 0.004553659353405237


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 32 | took 61.91239786148071 seconds | summed loss: 698.4500732421875 | avg loss: 0.00450589694082737


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 33 | took 61.909178256988525 seconds | summed loss: 692.3003540039062 | avg loss: 0.004466223530471325


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 34 | took 61.919256925582886 seconds | summed loss: 686.3583984375 | avg loss: 0.004427890293300152


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 35 | took 61.77137589454651 seconds | summed loss: 680.2569580078125 | avg loss: 0.004388528410345316


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 36 | took 61.749327659606934 seconds | summed loss: 671.3093872070312 | avg loss: 0.004330805037170649


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 37 | took 61.978156328201294 seconds | summed loss: 667.147705078125 | avg loss: 0.0043039568699896336


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 38 | took 61.91590142250061 seconds | summed loss: 655.3480224609375 | avg loss: 0.0042278338223695755


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 39 | took 61.83102750778198 seconds | summed loss: 653.2134399414062 | avg loss: 0.004214062821120024


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 40 | took 61.85589814186096 seconds | summed loss: 646.64404296875 | avg loss: 0.004171682056039572


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 41 | took 61.957563638687134 seconds | summed loss: 638.3878784179688 | avg loss: 0.004118419252336025


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 42 | took 61.90033745765686 seconds | summed loss: 631.7467651367188 | avg loss: 0.004075575154274702


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 43 | took 61.88641357421875 seconds | summed loss: 624.131103515625 | avg loss: 0.004026444628834724


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 44 | took 61.746490240097046 seconds | summed loss: 619.0786743164062 | avg loss: 0.003993849735707045


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 45 | took 61.76023864746094 seconds | summed loss: 612.4217529296875 | avg loss: 0.003950904123485088


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 46 | took 61.8553569316864 seconds | summed loss: 606.2111206054688 | avg loss: 0.003910837695002556


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 47 | took 61.914551973342896 seconds | summed loss: 597.7765502929688 | avg loss: 0.0038564240094274282


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 48 | took 61.89403581619263 seconds | summed loss: 591.9274291992188 | avg loss: 0.003818689612671733


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 49 | took 61.820072650909424 seconds | summed loss: 582.473876953125 | avg loss: 0.003757702186703682


HBox(children=(IntProgress(value=0, max=2422), HTML(value='')))


Epoch 50 | took 61.8701171875 seconds | summed loss: 576.369873046875 | avg loss: 0.003718323539942503
Took 3094.2199823856354 seconds
Attention_Net(
  (embedding): Embedding(400000, 100)
  (lstm): LSTM(100, 128, batch_first=True, dropout=0.2, bidirectional=True)
  (linear): Linear(in_features=256, out_features=2, bias=True)
)
Eval accuracy: 0.8539467912161639
Eval summed loss: 278.04278564453125 | avg loss: 0.00717474240809679
torch.Size([400000, 100])


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 1 | took 44.93777561187744 seconds | summed loss: 643.0120849609375 | avg loss: 0.004148250911384821


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 2 | took 44.92331290245056 seconds | summed loss: 600.9609375 | avg loss: 0.00387696735560894


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 3 | took 44.871695041656494 seconds | summed loss: 569.0142211914062 | avg loss: 0.003670870093628764


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 4 | took 44.88446307182312 seconds | summed loss: 536.1664428710938 | avg loss: 0.0034589599817991257


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 5 | took 44.830106019973755 seconds | summed loss: 511.7701110839844 | avg loss: 0.003301572287455201


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 6 | took 44.85476613044739 seconds | summed loss: 495.5711364746094 | avg loss: 0.003197068115696311


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 7 | took 44.844929456710815 seconds | summed loss: 480.5784606933594 | avg loss: 0.0031003463082015514


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 8 | took 44.851093769073486 seconds | summed loss: 470.5377197265625 | avg loss: 0.003035570727661252


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 9 | took 44.81894612312317 seconds | summed loss: 462.21221923828125 | avg loss: 0.0029818604234606028


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 10 | took 44.82874655723572 seconds | summed loss: 454.8736267089844 | avg loss: 0.0029345171060413122


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 11 | took 44.80732774734497 seconds | summed loss: 448.49517822265625 | avg loss: 0.0028933680150657892


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 12 | took 44.87572383880615 seconds | summed loss: 442.9441223144531 | avg loss: 0.002857556566596031


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 13 | took 44.78614592552185 seconds | summed loss: 438.10321044921875 | avg loss: 0.0028263265267014503


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 14 | took 44.854058027267456 seconds | summed loss: 433.5069885253906 | avg loss: 0.0027966750785708427


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 15 | took 44.7822265625 seconds | summed loss: 431.24371337890625 | avg loss: 0.002782074036076665


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 16 | took 45.00028586387634 seconds | summed loss: 424.31787109375 | avg loss: 0.0027373933698982


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 17 | took 44.981684923172 seconds | summed loss: 420.46044921875 | avg loss: 0.002712508197873831


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 18 | took 44.85071158409119 seconds | summed loss: 417.1790466308594 | avg loss: 0.0026913390029221773


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 19 | took 44.86944842338562 seconds | summed loss: 414.786376953125 | avg loss: 0.0026759030297398567


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 20 | took 45.05578327178955 seconds | summed loss: 411.0426330566406 | avg loss: 0.002651751274242997


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 21 | took 44.93572402000427 seconds | summed loss: 408.1721496582031 | avg loss: 0.002633232856169343


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 22 | took 45.021727561950684 seconds | summed loss: 404.22607421875 | avg loss: 0.002607775619253516


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 23 | took 44.97559905052185 seconds | summed loss: 402.070556640625 | avg loss: 0.002593869809061289


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 24 | took 44.951590061187744 seconds | summed loss: 400.81427001953125 | avg loss: 0.0025857652071863413


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 25 | took 44.88163638114929 seconds | summed loss: 396.02117919921875 | avg loss: 0.002554843667894602


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 26 | took 44.992199659347534 seconds | summed loss: 394.7840881347656 | avg loss: 0.0025468626990914345


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 27 | took 45.00008749961853 seconds | summed loss: 391.6030578613281 | avg loss: 0.0025263410061597824


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 28 | took 45.00247526168823 seconds | summed loss: 388.4844970703125 | avg loss: 0.002506222343072295


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 29 | took 44.7040581703186 seconds | summed loss: 387.0380554199219 | avg loss: 0.0024968909565359354


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 30 | took 44.7946503162384 seconds | summed loss: 383.5999450683594 | avg loss: 0.002474710810929537


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 31 | took 44.85105061531067 seconds | summed loss: 383.44049072265625 | avg loss: 0.0024736819323152304


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 32 | took 44.69736313819885 seconds | summed loss: 379.507568359375 | avg loss: 0.0024483096785843372


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 33 | took 44.766770124435425 seconds | summed loss: 378.1918029785156 | avg loss: 0.002439821371808648


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 34 | took 44.69220018386841 seconds | summed loss: 375.13311767578125 | avg loss: 0.002420088741928339


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 35 | took 44.65741467475891 seconds | summed loss: 372.70721435546875 | avg loss: 0.002404438564553857


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 36 | took 44.99077653884888 seconds | summed loss: 370.1048889160156 | avg loss: 0.0023876503109931946


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 37 | took 44.82302522659302 seconds | summed loss: 369.76904296875 | avg loss: 0.0023854835890233517


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 38 | took 44.71080279350281 seconds | summed loss: 366.97100830078125 | avg loss: 0.0023674326948821545


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 39 | took 44.63003897666931 seconds | summed loss: 363.69598388671875 | avg loss: 0.0023463047109544277


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 40 | took 44.61619830131531 seconds | summed loss: 362.0131530761719 | avg loss: 0.002335448283702135


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 41 | took 44.68460726737976 seconds | summed loss: 359.1317443847656 | avg loss: 0.0023168595507740974


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 42 | took 44.731343507766724 seconds | summed loss: 356.53912353515625 | avg loss: 0.0023001336958259344


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 43 | took 44.74559283256531 seconds | summed loss: 355.0207824707031 | avg loss: 0.002290338510647416


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 44 | took 44.62836837768555 seconds | summed loss: 353.0696105957031 | avg loss: 0.002277750987559557


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 45 | took 44.65640735626221 seconds | summed loss: 350.471923828125 | avg loss: 0.0022609925363212824


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 46 | took 44.719045639038086 seconds | summed loss: 349.72491455078125 | avg loss: 0.0022561734076589346


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 47 | took 44.69064259529114 seconds | summed loss: 347.3761901855469 | avg loss: 0.0022410210222005844


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 48 | took 44.73613357543945 seconds | summed loss: 346.88104248046875 | avg loss: 0.002237826818600297


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 49 | took 44.817145586013794 seconds | summed loss: 343.0782470703125 | avg loss: 0.002213293919339776


HBox(children=(IntProgress(value=0, max=1211), HTML(value='')))


Epoch 50 | took 44.78747820854187 seconds | summed loss: 342.0395202636719 | avg loss: 0.002206592820584774
Took 2241.4633293151855 seconds
Attention_Net(
  (embedding): Embedding(400000, 100)
  (lstm): LSTM(100, 128, batch_first=True, dropout=0.2, bidirectional=True)
  (linear): Linear(in_features=256, out_features=2, bias=True)
)
Eval accuracy: 0.8533790932314917
Eval summed loss: 138.96514892578125 | avg loss: 0.0035859199706465006


In [12]:
print(highest_acc)

0.8570949345857095


In [13]:
PATH = 'models/amzn_date4-20_gridsearch_batch' + str(best_batch) + '_epoch' + str(best_epochs) + '_lstm' + str(best_lstm_dim) + '.pt'
print(highest_acc)
torch.save(best_model, PATH)

0.8570949345857095


In [14]:
np.sum(df.recommendation.tolist()) / len(df)

0.7620507421399227