In [1]:
'''Import libraries'''
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import wandb
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import os
from torch.utils.data import Dataset
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from imblearn.under_sampling import RandomUnderSampler
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForSequenceClassification, AdamW, Trainer, TrainingArguments
from tqdm import tqdm
from torch.nn import functional as F
import torch.nn as nn
import pickle

wandb.login()

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33malberto-rodero557[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
'''Variables and parameters'''

SAMPLES_TO_TRAIN=1000
DIMENSIONS=200

N_LABELS=2
MAX_LEN = 256
EPOCHS=100
PATIENCE=10
LEARNING_RATE=.00005
WEIGHT_DECAY=.01
BATCH_SIZE=16
METRIC_FOR_BEST_MODEL='eval_loss'
if METRIC_FOR_BEST_MODEL=='eval_loss':
    GREATER_IS_BETTER = False
else:
    GREATER_IS_BETTER = True

In [3]:
'''Preparing dataset'''

file_path = 'datasets/subtaskA_glove_train_dev_monolingual.pkl'

# Load the data from the pickle file
with open(file_path, 'rb') as file:
    data = pickle.load(file)

# Extract the individual datasets from the loaded data
train_x = data['train_x']
train_y = data['train_y']
dev_x = data['dev_x']
dev_y = data['dev_y']

if SAMPLES_TO_TRAIN!=-1:
    random_indices = np.random.choice(train_x.shape[0], size=SAMPLES_TO_TRAIN, replace=False)
    train_x = train_x[random_indices]
    train_y = train_y[random_indices]

    random_indices_dev = np.random.choice(dev_x.shape[0], size=int(train_x.shape[0] * 0.2), replace=False)
    dev_x = dev_x[random_indices_dev]
    dev_y = pd.Series(dev_y.to_numpy()[random_indices_dev])

print(train_x.shape)
print(train_y.shape)
print(dev_x.shape)
print(dev_y.shape)

print(type(train_y))
print(type(dev_y))

(1000, 200)
(1000,)
(200, 200)
(200,)
<class 'pandas.core.series.Series'>
<class 'pandas.core.series.Series'>


In [4]:
'''metrics'''

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'auc': auc,
        'precision': precision,
        'recall': recall,
    }

In [5]:
class Data(Dataset):
    def __init__(self, X_train, y_train):
        self.X = torch.from_numpy(X_train.astype(np.float32))
        self.y = torch.from_numpy(y_train).type(torch.LongTensor)
        self.len = self.X.shape[0]

    def __getitem__(self, index):
        return {'input_ids': self.X[index], 'labels': self.y[index]}

    def __len__(self):
        return self.len

X_train, X_test, y_train, y_test = train_test_split(train_x, train_y.values, test_size=0.2, random_state=42)
traindata = Data(X_train, y_train)
testdata = Data(X_test, y_test)

devdata=Data(dev_x, dev_y.values)

print(traindata)

<__main__.Data object at 0x000001C58590B450>


In [6]:
# number of features (len of X cols)
input_dim = train_x.shape[-1]

# number of classes (unique of y)
output_dim = 2

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(0.5)
        
        self.linear2 = nn.Linear(512, 512)
        self.bn2 = nn.BatchNorm1d(512)
        self.dropout2 = nn.Dropout(0.5)
        
        self.linear3 = nn.Linear(512, 256)
        self.bn3 = nn.BatchNorm1d(256)
        self.dropout3 = nn.Dropout(0.5)
        
        self.linear4 = nn.Linear(256, output_dim)
        
        self.loss = nn.CrossEntropyLoss()

    def forward(self, input_ids, labels=None):
        x1 = F.leaky_relu(self.linear1(input_ids))
        x1 = self.bn1(x1)
        x1 = self.dropout1(x1)
        
        x2 = F.leaky_relu(self.linear2(x1))
        x2 = self.bn2(x2)
        x2 = self.dropout2(x2)
        
        # Adding the first skip connection
        x2 += x1
        
        x3 = F.leaky_relu(self.linear3(x2))
        x3 = self.bn3(x3)
        x3 = self.dropout3(x3)
        
        x4 = self.linear4(x3)
        
        outputs = (x4,)
        if labels is not None:
            loss = self.loss(x4, labels)
            outputs = (loss,) + outputs
            
        return (outputs if len(outputs) > 1 else outputs[0])

# Create the model
model = Network()

In [7]:
from transformers import EarlyStoppingCallback

model = Network()

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=500,
    weight_decay=WEIGHT_DECAY,
    metric_for_best_model=METRIC_FOR_BEST_MODEL,
    greater_is_better=GREATER_IS_BETTER,
    logging_dir='./logs',
    logging_steps=15000,
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    push_to_hub=False,
    logging_first_step=False,
    load_best_model_at_end=True,
    save_total_limit=2,
    report_to="wandb"
)

# Create trainer
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=traindata,             # training dataset
    eval_dataset=testdata, 
    compute_metrics=compute_metrics,# training dataset
    callbacks=[EarlyStoppingCallback(early_stopping_patience=PATIENCE)]
)

# Train the model
print(trainer.evaluate())

trainer.train()

print(trainer.evaluate())

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:00<00:00, 19.44it/s]


{'eval_loss': 0.698131799697876, 'eval_accuracy': 0.435, 'eval_f1': 0.6062717770034843, 'eval_auc': 0.5, 'eval_precision': 0.435, 'eval_recall': 1.0, 'eval_runtime': 1.2849, 'eval_samples_per_second': 155.65, 'eval_steps_per_second': 10.117}


                                                  
  2%|▏         | 90/5000 [00:00<00:18, 271.23it/s]

