In [1]:
'''Import libraries'''
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import wandb
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import os
from torch.utils.data import Dataset
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from imblearn.under_sampling import RandomUnderSampler
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForSequenceClassification, AdamW, Trainer, TrainingArguments
from tqdm import tqdm
from torch.nn import functional as F
import torch.nn as nn

wandb.login()

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33malberto-rodero557[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
'''Variables and parameters'''

SAMPLES_TO_TRAIN=-1
DIMENSIONS=200

N_LABELS=2
MAX_LEN = 256
EPOCHS=100
PATIENCE=10
LEARNING_RATE=.00005
WEIGHT_DECAY=.01
BATCH_SIZE=16
METRIC_FOR_BEST_MODEL='eval_loss'
if METRIC_FOR_BEST_MODEL=='eval_loss':
    GREATER_IS_BETTER = False
else:
    GREATER_IS_BETTER = True

In [4]:
'''Preparing dataset'''

import pickle

file_path = 'datasets/subtaskA_glove_train_dev_monolingual.pkl'

# Load the data from the pickle file
with open(file_path, 'rb') as file:
    data = pickle.load(file)

# Extract the individual datasets from the loaded data
train_x = data['train_x']
train_y = data['train_y']
dev_x = data['dev_x']
dev_y = data['dev_y']

if SAMPLES_TO_TRAIN!=-1:
    random_indices = np.random.choice(train_x.shape[0], size=SAMPLES_TO_TRAIN, replace=False)
    train_x = train_x[random_indices]
    train_y = train_y[random_indices]

    random_indices_dev = np.random.choice(dev_x.shape[0], size=int(train_x.shape[0] * 0.2), replace=False)
    dev_x = dev_x[random_indices_dev]
    dev_y = pd.Series(dev_y.to_numpy()[random_indices_dev])

print(train_x.shape)
print(train_y.shape)
print(dev_x.shape)
print(dev_y.shape)

print(type(train_y))
print(type(dev_y))

(119757, 200)
(119757,)
(5000, 200)
(5000,)
<class 'pandas.core.series.Series'>
<class 'pandas.core.series.Series'>


In [5]:
'''metrics'''

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'auc': auc,
        'precision': precision,
        'recall': recall,
    }

In [6]:
class Data(Dataset):
    def __init__(self, X_train, y_train):
        self.X = torch.from_numpy(X_train.astype(np.float32))
        self.y = torch.from_numpy(y_train).type(torch.LongTensor)
        self.len = self.X.shape[0]

    def __getitem__(self, index):
        return {'input_ids': self.X[index], 'labels': self.y[index]}

    def __len__(self):
        return self.len

X_train, X_test, y_train, y_test = train_test_split(train_x, train_y.values, test_size=0.2, random_state=42)
traindata = Data(X_train, y_train)
testdata = Data(X_test, y_test)

devdata=Data(dev_x, dev_y.values)

In [7]:
# number of features (len of X cols)
input_dim = train_x.shape[-1]

# number of classes (unique of y)
output_dim = 2

class RNNModel(nn.Module):
    def __init__(self):
        super(RNNModel, self).__init__()
        
        self.lstm1 = nn.LSTM(input_dim, 512, batch_first=True)
        self.ln1 = nn.LayerNorm(512)
        self.dropout1 = nn.Dropout(0.2)
        
        self.lstm2 = nn.LSTM(512, 512, batch_first=True)
        self.ln2 = nn.LayerNorm(512)
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc = nn.Linear(512, output_dim)
        
        self.loss = nn.CrossEntropyLoss()

    def forward(self, input_ids, labels=None):
        # print(f"Input shape: {input_ids.shape}")
        
        x, _ = self.lstm1(input_ids)
        # print(f"After LSTM1: {x.shape}")

        x = self.ln1(x)
        x = self.dropout1(x)
        
        # print(f"Before LSTM2: {x.shape}")
        
        x, _ = self.lstm2(x)
        # print(f"After LSTM2: {x.shape}")

        x = self.ln2(x)
        x = self.dropout2(x)
        
        x = self.fc(x)
        # print(f"Output shape: {x.shape}")
        
        outputs = (x,)
        if labels is not None:
            loss = self.loss(x, labels)
            outputs = (loss,) + outputs
            
        return (outputs if len(outputs) > 1 else outputs[0])

# Instantiate the model with appropriate dimensions
model = RNNModel()

In [8]:
from transformers import EarlyStoppingCallback

model = RNNModel()

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=500,
    weight_decay=WEIGHT_DECAY,
    metric_for_best_model=METRIC_FOR_BEST_MODEL,
    greater_is_better=GREATER_IS_BETTER,
    logging_dir='./logs',
    logging_steps=15000,
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    push_to_hub=False,
    logging_first_step=False,
    load_best_model_at_end=True,
    save_total_limit=2,
    report_to="wandb"
)

# Create trainer
trainer = Trainer(
    model=model, 
    args=training_args, 
    train_dataset=traindata,
    eval_dataset=testdata, 
    compute_metrics=compute_metrics,# training dataset
    callbacks=[EarlyStoppingCallback(early_stopping_patience=PATIENCE)]
)

# Train the model
print(trainer.evaluate())

trainer.train()

print(trainer.evaluate())

 95%|█████████▍| 1417/1497 [00:01<00:00, 861.32it/s]

100%|██████████| 1497/1497 [00:02<00:00, 627.44it/s]


{'eval_loss': 0.8866190910339355, 'eval_accuracy': 0.47828991315965264, 'eval_f1': 0.6470854044283778, 'eval_auc': 0.5, 'eval_precision': 0.47828991315965264, 'eval_recall': 1.0, 'eval_runtime': 1.9408, 'eval_samples_per_second': 12341.101, 'eval_steps_per_second': 771.319}


                                                      
  1%|          | 6006/598800 [00:23<4:01:07, 40.97it/s]

{'eval_loss': 0.5327990651130676, 'eval_accuracy': 0.7395207080828323, 'eval_f1': 0.6770202412382876, 'eval_auc': 0.7324993718660363, 'eval_precision': 0.8318280117033456, 'eval_recall': 0.5707925977653632, 'eval_runtime': 1.9799, 'eval_samples_per_second': 12097.297, 'eval_steps_per_second': 756.081, 'epoch': 1.0}


                                                        
  2%|▏         | 12003/598800 [00:46<4:03:39, 40.14it/s]

{'eval_loss': 0.4663628339767456, 'eval_accuracy': 0.7807698730794923, 'eval_f1': 0.7455294402713835, 'eval_auc': 0.7762202340503151, 'eval_precision': 0.8379997821113411, 'eval_recall': 0.6714385474860335, 'eval_runtime': 1.9872, 'eval_samples_per_second': 12053.14, 'eval_steps_per_second': 753.321, 'epoch': 2.0}


  3%|▎         | 15043/598800 [00:58<38:16, 254.22it/s]  

{'loss': 0.5356, 'learning_rate': 4.878823332776199e-05, 'epoch': 2.51}


                                                       
  3%|▎         | 17971/598800 [01:12<4:06:21, 39.29it/s]

{'eval_loss': 0.4239031970500946, 'eval_accuracy': 0.8040247160988644, 'eval_f1': 0.7820798514391829, 'eval_auc': 0.8011626849083325, 'eval_precision': 0.8352836176120587, 'eval_recall': 0.7352479050279329, 'eval_runtime': 1.8966, 'eval_samples_per_second': 12629.115, 'eval_steps_per_second': 789.32, 'epoch': 3.0}


                                                         
  4%|▍         | 23960/598800 [01:36<4:03:20, 39.37it/s]