{'eval_loss': 0.7775034308433533, 'eval_accuracy': 0.435, 'eval_f1': 0.6062717770034843, 'eval_auc': 0.5, 'eval_precision': 0.435, 'eval_recall': 1.0, 'eval_runtime': 0.0135, 'eval_samples_per_second': 14808.565, 'eval_steps_per_second': 962.557, 'epoch': 1.0}


                                                  
  2%|▏         | 121/5000 [00:00<00:21, 223.00it/s]

{'eval_loss': 0.7403948307037354, 'eval_accuracy': 0.495, 'eval_f1': 0.5589519650655022, 'eval_auc': 0.5226833485911911, 'eval_precision': 0.4507042253521127, 'eval_recall': 0.735632183908046, 'eval_runtime': 0.0135, 'eval_samples_per_second': 14810.134, 'eval_steps_per_second': 962.659, 'epoch': 2.0}


                                                   
  4%|▍         | 199/5000 [00:00<00:15, 309.68it/s]

{'eval_loss': 0.7177556753158569, 'eval_accuracy': 0.525, 'eval_f1': 0.5454545454545454, 'eval_auc': 0.5399755874275252, 'eval_precision': 0.4672131147540984, 'eval_recall': 0.6551724137931034, 'eval_runtime': 0.016, 'eval_samples_per_second': 12489.925, 'eval_steps_per_second': 811.845, 'epoch': 3.0}


                                                   
  4%|▍         | 200/5000 [00:00<00:15, 309.68it/s]

{'eval_loss': 0.7071982026100159, 'eval_accuracy': 0.515, 'eval_f1': 0.526829268292683, 'eval_auc': 0.5271589868782423, 'eval_precision': 0.4576271186440678, 'eval_recall': 0.6206896551724138, 'eval_runtime': 0.014, 'eval_samples_per_second': 14239.218, 'eval_steps_per_second': 925.549, 'epoch': 4.0}


                                                   
  5%|▌         | 250/5000 [00:00<00:15, 315.58it/s]

{'eval_loss': 0.6949342489242554, 'eval_accuracy': 0.56, 'eval_f1': 0.5925925925925927, 'eval_auc': 0.5802054724849964, 'eval_precision': 0.49612403100775193, 'eval_recall': 0.735632183908046, 'eval_runtime': 0.0151, 'eval_samples_per_second': 13260.106, 'eval_steps_per_second': 861.907, 'epoch': 5.0}


                                                   
  6%|▌         | 302/5000 [00:01<00:14, 316.76it/s]

{'eval_loss': 0.6778982281684875, 'eval_accuracy': 0.57, 'eval_f1': 0.5943396226415095, 'eval_auc': 0.5877326823314006, 'eval_precision': 0.504, 'eval_recall': 0.7241379310344828, 'eval_runtime': 0.0135, 'eval_samples_per_second': 14806.736, 'eval_steps_per_second': 962.438, 'epoch': 6.0}


                                                   
  7%|▋         | 350/5000 [00:01<00:13, 356.68it/s]

{'eval_loss': 0.6682389974594116, 'eval_accuracy': 0.6, 'eval_f1': 0.6261682242990655, 'eval_auc': 0.619570745600651, 'eval_precision': 0.5275590551181102, 'eval_recall': 0.7701149425287356, 'eval_runtime': 0.0142, 'eval_samples_per_second': 14108.697, 'eval_steps_per_second': 917.065, 'epoch': 7.0}


                                                   
  8%|▊         | 400/5000 [00:01<00:13, 350.92it/s]

{'eval_loss': 0.6487293839454651, 'eval_accuracy': 0.6, 'eval_f1': 0.6078431372549019, 'eval_auc': 0.6129590072220528, 'eval_precision': 0.5299145299145299, 'eval_recall': 0.7126436781609196, 'eval_runtime': 0.0122, 'eval_samples_per_second': 16339.64, 'eval_steps_per_second': 1062.077, 'epoch': 8.0}


                                                   
  9%|▉         | 457/5000 [00:01<00:13, 337.41it/s]

{'eval_loss': 0.6412814259529114, 'eval_accuracy': 0.61, 'eval_f1': 0.625, 'eval_auc': 0.6257756077713357, 'eval_precision': 0.5371900826446281, 'eval_recall': 0.7471264367816092, 'eval_runtime': 0.0142, 'eval_samples_per_second': 14113.444, 'eval_steps_per_second': 917.374, 'epoch': 9.0}


                                                   
 10%|█         | 501/5000 [00:01<00:12, 347.20it/s]

{'eval_loss': 0.6373379230499268, 'eval_accuracy': 0.63, 'eval_f1': 0.6407766990291263, 'eval_auc': 0.6447970704913029, 'eval_precision': 0.5546218487394958, 'eval_recall': 0.7586206896551724, 'eval_runtime': 0.0125, 'eval_samples_per_second': 15993.533, 'eval_steps_per_second': 1039.58, 'epoch': 10.0}


                                                   
 11%|█         | 550/5000 [00:01<00:11, 378.28it/s]