{'eval_loss': 0.4091581702232361, 'eval_accuracy': 0.813502004008016, 'eval_f1': 0.7919906868451688, 'eval_auc': 0.8105398196696686, 'eval_precision': 0.8487873041221679, 'eval_recall': 0.74231843575419, 'eval_runtime': 1.9429, 'eval_samples_per_second': 12327.888, 'eval_steps_per_second': 770.493, 'epoch': 4.0}


                                                         
  5%|▌         | 29952/598800 [02:00<3:56:50, 40.03it/s]

{'eval_loss': 0.3986860513687134, 'eval_accuracy': 0.8206412825651302, 'eval_f1': 0.7998882057015093, 'eval_auc': 0.8176798698488544, 'eval_precision': 0.8575709149021175, 'eval_recall': 0.7494762569832403, 'eval_runtime': 1.9296, 'eval_samples_per_second': 12413.153, 'eval_steps_per_second': 775.822, 'epoch': 5.0}


  5%|▌         | 30022/598800 [02:00<1:58:46, 79.81it/s]

{'loss': 0.4331, 'learning_rate': 4.753468159786061e-05, 'epoch': 5.01}


                                                         
  6%|▌         | 35944/598800 [02:23<3:44:31, 41.78it/s]

{'eval_loss': 0.42204558849334717, 'eval_accuracy': 0.8106212424849699, 'eval_f1': 0.8100502512562815, 'eval_auc': 0.8120216346325796, 'eval_precision': 0.7784932388924662, 'eval_recall': 0.8442737430167597, 'eval_runtime': 1.9339, 'eval_samples_per_second': 12385.495, 'eval_steps_per_second': 774.093, 'epoch': 6.0}


                                                         
  7%|▋         | 41927/598800 [02:46<3:57:15, 39.12it/s]

{'eval_loss': 0.3825473189353943, 'eval_accuracy': 0.8296175684702739, 'eval_f1': 0.8189521316711771, 'eval_auc': 0.8286219187905494, 'eval_precision': 0.8326567433468651, 'eval_recall': 0.8056913407821229, 'eval_runtime': 2.1015, 'eval_samples_per_second': 11397.79, 'eval_steps_per_second': 712.362, 'epoch': 7.0}


  8%|▊         | 45040/598800 [02:57<31:39, 291.59it/s]  

{'loss': 0.3974, 'learning_rate': 4.628112986795922e-05, 'epoch': 7.52}


                                                       
  8%|▊         | 47908/598800 [03:09<3:48:08, 40.25it/s]

{'eval_loss': 0.3788236081600189, 'eval_accuracy': 0.8300768203072812, 'eval_f1': 0.8244176013805006, 'eval_auc': 0.8302426050079043, 'eval_precision': 0.8149948822927329, 'eval_recall': 0.8340607541899442, 'eval_runtime': 2.0588, 'eval_samples_per_second': 11633.785, 'eval_steps_per_second': 727.112, 'epoch': 8.0}


                                                         
  9%|▉         | 53907/598800 [03:32<3:37:11, 41.81it/s]

{'eval_loss': 0.3654007911682129, 'eval_accuracy': 0.8396376085504342, 'eval_f1': 0.8285191303183177, 'eval_auc': 0.8384029816915715, 'eval_precision': 0.8479393219409669, 'eval_recall': 0.8099685754189944, 'eval_runtime': 2.0002, 'eval_samples_per_second': 11974.533, 'eval_steps_per_second': 748.408, 'epoch': 9.0}


                                                         
 10%|█         | 59888/598800 [03:55<3:42:33, 40.36it/s]

{'eval_loss': 0.3608098328113556, 'eval_accuracy': 0.8414328657314629, 'eval_f1': 0.8288726682887267, 'eval_auc': 0.8398293040365096, 'eval_precision': 0.8565840938722294, 'eval_recall': 0.8028980446927374, 'eval_runtime': 1.9924, 'eval_samples_per_second': 12021.573, 'eval_steps_per_second': 751.348, 'epoch': 10.0}


 10%|█         | 60044/598800 [03:56<58:04, 154.60it/s]  

{'loss': 0.3767, 'learning_rate': 4.5027578138057834e-05, 'epoch': 10.02}


                                                       
 11%|█         | 65900/598800 [04:18<2:29:44, 59.31it/s]

{'eval_loss': 0.35505449771881104, 'eval_accuracy': 0.8427271209084837, 'eval_f1': 0.8357818562273858, 'eval_auc': 0.8424790905156689, 'eval_precision': 0.8347992684838457, 'eval_recall': 0.8367667597765364, 'eval_runtime': 1.9051, 'eval_samples_per_second': 12572.566, 'eval_steps_per_second': 785.785, 'epoch': 11.0}


                                                         
 12%|█▏        | 71863/598800 [04:40<3:42:28, 39.47it/s]

{'eval_loss': 0.3538796007633209, 'eval_accuracy': 0.8436038744154977, 'eval_f1': 0.8331254454739843, 'eval_auc': 0.8424657342863684, 'eval_precision': 0.8507096069868996, 'eval_recall': 0.8162534916201117, 'eval_runtime': 2.1439, 'eval_samples_per_second': 11172.418, 'eval_steps_per_second': 698.276, 'epoch': 12.0}


 13%|█▎        | 75057/598800 [04:51<29:26, 296.48it/s]  

{'loss': 0.3603, 'learning_rate': 4.377402640815644e-05, 'epoch': 12.53}


                                                       
 13%|█▎        | 77861/598800 [05:02<3:28:26, 41.65it/s]

{'eval_loss': 0.36157843470573425, 'eval_accuracy': 0.8468603874415498, 'eval_f1': 0.8329081632653061, 'eval_auc': 0.8448275515203972, 'eval_precision': 0.8709984756097561, 'eval_recall': 0.7980097765363129, 'eval_runtime': 1.9282, 'eval_samples_per_second': 12422.078, 'eval_steps_per_second': 776.38, 'epoch': 13.0}


                                                         
 14%|█▍        | 83848/598800 [05:25<3:20:33, 42.79it/s]

{'eval_loss': 0.36190080642700195, 'eval_accuracy': 0.8423513694054776, 'eval_f1': 0.8398642917726887, 'eval_auc': 0.8432668286253836, 'eval_precision': 0.8167271527548664, 'eval_recall': 0.8643505586592178, 'eval_runtime': 2.0162, 'eval_samples_per_second': 11880.047, 'eval_steps_per_second': 742.503, 'epoch': 14.0}


                                                         
 15%|█▌        | 89839/598800 [05:47<3:27:47, 40.82it/s]

{'eval_loss': 0.34772762656211853, 'eval_accuracy': 0.8506596526386105, 'eval_f1': 0.8362554360265507, 'eval_auc': 0.8484396571148577, 'eval_precision': 0.8791991529502359, 'eval_recall': 0.7973114525139665, 'eval_runtime': 1.9402, 'eval_samples_per_second': 12344.98, 'eval_steps_per_second': 771.561, 'epoch': 15.0}


 15%|█▌        | 90031/598800 [05:48<47:19, 179.18it/s]  

{'loss': 0.3471, 'learning_rate': 4.2520474678255056e-05, 'epoch': 15.03}


                                                       
 16%|█▌        | 95840/598800 [06:10<2:25:16, 57.70it/s]

{'eval_loss': 0.32999366521835327, 'eval_accuracy': 0.8556696726786908, 'eval_f1': 0.8431986211275911, 'eval_auc': 0.8538260176217283, 'eval_precision': 0.8776319516570673, 'eval_recall': 0.8113652234636871, 'eval_runtime': 1.9258, 'eval_samples_per_second': 12437.52, 'eval_steps_per_second': 777.345, 'epoch': 16.0}


                                                         
 17%|█▋        | 101812/598800 [06:33<3:17:16, 41.99it/s]