{'eval_loss': 0.6243371963500977, 'eval_accuracy': 0.66, 'eval_f1': 0.6344086021505375, 'eval_auc': 0.662089309327637, 'eval_precision': 0.5959595959595959, 'eval_recall': 0.6781609195402298, 'eval_runtime': 0.0156, 'eval_samples_per_second': 12838.789, 'eval_steps_per_second': 834.521, 'epoch': 11.0}


                                                   
 12%|█▏        | 624/5000 [00:01<00:12, 362.71it/s]

{'eval_loss': 0.6106968522071838, 'eval_accuracy': 0.66, 'eval_f1': 0.6263736263736264, 'eval_auc': 0.6594446139761976, 'eval_precision': 0.6, 'eval_recall': 0.6551724137931034, 'eval_runtime': 0.0135, 'eval_samples_per_second': 14799.421, 'eval_steps_per_second': 961.962, 'epoch': 12.0}


                                                   
 13%|█▎        | 661/5000 [00:01<00:12, 354.59it/s]

{'eval_loss': 0.6032118797302246, 'eval_accuracy': 0.665, 'eval_f1': 0.633879781420765, 'eval_auc': 0.6651917404129792, 'eval_precision': 0.6041666666666666, 'eval_recall': 0.6666666666666666, 'eval_runtime': 0.0145, 'eval_samples_per_second': 13787.075, 'eval_steps_per_second': 896.16, 'epoch': 13.0}


                                                   
 14%|█▍        | 701/5000 [00:02<00:12, 347.21it/s]

{'eval_loss': 0.5924349427223206, 'eval_accuracy': 0.65, 'eval_f1': 0.5930232558139535, 'eval_auc': 0.6426609703997558, 'eval_precision': 0.6, 'eval_recall': 0.5862068965517241, 'eval_runtime': 0.0136, 'eval_samples_per_second': 14735.987, 'eval_steps_per_second': 957.839, 'epoch': 14.0}


                                                   
 15%|█▌        | 750/5000 [00:02<00:12, 354.16it/s]

{'eval_loss': 0.5925384759902954, 'eval_accuracy': 0.65, 'eval_f1': 0.5930232558139535, 'eval_auc': 0.6426609703997558, 'eval_precision': 0.6, 'eval_recall': 0.5862068965517241, 'eval_runtime': 0.016, 'eval_samples_per_second': 12485.649, 'eval_steps_per_second': 811.567, 'epoch': 15.0}


                                                   
 16%|█▌        | 809/5000 [00:02<00:12, 330.96it/s]

{'eval_loss': 0.5881457328796387, 'eval_accuracy': 0.67, 'eval_f1': 0.6162790697674418, 'eval_auc': 0.663004780795443, 'eval_precision': 0.6235294117647059, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0145, 'eval_samples_per_second': 13773.04, 'eval_steps_per_second': 895.248, 'epoch': 16.0}


                                                   
 17%|█▋        | 851/5000 [00:02<00:12, 339.81it/s]

{'eval_loss': 0.5944240689277649, 'eval_accuracy': 0.64, 'eval_f1': 0.5813953488372092, 'eval_auc': 0.6324890652019123, 'eval_precision': 0.5882352941176471, 'eval_recall': 0.5747126436781609, 'eval_runtime': 0.013, 'eval_samples_per_second': 15359.813, 'eval_steps_per_second': 998.388, 'epoch': 17.0}


                                                   
 18%|█▊        | 900/5000 [00:02<00:10, 379.59it/s]

{'eval_loss': 0.5799891352653503, 'eval_accuracy': 0.675, 'eval_f1': 0.6060606060606061, 'eval_auc': 0.663462516529346, 'eval_precision': 0.6410256410256411, 'eval_recall': 0.5747126436781609, 'eval_runtime': 0.013, 'eval_samples_per_second': 15368.536, 'eval_steps_per_second': 998.955, 'epoch': 18.0}


                                                   
 19%|█▉        | 950/5000 [00:02<00:11, 367.21it/s]

{'eval_loss': 0.5664207935333252, 'eval_accuracy': 0.685, 'eval_f1': 0.6134969325153374, 'eval_auc': 0.6723120740514698, 'eval_precision': 0.6578947368421053, 'eval_recall': 0.5747126436781609, 'eval_runtime': 0.0189, 'eval_samples_per_second': 10598.502, 'eval_steps_per_second': 688.903, 'epoch': 19.0}


                                                   
 20%|██        | 1000/5000 [00:02<00:11, 346.63it/s]

{'eval_loss': 0.5694951415061951, 'eval_accuracy': 0.67, 'eval_f1': 0.6024096385542168, 'eval_auc': 0.6590377377682841, 'eval_precision': 0.6329113924050633, 'eval_recall': 0.5747126436781609, 'eval_runtime': 0.0145, 'eval_samples_per_second': 13765.13, 'eval_steps_per_second': 894.733, 'epoch': 20.0}


                                                    
 21%|██        | 1051/5000 [00:03<00:11, 340.81it/s]

{'eval_loss': 0.5580580830574036, 'eval_accuracy': 0.685, 'eval_f1': 0.6272189349112427, 'eval_auc': 0.676279117078629, 'eval_precision': 0.6463414634146342, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0136, 'eval_samples_per_second': 14666.168, 'eval_steps_per_second': 953.301, 'epoch': 21.0}


                                                    
 22%|██▏       | 1100/5000 [00:03<00:10, 378.40it/s]