{'eval_loss': 0.34285032749176025, 'eval_accuracy': 0.8484051436205745, 'eval_f1': 0.8298899039587726, 'eval_auc': 0.8452727777916866, 'eval_precision': 0.8956416220042471, 'eval_recall': 0.7731319832402235, 'eval_runtime': 1.933, 'eval_samples_per_second': 12391.172, 'eval_steps_per_second': 774.448, 'epoch': 17.0}


 18%|█▊        | 105022/598800 [06:44<31:49, 258.65it/s]  

{'loss': 0.3347, 'learning_rate': 4.126692294835367e-05, 'epoch': 17.54}


                                                        
 18%|█▊        | 107799/598800 [06:58<4:07:55, 33.01it/s]

{'eval_loss': 0.33708688616752625, 'eval_accuracy': 0.853498663994656, 'eval_f1': 0.8404202101050525, 'eval_auc': 0.8515455672072045, 'eval_precision': 0.8772429507262888, 'eval_recall': 0.8065642458100558, 'eval_runtime': 2.2543, 'eval_samples_per_second': 10624.84, 'eval_steps_per_second': 664.052, 'epoch': 18.0}


                                                          
 19%|█▉        | 113772/598800 [07:22<30:33, 264.59it/s]

{'eval_loss': 0.3374324440956116, 'eval_accuracy': 0.8556279225116901, 'eval_f1': 0.8406451612903226, 'eval_auc': 0.8531539589875464, 'eval_precision': 0.8903748535728231, 'eval_recall': 0.7961766759776536, 'eval_runtime': 1.9626, 'eval_samples_per_second': 12204.219, 'eval_steps_per_second': 762.764, 'epoch': 19.0}


                                                          
 20%|██        | 119780/598800 [07:46<3:22:38, 39.40it/s]

{'eval_loss': 0.33525577187538147, 'eval_accuracy': 0.8590514362057449, 'eval_f1': 0.8471568272365086, 'eval_auc': 0.8572886340388701, 'eval_precision': 0.8799849510910459, 'eval_recall': 0.8166899441340782, 'eval_runtime': 1.9821, 'eval_samples_per_second': 12083.883, 'eval_steps_per_second': 755.243, 'epoch': 20.0}


 20%|██        | 120052/598800 [07:47<33:36, 237.41it/s]  

{'loss': 0.325, 'learning_rate': 4.0013371218452283e-05, 'epoch': 20.04}


                                                        
 21%|██        | 125771/598800 [08:10<3:34:37, 36.73it/s]

{'eval_loss': 0.33213648200035095, 'eval_accuracy': 0.8582999331997327, 'eval_f1': 0.8451642335766424, 'eval_auc': 0.8562305859662803, 'eval_precision': 0.8852255351681957, 'eval_recall': 0.8085719273743017, 'eval_runtime': 1.9815, 'eval_samples_per_second': 12087.589, 'eval_steps_per_second': 755.474, 'epoch': 21.0}


                                                          
 22%|██▏       | 131749/598800 [08:34<3:31:58, 36.72it/s]

{'eval_loss': 0.33138397336006165, 'eval_accuracy': 0.8569639278557114, 'eval_f1': 0.8520725388601036, 'eval_auc': 0.8571441744039656, 'eval_precision': 0.8430451127819549, 'eval_recall': 0.8612953910614525, 'eval_runtime': 1.9736, 'eval_samples_per_second': 12136.336, 'eval_steps_per_second': 758.521, 'epoch': 22.0}


 23%|██▎       | 135032/598800 [08:46<30:01, 257.43it/s]  

{'loss': 0.315, 'learning_rate': 3.87598194885509e-05, 'epoch': 22.55}


                                                        
 23%|██▎       | 137730/598800 [08:58<3:09:36, 40.53it/s]