{'eval_loss': 0.5564488768577576, 'eval_accuracy': 0.695, 'eval_f1': 0.6347305389221558, 'eval_auc': 0.6851286746007527, 'eval_precision': 0.6625, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0165, 'eval_samples_per_second': 12102.503, 'eval_steps_per_second': 786.663, 'epoch': 22.0}


                                                    
 24%|██▎       | 1175/5000 [00:03<00:10, 353.42it/s]

{'eval_loss': 0.5523406267166138, 'eval_accuracy': 0.71, 'eval_f1': 0.6282051282051283, 'eval_auc': 0.6931136201810598, 'eval_precision': 0.7101449275362319, 'eval_recall': 0.5632183908045977, 'eval_runtime': 0.014, 'eval_samples_per_second': 14249.377, 'eval_steps_per_second': 926.209, 'epoch': 23.0}


                                                    
 24%|██▍       | 1211/5000 [00:03<00:11, 333.20it/s]

{'eval_loss': 0.5511720776557922, 'eval_accuracy': 0.705, 'eval_f1': 0.6289308176100629, 'eval_auc': 0.6900111890957177, 'eval_precision': 0.6944444444444444, 'eval_recall': 0.5747126436781609, 'eval_runtime': 0.0141, 'eval_samples_per_second': 14233.902, 'eval_steps_per_second': 925.204, 'epoch': 24.0}


                                                    
 25%|██▌       | 1251/5000 [00:03<00:11, 337.27it/s]

{'eval_loss': 0.5513394474983215, 'eval_accuracy': 0.68, 'eval_f1': 0.5949367088607594, 'eval_auc': 0.663920252263249, 'eval_precision': 0.6619718309859155, 'eval_recall': 0.5402298850574713, 'eval_runtime': 0.0141, 'eval_samples_per_second': 14216.534, 'eval_steps_per_second': 924.075, 'epoch': 25.0}


                                                    
 26%|██▌       | 1300/5000 [00:03<00:09, 371.63it/s]

{'eval_loss': 0.544624388217926, 'eval_accuracy': 0.715, 'eval_f1': 0.632258064516129, 'eval_auc': 0.6975383989421218, 'eval_precision': 0.7205882352941176, 'eval_recall': 0.5632183908045977, 'eval_runtime': 0.0125, 'eval_samples_per_second': 15984.695, 'eval_steps_per_second': 1039.005, 'epoch': 26.0}


                                                    
 27%|██▋       | 1350/5000 [00:03<00:10, 364.84it/s]

{'eval_loss': 0.5486124157905579, 'eval_accuracy': 0.715, 'eval_f1': 0.6274509803921569, 'eval_auc': 0.6962160512664022, 'eval_precision': 0.7272727272727273, 'eval_recall': 0.5517241379310345, 'eval_runtime': 0.015, 'eval_samples_per_second': 13307.224, 'eval_steps_per_second': 864.97, 'epoch': 27.0}


                                                    
 28%|██▊       | 1409/5000 [00:04<00:10, 339.57it/s]

{'eval_loss': 0.5437541604042053, 'eval_accuracy': 0.71, 'eval_f1': 0.6329113924050632, 'eval_auc': 0.6944359678567795, 'eval_precision': 0.704225352112676, 'eval_recall': 0.5747126436781609, 'eval_runtime': 0.0146, 'eval_samples_per_second': 13705.532, 'eval_steps_per_second': 890.86, 'epoch': 28.0}


                                                    
 29%|██▉       | 1451/5000 [00:04<00:10, 339.21it/s]

{'eval_loss': 0.5432240962982178, 'eval_accuracy': 0.715, 'eval_f1': 0.6503067484662576, 'eval_auc': 0.7028277896450006, 'eval_precision': 0.6973684210526315, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.013, 'eval_samples_per_second': 15366.565, 'eval_steps_per_second': 998.827, 'epoch': 29.0}


                                                    
 30%|███       | 1500/5000 [00:04<00:09, 355.00it/s]

{'eval_loss': 0.5351526737213135, 'eval_accuracy': 0.725, 'eval_f1': 0.6583850931677019, 'eval_auc': 0.7116773471671244, 'eval_precision': 0.7162162162162162, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.016, 'eval_samples_per_second': 12482.676, 'eval_steps_per_second': 811.374, 'epoch': 30.0}


                                                    
 31%|███       | 1550/5000 [00:04<00:10, 339.49it/s]

{'eval_loss': 0.5458622574806213, 'eval_accuracy': 0.7, 'eval_f1': 0.6511627906976745, 'eval_auc': 0.6935204963889738, 'eval_precision': 0.6588235294117647, 'eval_recall': 0.6436781609195402, 'eval_runtime': 0.0155, 'eval_samples_per_second': 12893.847, 'eval_steps_per_second': 838.1, 'epoch': 31.0}


                                                    
 32%|███▏      | 1601/5000 [00:04<00:10, 322.41it/s]

{'eval_loss': 0.5394492149353027, 'eval_accuracy': 0.73, 'eval_f1': 0.674698795180723, 'eval_auc': 0.7200691689553453, 'eval_precision': 0.7088607594936709, 'eval_recall': 0.6436781609195402, 'eval_runtime': 0.0135, 'eval_samples_per_second': 14788.203, 'eval_steps_per_second': 961.233, 'epoch': 32.0}


                                                    
 33%|███▎      | 1650/5000 [00:04<00:09, 351.13it/s]

{'eval_loss': 0.5396431684494019, 'eval_accuracy': 0.71, 'eval_f1': 0.6588235294117646, 'eval_auc': 0.7023700539110975, 'eval_precision': 0.6746987951807228, 'eval_recall': 0.6436781609195402, 'eval_runtime': 0.015, 'eval_samples_per_second': 13309.124, 'eval_steps_per_second': 865.093, 'epoch': 33.0}


                                                    
 34%|███▍      | 1700/5000 [00:05<00:10, 325.11it/s]

{'eval_loss': 0.5376358032226562, 'eval_accuracy': 0.725, 'eval_f1': 0.6666666666666666, 'eval_auc': 0.7143220425185638, 'eval_precision': 0.7051282051282052, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.014, 'eval_samples_per_second': 14243.812, 'eval_steps_per_second': 925.848, 'epoch': 34.0}


                                                    
 35%|███▌      | 1751/5000 [00:05<00:10, 312.39it/s]

{'eval_loss': 0.5418542623519897, 'eval_accuracy': 0.71, 'eval_f1': 0.6627906976744184, 'eval_auc': 0.7036924015868172, 'eval_precision': 0.6705882352941176, 'eval_recall': 0.6551724137931034, 'eval_runtime': 0.0147, 'eval_samples_per_second': 13609.254, 'eval_steps_per_second': 884.602, 'epoch': 35.0}


                                                    
 36%|███▌      | 1800/5000 [00:05<00:09, 337.31it/s]

{'eval_loss': 0.5391243100166321, 'eval_accuracy': 0.71, 'eval_f1': 0.6506024096385542, 'eval_auc': 0.6997253585596581, 'eval_precision': 0.6835443037974683, 'eval_recall': 0.6206896551724138, 'eval_runtime': 0.0161, 'eval_samples_per_second': 12395.248, 'eval_steps_per_second': 805.691, 'epoch': 36.0}


                                                    
 37%|███▋      | 1861/5000 [00:05<00:09, 318.39it/s]

{'eval_loss': 0.5348414778709412, 'eval_accuracy': 0.73, 'eval_f1': 0.6625, 'eval_auc': 0.7161021259281863, 'eval_precision': 0.726027397260274, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0155, 'eval_samples_per_second': 12897.415, 'eval_steps_per_second': 838.332, 'epoch': 37.0}


                                                    
 38%|███▊      | 1901/5000 [00:05<00:09, 322.19it/s]

{'eval_loss': 0.5357725620269775, 'eval_accuracy': 0.73, 'eval_f1': 0.6582278481012659, 'eval_auc': 0.7147797782524667, 'eval_precision': 0.7323943661971831, 'eval_recall': 0.5977011494252874, 'eval_runtime': 0.0131, 'eval_samples_per_second': 15316.064, 'eval_steps_per_second': 995.544, 'epoch': 38.0}


                                                    
 39%|███▉      | 1950/5000 [00:05<00:08, 351.26it/s]

{'eval_loss': 0.5341449975967407, 'eval_accuracy': 0.72, 'eval_f1': 0.6455696202531646, 'eval_auc': 0.7046078730546231, 'eval_precision': 0.7183098591549296, 'eval_recall': 0.5862068965517241, 'eval_runtime': 0.0166, 'eval_samples_per_second': 12043.945, 'eval_steps_per_second': 782.856, 'epoch': 39.0}


                                                    
 40%|████      | 2000/5000 [00:05<00:08, 340.59it/s]

{'eval_loss': 0.5370564460754395, 'eval_accuracy': 0.71, 'eval_f1': 0.646341463414634, 'eval_auc': 0.6984030108839385, 'eval_precision': 0.6883116883116883, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.015, 'eval_samples_per_second': 13295.624, 'eval_steps_per_second': 864.216, 'epoch': 40.0}


                                                    
 41%|████      | 2051/5000 [00:06<00:09, 317.33it/s]

{'eval_loss': 0.5408334136009216, 'eval_accuracy': 0.7, 'eval_f1': 0.6000000000000001, 'eval_auc': 0.6789746719560573, 'eval_precision': 0.7142857142857143, 'eval_recall': 0.5172413793103449, 'eval_runtime': 0.015, 'eval_samples_per_second': 13294.149, 'eval_steps_per_second': 864.12, 'epoch': 41.0}


                                                    
 42%|████▏     | 2100/5000 [00:06<00:08, 345.34it/s]