{'eval_loss': 0.3253442049026489, 'eval_accuracy': 0.8594689378757515, 'eval_f1': 0.8480635551142005, 'eval_auc': 0.857826795077218, 'eval_precision': 0.8781080575808562, 'eval_recall': 0.8200069832402235, 'eval_runtime': 2.0152, 'eval_samples_per_second': 11885.757, 'eval_steps_per_second': 742.86, 'epoch': 23.0}


                                                          
 24%|██▍       | 143728/598800 [09:21<2:52:51, 43.88it/s]

{'eval_loss': 0.3358996510505676, 'eval_accuracy': 0.8586756847027388, 'eval_f1': 0.84385811153651, 'eval_auc': 0.856169337316433, 'eval_precision': 0.8947471388046562, 'eval_recall': 0.7984462290502793, 'eval_runtime': 1.9023, 'eval_samples_per_second': 12591.384, 'eval_steps_per_second': 786.962, 'epoch': 24.0}


                                                          
 25%|██▌       | 149706/598800 [09:43<2:57:50, 42.09it/s]

{'eval_loss': 0.32970255613327026, 'eval_accuracy': 0.8581746826987308, 'eval_f1': 0.8536974029889315, 'eval_auc': 0.8584643734039586, 'eval_precision': 0.8425571707897646, 'eval_recall': 0.8651361731843575, 'eval_runtime': 2.0078, 'eval_samples_per_second': 11929.343, 'eval_steps_per_second': 745.584, 'epoch': 25.0}


 25%|██▌       | 150048/598800 [09:44<28:32, 262.05it/s]  

{'loss': 0.3053, 'learning_rate': 3.750626775864951e-05, 'epoch': 25.05}


                                                        
 26%|██▌       | 155716/598800 [10:06<2:49:49, 43.48it/s]

{'eval_loss': 0.32006365060806274, 'eval_accuracy': 0.8633517034068137, 'eval_f1': 0.8531694405814005, 'eval_auc': 0.8619657172976917, 'eval_precision': 0.877618827872635, 'eval_recall': 0.8300453910614525, 'eval_runtime': 1.9133, 'eval_samples_per_second': 12518.777, 'eval_steps_per_second': 782.424, 'epoch': 26.0}


                                                          
 27%|██▋       | 161701/598800 [10:28<2:43:23, 44.59it/s]

{'eval_loss': 0.3298362195491791, 'eval_accuracy': 0.8588426853707415, 'eval_f1': 0.8548989313763359, 'eval_auc': 0.8592825681871831, 'eval_precision': 0.8408611228366399, 'eval_recall': 0.869413407821229, 'eval_runtime': 1.8609, 'eval_samples_per_second': 12870.916, 'eval_steps_per_second': 804.432, 'epoch': 27.0}


 28%|██▊       | 165039/598800 [10:39<25:55, 278.92it/s] 

{'loss': 0.2958, 'learning_rate': 3.6252716028748125e-05, 'epoch': 27.56}


                                                        
 28%|██▊       | 167695/598800 [10:50<2:02:36, 58.60it/s]

{'eval_loss': 0.3158232867717743, 'eval_accuracy': 0.8643954575818303, 'eval_f1': 0.8555674137317681, 'eval_auc': 0.8633692390503509, 'eval_precision': 0.8720087019579406, 'eval_recall': 0.8397346368715084, 'eval_runtime': 1.9022, 'eval_samples_per_second': 12591.837, 'eval_steps_per_second': 786.99, 'epoch': 28.0}


                                                          
 29%|██▉       | 173663/598800 [11:13<2:36:50, 45.18it/s]

{'eval_loss': 0.3317015767097473, 'eval_accuracy': 0.8612224448897795, 'eval_f1': 0.8450494126421778, 'eval_auc': 0.858308625240524, 'eval_precision': 0.9067627050820328, 'eval_recall': 0.7912011173184358, 'eval_runtime': 1.8512, 'eval_samples_per_second': 12938.792, 'eval_steps_per_second': 808.674, 'epoch': 29.0}


                                                         
 30%|███       | 179664/598800 [11:34<2:00:24, 58.02it/s]

{'eval_loss': 0.3462073504924774, 'eval_accuracy': 0.8573396793587175, 'eval_f1': 0.8393134258170704, 'eval_auc': 0.8540788918375668, 'eval_precision': 0.9097767356509329, 'eval_recall': 0.7789804469273743, 'eval_runtime': 1.816, 'eval_samples_per_second': 13189.306, 'eval_steps_per_second': 824.332, 'epoch': 30.0}


 30%|███       | 180053/598800 [11:36<23:42, 294.44it/s] 

{'loss': 0.2893, 'learning_rate': 3.499916429884673e-05, 'epoch': 30.06}


                                                        
 31%|███       | 185646/598800 [11:56<2:32:20, 45.20it/s]

{'eval_loss': 0.3399714231491089, 'eval_accuracy': 0.8586756847027388, 'eval_f1': 0.8415633044699274, 'eval_auc': 0.855599043090437, 'eval_precision': 0.9072560298718337, 'eval_recall': 0.7847416201117319, 'eval_runtime': 1.8785, 'eval_samples_per_second': 12750.298, 'eval_steps_per_second': 796.894, 'epoch': 31.0}


                                                         
 32%|███▏      | 191616/598800 [12:18<21:59, 308.55it/s]

{'eval_loss': 0.33328384160995483, 'eval_accuracy': 0.8638944555778223, 'eval_f1': 0.853601580743668, 'eval_auc': 0.86246772151446, 'eval_precision': 0.8790233074361821, 'eval_recall': 0.8296089385474861, 'eval_runtime': 1.8521, 'eval_samples_per_second': 12932.413, 'eval_steps_per_second': 808.276, 'epoch': 32.0}


 33%|███▎      | 195053/598800 [12:29<21:52, 307.56it/s] 

{'loss': 0.2805, 'learning_rate': 3.3745612568945346e-05, 'epoch': 32.57}


                                                        
 33%|███▎      | 197630/598800 [12:40<2:40:28, 41.66it/s]

{'eval_loss': 0.3212178647518158, 'eval_accuracy': 0.865313961255845, 'eval_f1': 0.8561106155218555, 'eval_auc': 0.864165974452607, 'eval_precision': 0.8753192265596498, 'eval_recall': 0.8377269553072626, 'eval_runtime': 2.0592, 'eval_samples_per_second': 11631.977, 'eval_steps_per_second': 726.999, 'epoch': 33.0}


                                                         
 34%|███▍      | 203604/598800 [13:04<2:46:15, 39.62it/s]

{'eval_loss': 0.32105857133865356, 'eval_accuracy': 0.8658984635938544, 'eval_f1': 0.8590115003072601, 'eval_auc': 0.8654090538022448, 'eval_precision': 0.8639413738301254, 'eval_recall': 0.8541375698324022, 'eval_runtime': 2.0512, 'eval_samples_per_second': 11676.914, 'eval_steps_per_second': 729.807, 'epoch': 34.0}


                                                         
 35%|███▌      | 209603/598800 [13:28<2:31:49, 42.73it/s]

{'eval_loss': 0.3302525579929352, 'eval_accuracy': 0.864687708750835, 'eval_f1': 0.8546310832025119, 'eval_auc': 0.8633115110801938, 'eval_precision': 0.8789556232124734, 'eval_recall': 0.8316166201117319, 'eval_runtime': 1.8358, 'eval_samples_per_second': 13047.029, 'eval_steps_per_second': 815.439, 'epoch': 35.0}


 35%|███▌      | 210024/598800 [13:29<24:12, 267.74it/s] 