{'eval_loss': 0.5351957082748413, 'eval_accuracy': 0.73, 'eval_f1': 0.6823529411764706, 'eval_auc': 0.7227138643067846, 'eval_precision': 0.6987951807228916, 'eval_recall': 0.6666666666666666, 'eval_runtime': 0.015, 'eval_samples_per_second': 13315.039, 'eval_steps_per_second': 865.478, 'epoch': 42.0}


                                                    
[A                                   

{'eval_loss': 0.5277564525604248, 'eval_accuracy': 0.73, 'eval_f1': 0.6666666666666666, 'eval_auc': 0.7174244736039059, 'eval_precision': 0.72, 'eval_recall': 0.6206896551724138, 'eval_runtime': 0.0161, 'eval_samples_per_second': 12426.647, 'eval_steps_per_second': 807.732, 'epoch': 43.0}

 43%|████▎     | 2162/5000 [00:06<00:09, 302.40it/s]




                                                    
 44%|████▍     | 2201/5000 [00:06<00:09, 305.58it/s]

{'eval_loss': 0.5218561291694641, 'eval_accuracy': 0.755, 'eval_f1': 0.6993865030674846, 'eval_auc': 0.7435154104363747, 'eval_precision': 0.75, 'eval_recall': 0.6551724137931034, 'eval_runtime': 0.0145, 'eval_samples_per_second': 13778.696, 'eval_steps_per_second': 895.615, 'epoch': 44.0}


                                                    
 45%|████▌     | 2250/5000 [00:06<00:08, 329.59it/s]

{'eval_loss': 0.5294960737228394, 'eval_accuracy': 0.73, 'eval_f1': 0.6538461538461539, 'eval_auc': 0.713457430576747, 'eval_precision': 0.7391304347826086, 'eval_recall': 0.5862068965517241, 'eval_runtime': 0.0164, 'eval_samples_per_second': 12188.138, 'eval_steps_per_second': 792.229, 'epoch': 45.0}


                                                    
 46%|████▌     | 2308/5000 [00:06<00:08, 319.55it/s]

{'eval_loss': 0.529761016368866, 'eval_accuracy': 0.74, 'eval_f1': 0.675, 'eval_auc': 0.7262740311260298, 'eval_precision': 0.7397260273972602, 'eval_recall': 0.6206896551724138, 'eval_runtime': 0.0136, 'eval_samples_per_second': 14737.023, 'eval_steps_per_second': 957.906, 'epoch': 46.0}


                                                    
 47%|████▋     | 2350/5000 [00:07<00:07, 346.14it/s]

{'eval_loss': 0.5282353162765503, 'eval_accuracy': 0.725, 'eval_f1': 0.6666666666666666, 'eval_auc': 0.7143220425185638, 'eval_precision': 0.7051282051282052, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.0134, 'eval_samples_per_second': 14943.366, 'eval_steps_per_second': 971.319, 'epoch': 47.0}


                                                    
 48%|████▊     | 2400/5000 [00:07<00:08, 322.33it/s]

{'eval_loss': 0.5235162973403931, 'eval_accuracy': 0.73, 'eval_f1': 0.6625, 'eval_auc': 0.7161021259281863, 'eval_precision': 0.726027397260274, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0166, 'eval_samples_per_second': 12037.724, 'eval_steps_per_second': 782.452, 'epoch': 48.0}


                                                    
 49%|████▉     | 2451/5000 [00:07<00:08, 298.18it/s]

{'eval_loss': 0.5212118029594421, 'eval_accuracy': 0.745, 'eval_f1': 0.679245283018868, 'eval_auc': 0.7306988098870918, 'eval_precision': 0.75, 'eval_recall': 0.6206896551724138, 'eval_runtime': 0.015, 'eval_samples_per_second': 13313.983, 'eval_steps_per_second': 865.409, 'epoch': 49.0}


                                                    
 50%|█████     | 2500/5000 [00:07<00:07, 322.18it/s]

{'eval_loss': 0.5210397243499756, 'eval_accuracy': 0.74, 'eval_f1': 0.6708860759493671, 'eval_auc': 0.7249516834503102, 'eval_precision': 0.7464788732394366, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0144, 'eval_samples_per_second': 13919.073, 'eval_steps_per_second': 904.74, 'epoch': 50.0}


                                                    
 51%|█████     | 2550/5000 [00:07<00:07, 310.56it/s]

{'eval_loss': 0.5270293354988098, 'eval_accuracy': 0.73, 'eval_f1': 0.6582278481012659, 'eval_auc': 0.7147797782524667, 'eval_precision': 0.7323943661971831, 'eval_recall': 0.5977011494252874, 'eval_runtime': 0.0165, 'eval_samples_per_second': 12106.87, 'eval_steps_per_second': 786.947, 'epoch': 51.0}


                                                    
 52%|█████▏    | 2600/5000 [00:07<00:07, 327.95it/s]

{'eval_loss': 0.5227335095405579, 'eval_accuracy': 0.735, 'eval_f1': 0.6624203821656051, 'eval_auc': 0.7192045570135287, 'eval_precision': 0.7428571428571429, 'eval_recall': 0.5977011494252874, 'eval_runtime': 0.0155, 'eval_samples_per_second': 12869.319, 'eval_steps_per_second': 836.506, 'epoch': 52.0}


                                                    
 53%|█████▎    | 2663/5000 [00:08<00:07, 312.30it/s]

{'eval_loss': 0.5345630645751953, 'eval_accuracy': 0.72, 'eval_f1': 0.631578947368421, 'eval_auc': 0.7006408300274641, 'eval_precision': 0.7384615384615385, 'eval_recall': 0.5517241379310345, 'eval_runtime': 0.013, 'eval_samples_per_second': 15357.844, 'eval_steps_per_second': 998.26, 'epoch': 53.0}


                                                    
 54%|█████▍    | 2701/5000 [00:08<00:07, 310.87it/s]

{'eval_loss': 0.5262733697891235, 'eval_accuracy': 0.745, 'eval_f1': 0.6832298136645962, 'eval_auc': 0.7320211575628115, 'eval_precision': 0.7432432432432432, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.0165, 'eval_samples_per_second': 12104.424, 'eval_steps_per_second': 786.788, 'epoch': 54.0}


                                                    
 55%|█████▌    | 2750/5000 [00:08<00:06, 346.02it/s]

{'eval_loss': 0.5284649133682251, 'eval_accuracy': 0.735, 'eval_f1': 0.6580645161290322, 'eval_auc': 0.717882209337809, 'eval_precision': 0.75, 'eval_recall': 0.5862068965517241, 'eval_runtime': 0.0145, 'eval_samples_per_second': 13780.96, 'eval_steps_per_second': 895.762, 'epoch': 55.0}


                                                    
 56%|█████▌    | 2800/5000 [00:08<00:06, 336.85it/s]

{'eval_loss': 0.5277034640312195, 'eval_accuracy': 0.745, 'eval_f1': 0.6832298136645962, 'eval_auc': 0.7320211575628115, 'eval_precision': 0.7432432432432432, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.014, 'eval_samples_per_second': 14269.252, 'eval_steps_per_second': 927.501, 'epoch': 56.0}


                                                    
 57%|█████▋    | 2851/5000 [00:08<00:06, 316.62it/s]

{'eval_loss': 0.5270910859107971, 'eval_accuracy': 0.735, 'eval_f1': 0.6624203821656051, 'eval_auc': 0.7192045570135287, 'eval_precision': 0.7428571428571429, 'eval_recall': 0.5977011494252874, 'eval_runtime': 0.0142, 'eval_samples_per_second': 14051.27, 'eval_steps_per_second': 913.333, 'epoch': 57.0}


                                                    
 58%|█████▊    | 2900/5000 [00:08<00:06, 343.93it/s]

{'eval_loss': 0.5203037261962891, 'eval_accuracy': 0.74, 'eval_f1': 0.6790123456790123, 'eval_auc': 0.7275963788017495, 'eval_precision': 0.7333333333333333, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.0155, 'eval_samples_per_second': 12893.054, 'eval_steps_per_second': 838.049, 'epoch': 58.0}


                                                    
 59%|█████▉    | 2950/5000 [00:08<00:06, 326.95it/s]

{'eval_loss': 0.5209954977035522, 'eval_accuracy': 0.755, 'eval_f1': 0.6956521739130436, 'eval_auc': 0.7421930627606551, 'eval_precision': 0.7567567567567568, 'eval_recall': 0.6436781609195402, 'eval_runtime': 0.0151, 'eval_samples_per_second': 13240.223, 'eval_steps_per_second': 860.614, 'epoch': 59.0}


                                                    
 60%|██████    | 3001/5000 [00:09<00:06, 320.07it/s]

{'eval_loss': 0.5240082144737244, 'eval_accuracy': 0.745, 'eval_f1': 0.6832298136645962, 'eval_auc': 0.7320211575628115, 'eval_precision': 0.7432432432432432, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.0132, 'eval_samples_per_second': 15100.461, 'eval_steps_per_second': 981.53, 'epoch': 60.0}


                                                    
 61%|██████    | 3050/5000 [00:09<00:05, 345.32it/s]

{'eval_loss': 0.5214991569519043, 'eval_accuracy': 0.74, 'eval_f1': 0.6790123456790123, 'eval_auc': 0.7275963788017495, 'eval_precision': 0.7333333333333333, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.0155, 'eval_samples_per_second': 12929.819, 'eval_steps_per_second': 840.438, 'epoch': 61.0}


                                                    
 62%|██████▏   | 3100/5000 [00:09<00:05, 332.03it/s]

{'eval_loss': 0.5214917063713074, 'eval_accuracy': 0.74, 'eval_f1': 0.6829268292682927, 'eval_auc': 0.7289187264774692, 'eval_precision': 0.7272727272727273, 'eval_recall': 0.6436781609195402, 'eval_runtime': 0.0155, 'eval_samples_per_second': 12896.226, 'eval_steps_per_second': 838.255, 'epoch': 62.0}


                                                    
 63%|██████▎   | 3151/5000 [00:09<00:05, 316.95it/s]

{'eval_loss': 0.5169082880020142, 'eval_accuracy': 0.745, 'eval_f1': 0.6832298136645962, 'eval_auc': 0.7320211575628115, 'eval_precision': 0.7432432432432432, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.014, 'eval_samples_per_second': 14267.311, 'eval_steps_per_second': 927.375, 'epoch': 63.0}


                                                    
 64%|██████▍   | 3200/5000 [00:09<00:05, 341.58it/s]

{'eval_loss': 0.5207852721214294, 'eval_accuracy': 0.74, 'eval_f1': 0.6708860759493671, 'eval_auc': 0.7249516834503102, 'eval_precision': 0.7464788732394366, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0216, 'eval_samples_per_second': 9262.63, 'eval_steps_per_second': 602.071, 'epoch': 64.0}


                                                    
 65%|██████▌   | 3250/5000 [00:09<00:05, 328.09it/s]

{'eval_loss': 0.5221735239028931, 'eval_accuracy': 0.75, 'eval_f1': 0.6875000000000001, 'eval_auc': 0.7364459363238735, 'eval_precision': 0.7534246575342466, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.0158, 'eval_samples_per_second': 12646.778, 'eval_steps_per_second': 822.041, 'epoch': 65.0}


                                                    
 66%|██████▌   | 3301/5000 [00:09<00:05, 312.12it/s]

{'eval_loss': 0.531046450138092, 'eval_accuracy': 0.745, 'eval_f1': 0.679245283018868, 'eval_auc': 0.7306988098870918, 'eval_precision': 0.75, 'eval_recall': 0.6206896551724138, 'eval_runtime': 0.0155, 'eval_samples_per_second': 12899.002, 'eval_steps_per_second': 838.435, 'epoch': 66.0}


                                                    
 67%|██████▋   | 3350/5000 [00:10<00:04, 341.23it/s]

{'eval_loss': 0.5236904621124268, 'eval_accuracy': 0.745, 'eval_f1': 0.6709677419354838, 'eval_auc': 0.7280541145356526, 'eval_precision': 0.7647058823529411, 'eval_recall': 0.5977011494252874, 'eval_runtime': 0.0152, 'eval_samples_per_second': 13187.976, 'eval_steps_per_second': 857.218, 'epoch': 67.0}


                                                    
 68%|██████▊   | 3400/5000 [00:10<00:04, 330.43it/s]

{'eval_loss': 0.5234100222587585, 'eval_accuracy': 0.745, 'eval_f1': 0.679245283018868, 'eval_auc': 0.7306988098870918, 'eval_precision': 0.75, 'eval_recall': 0.6206896551724138, 'eval_runtime': 0.0136, 'eval_samples_per_second': 14716.082, 'eval_steps_per_second': 956.545, 'epoch': 68.0}


                                                    
 69%|██████▉   | 3451/5000 [00:10<00:04, 312.34it/s]

{'eval_loss': 0.5211723446846008, 'eval_accuracy': 0.74, 'eval_f1': 0.675, 'eval_auc': 0.7262740311260298, 'eval_precision': 0.7397260273972602, 'eval_recall': 0.6206896551724138, 'eval_runtime': 0.018, 'eval_samples_per_second': 11133.596, 'eval_steps_per_second': 723.684, 'epoch': 69.0}


                                                    
 70%|███████   | 3515/5000 [00:10<00:05, 284.71it/s]

{'eval_loss': 0.529462456703186, 'eval_accuracy': 0.74, 'eval_f1': 0.6708860759493671, 'eval_auc': 0.7249516834503102, 'eval_precision': 0.7464788732394366, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0207, 'eval_samples_per_second': 9673.433, 'eval_steps_per_second': 628.773, 'epoch': 70.0}


                                                    
 71%|███████   | 3551/5000 [00:10<00:05, 285.01it/s]

{'eval_loss': 0.5328310132026672, 'eval_accuracy': 0.74, 'eval_f1': 0.6708860759493671, 'eval_auc': 0.7249516834503102, 'eval_precision': 0.7464788732394366, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0145, 'eval_samples_per_second': 13782.998, 'eval_steps_per_second': 895.895, 'epoch': 71.0}


                                                    
 73%|███████▎  | 3646/5000 [00:11<00:04, 293.51it/s]

{'eval_loss': 0.5243722796440125, 'eval_accuracy': 0.745, 'eval_f1': 0.679245283018868, 'eval_auc': 0.7306988098870918, 'eval_precision': 0.75, 'eval_recall': 0.6206896551724138, 'eval_runtime': 0.015, 'eval_samples_per_second': 13309.969, 'eval_steps_per_second': 865.148, 'epoch': 72.0}


                                                    
 73%|███████▎  | 3650/5000 [00:11<00:04, 293.51it/s]

{'eval_loss': 0.5278418064117432, 'eval_accuracy': 0.74, 'eval_f1': 0.6708860759493671, 'eval_auc': 0.7249516834503102, 'eval_precision': 0.7464788732394366, 'eval_recall': 0.6091954022988506, 'eval_runtime': 0.0175, 'eval_samples_per_second': 11408.105, 'eval_steps_per_second': 741.527, 'epoch': 73.0}


 73%|███████▎  | 3650/5000 [00:11<00:04, 326.15it/s]


{'train_runtime': 11.191, 'train_samples_per_second': 7148.588, 'train_steps_per_second': 446.787, 'train_loss': 0.5470019865689212, 'epoch': 73.0}


100%|██████████| 13/13 [00:00<00:00, 552.46it/s]

{'eval_loss': 0.5169082880020142, 'eval_accuracy': 0.745, 'eval_f1': 0.6832298136645962, 'eval_auc': 0.7320211575628115, 'eval_precision': 0.7432432432432432, 'eval_recall': 0.632183908045977, 'eval_runtime': 0.0265, 'eval_samples_per_second': 7538.221, 'eval_steps_per_second': 489.984, 'epoch': 73.0}





In [8]:
trainer.evaluate(devdata)

100%|██████████| 13/13 [00:00<00:00, 886.83it/s]


{'eval_loss': 1.4113526344299316,
 'eval_accuracy': 0.465,
 'eval_f1': 0.17054263565891473,
 'eval_auc': 0.45475928335501953,
 'eval_precision': 0.34375,
 'eval_recall': 0.1134020618556701,
 'eval_runtime': 0.0154,
 'eval_samples_per_second': 12977.627,
 'eval_steps_per_second': 843.546,
 'epoch': 73.0}

In [9]:
trainer.save_model('SavedModels/'+'dense_'+str(round(SAMPLES_TO_TRAIN/1000))+'k')

In [None]:
# All samples trained with training dataset
print(''''eval_loss': 0.2924209237098694, 'eval_accuracy': 0.8761690046760187, 'eval_f1': 0.8671385056441498, 'eval_auc': 0.8748671641964536, 'eval_precision': 0.8905962458594038, 'eval_recall': 0.8448847765363129''')

'eval_loss': 0.2924209237098694, 'eval_accuracy': 0.8761690046760187, 'eval_f1': 0.8671385056441498, 'eval_auc': 0.8748671641964536, 'eval_precision': 0.8905962458594038, 'eval_recall': 0.8448847765363129