{'loss': 0.2714, 'learning_rate': 3.249206083904396e-05, 'epoch': 35.07}


                                                        
 36%|███▌      | 215580/598800 [13:51<2:22:49, 44.72it/s]

{'eval_loss': 0.32847294211387634, 'eval_accuracy': 0.8686122244488977, 'eval_f1': 0.8576597765615812, 'eval_auc': 0.8669056220895716, 'eval_precision': 0.8899840420538815, 'eval_recall': 0.8276012569832403, 'eval_runtime': 1.8511, 'eval_samples_per_second': 12939.474, 'eval_steps_per_second': 808.717, 'epoch': 36.0}


                                                         
 37%|███▋      | 221578/598800 [14:14<2:32:30, 41.22it/s]

{'eval_loss': 0.3293704092502594, 'eval_accuracy': 0.8685704742818972, 'eval_f1': 0.8592254717824882, 'eval_auc': 0.8673232976451906, 'eval_precision': 0.8808912525215478, 'eval_recall': 0.8385998603351955, 'eval_runtime': 1.8389, 'eval_samples_per_second': 13025.179, 'eval_steps_per_second': 814.074, 'epoch': 37.0}


 38%|███▊      | 225054/598800 [14:27<22:30, 276.67it/s] 

{'loss': 0.2646, 'learning_rate': 3.1238509109142574e-05, 'epoch': 37.58}


                                                        
 38%|███▊      | 227544/598800 [14:38<23:52, 259.16it/s]


{'eval_loss': 0.31859129667282104, 'eval_accuracy': 0.8669839679358717, 'eval_f1': 0.8577424540096446, 'eval_auc': 0.8657955461948942, 'eval_precision': 0.8779707495429616, 'eval_recall': 0.838425279329609, 'eval_runtime': 1.9102, 'eval_samples_per_second': 12539.03, 'eval_steps_per_second': 783.689, 'epoch': 38.0}
{'train_runtime': 878.0052, 'train_samples_per_second': 10911.666, 'train_steps_per_second': 682.001, 'train_loss': 0.3412613080271529, 'epoch': 38.0}


100%|██████████| 1497/1497 [00:01<00:00, 792.03it/s]

{'eval_loss': 0.3158232867717743, 'eval_accuracy': 0.8643954575818303, 'eval_f1': 0.8555674137317681, 'eval_auc': 0.8633692390503509, 'eval_precision': 0.8720087019579406, 'eval_recall': 0.8397346368715084, 'eval_runtime': 1.8921, 'eval_samples_per_second': 12659.028, 'eval_steps_per_second': 791.189, 'epoch': 38.0}





In [9]:
trainer.evaluate(devdata)

100%|██████████| 313/313 [00:00<00:00, 705.31it/s]


{'eval_loss': 1.2049089670181274,
 'eval_accuracy': 0.551,
 'eval_f1': 0.3631205673758865,
 'eval_auc': 0.5509999999999999,
 'eval_precision': 0.624390243902439,
 'eval_recall': 0.256,
 'eval_runtime': 0.4448,
 'eval_samples_per_second': 11241.57,
 'eval_steps_per_second': 703.722,
 'epoch': 38.0}

In [10]:
trainer.save_model('SavedModels/'+'lstm_'+str(round(SAMPLES_TO_TRAIN/1000))+'k')

In [11]:
# All samples trained with training dataset
print(''''eval_loss': 0.3158232867717743, 'eval_accuracy': 0.8643954575818303, 'eval_f1': 0.8555674137317681, 'eval_auc': 0.8633692390503509, 'eval_precision': 0.8720087019579406, 'eval_recall': 0.8397346368715084''')

'eval_loss': 0.3158232867717743, 'eval_accuracy': 0.8643954575818303, 'eval_f1': 0.8555674137317681, 'eval_auc': 0.8633692390503509, 'eval_precision': 0.8720087019579406, 'eval_recall': 0.8397346368715084
