In [1]:
'''Import libraries'''
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import wandb
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import os
from torch.utils.data import Dataset
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from imblearn.under_sampling import RandomUnderSampler
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForSequenceClassification, AdamW, Trainer, TrainingArguments
from tqdm import tqdm
from torch.nn import functional as F
import torch.nn as nn

wandb.login()

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33malberto-rodero557[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
'''Variables and parameters'''

SAMPLES_TO_TRAIN=-1
DIMENSIONS=200

N_LABELS=2
MAX_LEN = 256
EPOCHS=100
PATIENCE=10
LEARNING_RATE=.00005
WEIGHT_DECAY=.01
BATCH_SIZE=16
METRIC_FOR_BEST_MODEL='eval_loss'
if METRIC_FOR_BEST_MODEL=='eval_loss':
    GREATER_IS_BETTER = False
else:
    GREATER_IS_BETTER = True

In [3]:
'''Preparing dataset'''
import pickle

file_path = 'datasets/subtaskA_glove_train_dev_monolingual.pkl'

# Load the data from the pickle file
with open(file_path, 'rb') as file:
    data = pickle.load(file)

# Extract the individual datasets from the loaded data
train_x = data['train_x']
train_y = data['train_y']
dev_x = data['dev_x']
dev_y = data['dev_y']

if SAMPLES_TO_TRAIN!=-1:
    random_indices = np.random.choice(train_x.shape[0], size=SAMPLES_TO_TRAIN, replace=False)
    train_x = train_x[random_indices]
    train_y = train_y[random_indices]

    random_indices_dev = np.random.choice(dev_x.shape[0], size=int(train_x.shape[0] * 0.2), replace=False)
    dev_x = dev_x[random_indices_dev]
    dev_y = pd.Series(dev_y.to_numpy()[random_indices_dev])

print(train_x.shape)
print(train_y.shape)
print(dev_x.shape)
print(dev_y.shape)

print(type(train_y))
print(type(dev_y))

(119757, 200)
(119757,)
(5000, 200)
(5000,)
<class 'pandas.core.series.Series'>
<class 'pandas.core.series.Series'>


In [4]:
'''metrics'''

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'auc': auc,
        'precision': precision,
        'recall': recall,
    }

In [5]:
class Data(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)
        self.len = len(X)

    def __getitem__(self, index):
        return {'x': self.X[index], 'label': self.y[index], 'label_ids': index}

    def __len__(self):
        return self.len

X_train, X_test, y_train, y_test = train_test_split(train_x, train_y.values, test_size=0.2, random_state=42)
traindata = Data(X_train, y_train)
testdata = Data(X_test, y_test)

devdata=Data(dev_x, dev_y.values)

In [6]:
# number of features (len of X cols)
input_dim = train_x.shape[-1]

# number of classes (unique of y)
output_dim = 2

class CNN1D(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CNN1D, self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=100, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm1d(100)
        self.conv2 = nn.Conv1d(in_channels=100, out_channels=150, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm1d(150)
        self.dropout1 = nn.Dropout(0.5)
        
        self.fc1 = nn.Linear(150 * 200, 256)
        self.dropout2 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x, labels=None):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout1(x)
        
        # Flatten the output for the dense layer
        x = torch.flatten(x, 1) 
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(x, labels)
            return loss, x
        
        return x

# Instantiate the model with appropriate dimensions
model = CNN1D(input_dim=200, num_classes=2)

In [7]:
from transformers import EarlyStoppingCallback

model = CNN1D(input_dim=200, num_classes=2)

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=500,
    weight_decay=WEIGHT_DECAY,
    metric_for_best_model=METRIC_FOR_BEST_MODEL,
    greater_is_better=GREATER_IS_BETTER,
    logging_dir='./logs',
    logging_steps=15000,
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    push_to_hub=False,
    logging_first_step=False,
    load_best_model_at_end=True,
    save_total_limit=2,
    report_to="wandb"
)

# Create trainer
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=traindata,             # training dataset
    eval_dataset=testdata, 
    compute_metrics=compute_metrics,# training dataset
    callbacks=[EarlyStoppingCallback(early_stopping_patience=PATIENCE)]
)

# Train the model
print(trainer.evaluate())

trainer.train()

print(trainer.evaluate())

 99%|█████████▉| 1479/1497 [00:01<00:00, 1209.74it/s]

100%|██████████| 1497/1497 [00:01<00:00, 789.28it/s] 


{'eval_loss': 0.6937636137008667, 'eval_accuracy': 0.47828991315965264, 'eval_f1': 0.6470854044283778, 'eval_auc': 0.5, 'eval_precision': 0.47828991315965264, 'eval_recall': 1.0, 'eval_runtime': 3.0591, 'eval_samples_per_second': 7829.627, 'eval_steps_per_second': 489.352}


                                                      
  1%|          | 6010/598800 [00:20<3:13:58, 50.94it/s]

{'eval_loss': 0.45240145921707153, 'eval_accuracy': 0.7842768871075484, 'eval_f1': 0.7459060732726825, 'eval_auc': 0.7791890052861609, 'eval_precision': 0.8541502421443856, 'eval_recall': 0.6620111731843575, 'eval_runtime': 1.5479, 'eval_samples_per_second': 15473.835, 'eval_steps_per_second': 967.115, 'epoch': 1.0}


                                                        
  2%|▏         | 11979/598800 [00:41<2:56:54, 55.28it/s]

{'eval_loss': 0.41438066959381104, 'eval_accuracy': 0.8105794923179692, 'eval_f1': 0.7932558669400774, 'eval_auc': 0.8084654128427242, 'eval_precision': 0.8298217179902755, 'eval_recall': 0.7597765363128491, 'eval_runtime': 1.4165, 'eval_samples_per_second': 16909.035, 'eval_steps_per_second': 1056.815, 'epoch': 2.0}


  3%|▎         | 15032/598800 [00:50<31:03, 313.28it/s]  

{'loss': 0.4943, 'learning_rate': 4.878823332776199e-05, 'epoch': 2.51}


                                                       
  3%|▎         | 17974/598800 [01:01<2:35:03, 62.43it/s]

{'eval_loss': 0.40199750661849976, 'eval_accuracy': 0.814128256513026, 'eval_f1': 0.7955922865013774, 'eval_auc': 0.8117212032990221, 'eval_precision': 0.8392096086788067, 'eval_recall': 0.7562849162011173, 'eval_runtime': 1.3778, 'eval_samples_per_second': 17383.869, 'eval_steps_per_second': 1086.492, 'epoch': 3.0}


                                                         
  4%|▍         | 23956/598800 [01:20<2:36:38, 61.16it/s]

{'eval_loss': 0.39006248116493225, 'eval_accuracy': 0.8210170340681363, 'eval_f1': 0.8030504892727522, 'eval_auc': 0.8185993819698281, 'eval_precision': 0.8476384443797885, 'eval_recall': 0.7629189944134078, 'eval_runtime': 1.4102, 'eval_samples_per_second': 16984.479, 'eval_steps_per_second': 1061.53, 'epoch': 4.0}


                                                         
  5%|▌         | 29952/598800 [01:40<2:27:29, 64.28it/s]

{'eval_loss': 0.3851029574871063, 'eval_accuracy': 0.8255260521042084, 'eval_f1': 0.8103471749489449, 'eval_auc': 0.8236036649046131, 'eval_precision': 0.8439360998203989, 'eval_recall': 0.7793296089385475, 'eval_runtime': 1.3735, 'eval_samples_per_second': 17438.262, 'eval_steps_per_second': 1089.891, 'epoch': 5.0}


  5%|▌         | 30036/598800 [01:40<1:18:53, 120.17it/s]

{'loss': 0.4362, 'learning_rate': 4.753468159786061e-05, 'epoch': 5.01}


                                                         
  6%|▌         | 35957/598800 [01:59<1:57:27, 79.87it/s]

{'eval_loss': 0.40294814109802246, 'eval_accuracy': 0.8128340013360054, 'eval_f1': 0.8121674278292202, 'eval_auc': 0.8142149621957238, 'eval_precision': 0.7809201514785271, 'eval_recall': 0.8460195530726257, 'eval_runtime': 1.3517, 'eval_samples_per_second': 17720.214, 'eval_steps_per_second': 1107.513, 'epoch': 6.0}


                                                         
  7%|▋         | 41918/598800 [02:19<2:25:56, 63.59it/s]

{'eval_loss': 0.37403255701065063, 'eval_accuracy': 0.8303690714762859, 'eval_f1': 0.8213201987774308, 'eval_auc': 0.8297344535726293, 'eval_precision': 0.8276167685899141, 'eval_recall': 0.8151187150837989, 'eval_runtime': 1.384, 'eval_samples_per_second': 17306.071, 'eval_steps_per_second': 1081.629, 'epoch': 7.0}


  8%|▊         | 45057/598800 [02:28<27:31, 335.21it/s]  

{'loss': 0.4152, 'learning_rate': 4.628112986795922e-05, 'epoch': 7.52}


                                                       
  8%|▊         | 47924/598800 [02:38<2:21:53, 64.71it/s]

{'eval_loss': 0.37612423300743103, 'eval_accuracy': 0.8285738142952572, 'eval_f1': 0.8216643502432244, 'eval_auc': 0.8284534291196647, 'eval_precision': 0.8176867219917012, 'eval_recall': 0.8256808659217877, 'eval_runtime': 1.3194, 'eval_samples_per_second': 18153.44, 'eval_steps_per_second': 1134.59, 'epoch': 8.0}


                                                         
  9%|▉         | 53912/598800 [02:57<2:17:35, 66.00it/s]

{'eval_loss': 0.36518043279647827, 'eval_accuracy': 0.832999331997328, 'eval_f1': 0.8180163785259328, 'eval_auc': 0.8309911685705906, 'eval_precision': 0.8542379323451159, 'eval_recall': 0.7847416201117319, 'eval_runtime': 1.3339, 'eval_samples_per_second': 17955.83, 'eval_steps_per_second': 1122.239, 'epoch': 9.0}


                                                         
 10%|█         | 59887/598800 [03:17<2:23:40, 62.51it/s]

{'eval_loss': 0.36101606488227844, 'eval_accuracy': 0.8380511022044088, 'eval_f1': 0.8229009724695248, 'eval_auc': 0.8359126317069507, 'eval_precision': 0.8626399923422993, 'eval_recall': 0.7866620111731844, 'eval_runtime': 1.3865, 'eval_samples_per_second': 17275.568, 'eval_steps_per_second': 1079.723, 'epoch': 10.0}


 10%|█         | 60042/598800 [03:17<48:22, 185.63it/s]  

{'loss': 0.4009, 'learning_rate': 4.5027578138057834e-05, 'epoch': 10.02}


                                                       
 11%|█         | 65875/598800 [03:37<2:16:49, 64.91it/s]

{'eval_loss': 0.36297622323036194, 'eval_accuracy': 0.8383433533734135, 'eval_f1': 0.8333620244448271, 'eval_auc': 0.8386264610261877, 'eval_precision': 0.8219015280135823, 'eval_recall': 0.8451466480446927, 'eval_runtime': 1.3641, 'eval_samples_per_second': 17558.247, 'eval_steps_per_second': 1097.39, 'epoch': 11.0}


                                                         
 12%|█▏        | 71878/598800 [03:56<2:14:50, 65.12it/s]

{'eval_loss': 0.3523162007331848, 'eval_accuracy': 0.841808617234469, 'eval_f1': 0.8282956450808899, 'eval_auc': 0.8399751048827245, 'eval_precision': 0.8612760343040241, 'eval_recall': 0.7977479050279329, 'eval_runtime': 1.3576, 'eval_samples_per_second': 17643.194, 'eval_steps_per_second': 1102.7, 'epoch': 12.0}


 13%|█▎        | 75056/598800 [04:06<25:30, 342.19it/s]  

{'loss': 0.3885, 'learning_rate': 4.377402640815644e-05, 'epoch': 12.53}


                                                       
 13%|█▎        | 77863/598800 [04:15<2:07:52, 67.89it/s]

{'eval_loss': 0.34894776344299316, 'eval_accuracy': 0.8463593854375417, 'eval_f1': 0.8348590917250045, 'eval_auc': 0.8449285894391234, 'eval_precision': 0.8590690801625416, 'eval_recall': 0.8119762569832403, 'eval_runtime': 1.3091, 'eval_samples_per_second': 18296.907, 'eval_steps_per_second': 1143.557, 'epoch': 13.0}


                                                         
 14%|█▍        | 83832/598800 [04:34<25:31, 336.34it/s]

{'eval_loss': 0.35879912972450256, 'eval_accuracy': 0.8421426185704742, 'eval_f1': 0.8363699311896827, 'eval_auc': 0.8421986097003555, 'eval_precision': 0.8293708694532658, 'eval_recall': 0.8434881284916201, 'eval_runtime': 1.3255, 'eval_samples_per_second': 18070.256, 'eval_steps_per_second': 1129.391, 'epoch': 14.0}


                                                         
 15%|█▌        | 89820/598800 [04:56<27:45, 305.66it/s]

{'eval_loss': 0.34714072942733765, 'eval_accuracy': 0.8461088844355378, 'eval_f1': 0.8335140018066848, 'eval_auc': 0.8444160790670893, 'eval_precision': 0.8636278547360539, 'eval_recall': 0.805429469273743, 'eval_runtime': 1.4324, 'eval_samples_per_second': 16721.918, 'eval_steps_per_second': 1045.12, 'epoch': 15.0}


 15%|█▌        | 90048/598800 [04:57<40:39, 208.53it/s]  

{'loss': 0.3811, 'learning_rate': 4.2520474678255056e-05, 'epoch': 15.03}


                                                       
 16%|█▌        | 95825/598800 [05:18<2:32:03, 55.13it/s]

{'eval_loss': 0.3427473306655884, 'eval_accuracy': 0.8442718770875084, 'eval_f1': 0.8256357516828722, 'eval_auc': 0.84121706655627, 'eval_precision': 0.8887882447665056, 'eval_recall': 0.7708624301675978, 'eval_runtime': 1.4042, 'eval_samples_per_second': 17057.416, 'eval_steps_per_second': 1066.088, 'epoch': 16.0}


                                                         
 17%|█▋        | 101800/598800 [05:40<2:24:11, 57.44it/s]

{'eval_loss': 0.3426991403102875, 'eval_accuracy': 0.8442718770875084, 'eval_f1': 0.8267211743937564, 'eval_auc': 0.8414604405253256, 'eval_precision': 0.8836146971201589, 'eval_recall': 0.7767108938547486, 'eval_runtime': 1.3706, 'eval_samples_per_second': 17474.929, 'eval_steps_per_second': 1092.183, 'epoch': 17.0}


 18%|█▊        | 105062/598800 [05:51<26:41, 308.35it/s]  

{'loss': 0.3721, 'learning_rate': 4.126692294835367e-05, 'epoch': 17.54}


                                                        
 18%|█▊        | 107791/598800 [06:02<2:26:47, 55.75it/s]

{'eval_loss': 0.33992695808410645, 'eval_accuracy': 0.8478206412825652, 'eval_f1': 0.8320818169254157, 'eval_auc': 0.8453446443644088, 'eval_precision': 0.8809872207589503, 'eval_recall': 0.788320530726257, 'eval_runtime': 1.3543, 'eval_samples_per_second': 17686.443, 'eval_steps_per_second': 1105.403, 'epoch': 18.0}


                                                          
 19%|█▉        | 113772/598800 [06:23<28:28, 283.88it/s]

{'eval_loss': 0.3361303210258484, 'eval_accuracy': 0.8500751503006012, 'eval_f1': 0.8353734011827809, 'eval_auc': 0.8477959315696106, 'eval_precision': 0.8796948923433426, 'eval_recall': 0.7953037709497207, 'eval_runtime': 1.3727, 'eval_samples_per_second': 17448.927, 'eval_steps_per_second': 1090.558, 'epoch': 19.0}


                                                          
 20%|██        | 119785/598800 [06:44<2:20:28, 56.83it/s]

{'eval_loss': 0.33826765418052673, 'eval_accuracy': 0.8516616566466266, 'eval_f1': 0.8405940149849701, 'eval_auc': 0.8502499570812381, 'eval_precision': 0.8647650696944521, 'eval_recall': 0.8177374301675978, 'eval_runtime': 1.3961, 'eval_samples_per_second': 17156.645, 'eval_steps_per_second': 1072.29, 'epoch': 20.0}


 20%|██        | 120045/598800 [06:45<33:50, 235.82it/s]  

{'loss': 0.3665, 'learning_rate': 4.0013371218452283e-05, 'epoch': 20.04}


                                                        
 21%|██        | 125767/598800 [07:05<2:21:38, 55.66it/s]

{'eval_loss': 0.33561617136001587, 'eval_accuracy': 0.8526636606546426, 'eval_f1': 0.8415712682379349, 'eval_auc': 0.8512284266160701, 'eval_precision': 0.8663462427211387, 'eval_recall': 0.8181738826815642, 'eval_runtime': 1.3801, 'eval_samples_per_second': 17355.891, 'eval_steps_per_second': 1084.743, 'epoch': 21.0}


                                                          
 22%|██▏       | 131736/598800 [07:25<24:29, 317.81it/s]

{'eval_loss': 0.33401724696159363, 'eval_accuracy': 0.8536656646626587, 'eval_f1': 0.8456151169448972, 'eval_auc': 0.8530096670040559, 'eval_precision': 0.8534720369876412, 'eval_recall': 0.8379015363128491, 'eval_runtime': 1.3783, 'eval_samples_per_second': 17377.738, 'eval_steps_per_second': 1086.109, 'epoch': 22.0}


 23%|██▎       | 135059/598800 [07:34<21:57, 351.85it/s]  

{'loss': 0.3587, 'learning_rate': 3.87598194885509e-05, 'epoch': 22.55}


                                                        
 23%|██▎       | 137725/598800 [07:44<1:58:43, 64.73it/s]

{'eval_loss': 0.3309616446495056, 'eval_accuracy': 0.852747160988644, 'eval_f1': 0.839118733749943, 'eval_auc': 0.8506727739468808, 'eval_precision': 0.8787618228718831, 'eval_recall': 0.8028980446927374, 'eval_runtime': 1.3696, 'eval_samples_per_second': 17487.842, 'eval_steps_per_second': 1092.99, 'epoch': 23.0}


                                                          
 24%|██▍       | 143712/598800 [08:03<22:19, 339.84it/s]

{'eval_loss': 0.338632732629776, 'eval_accuracy': 0.8517869071476286, 'eval_f1': 0.8339725002338415, 'eval_auc': 0.8487281293142297, 'eval_precision': 0.8982470280072536, 'eval_recall': 0.778282122905028, 'eval_runtime': 1.3784, 'eval_samples_per_second': 17376.992, 'eval_steps_per_second': 1086.062, 'epoch': 24.0}


                                                         
 25%|██▌       | 149709/598800 [08:23<1:55:37, 64.73it/s]

{'eval_loss': 0.33994778990745544, 'eval_accuracy': 0.852997661990648, 'eval_f1': 0.8476351205158162, 'eval_auc': 0.8530777893618696, 'eval_precision': 0.8404702651677679, 'eval_recall': 0.8549231843575419, 'eval_runtime': 1.3335, 'eval_samples_per_second': 17962.26, 'eval_steps_per_second': 1122.641, 'epoch': 25.0}


 25%|██▌       | 150041/598800 [08:24<25:13, 296.50it/s] 

{'loss': 0.3538, 'learning_rate': 3.750626775864951e-05, 'epoch': 25.05}


                                                        
 26%|██▌       | 155716/598800 [08:42<1:29:31, 82.49it/s]

{'eval_loss': 0.3265962600708008, 'eval_accuracy': 0.856629926519706, 'eval_f1': 0.843110380116959, 'eval_auc': 0.8544993056996116, 'eval_precision': 0.8844900306748467, 'eval_recall': 0.805429469273743, 'eval_runtime': 1.3246, 'eval_samples_per_second': 18081.832, 'eval_steps_per_second': 1130.114, 'epoch': 26.0}


                                                          
 27%|██▋       | 161679/598800 [09:01<1:52:25, 64.80it/s]

{'eval_loss': 0.3298064172267914, 'eval_accuracy': 0.8555026720106881, 'eval_f1': 0.8492004705677313, 'eval_auc': 0.8553005676900407, 'eval_precision': 0.8477598956067856, 'eval_recall': 0.8506459497206704, 'eval_runtime': 1.3094, 'eval_samples_per_second': 18292.853, 'eval_steps_per_second': 1143.303, 'epoch': 27.0}


 28%|██▊       | 165059/598800 [09:11<21:21, 338.47it/s] 

{'loss': 0.3464, 'learning_rate': 3.6252716028748125e-05, 'epoch': 27.56}


                                                        
 28%|██▊       | 167668/598800 [09:20<1:51:05, 64.69it/s]

{'eval_loss': 0.32529863715171814, 'eval_accuracy': 0.8570474281897128, 'eval_f1': 0.8462091268415378, 'eval_auc': 0.8556004960693567, 'eval_precision': 0.8715766099185789, 'eval_recall': 0.8222765363128491, 'eval_runtime': 1.3649, 'eval_samples_per_second': 17548.709, 'eval_steps_per_second': 1096.794, 'epoch': 28.0}


                                                         
 29%|██▉       | 173659/598800 [09:40<1:49:25, 64.76it/s]

{'eval_loss': 0.3284814953804016, 'eval_accuracy': 0.856003674014696, 'eval_f1': 0.8395888563322635, 'eval_auc': 0.8531689917309851, 'eval_precision': 0.8985564957690393, 'eval_recall': 0.7878840782122905, 'eval_runtime': 1.3615, 'eval_samples_per_second': 17592.725, 'eval_steps_per_second': 1099.545, 'epoch': 29.0}


                                                         
 30%|███       | 179640/598800 [09:59<21:02, 332.03it/s]

{'eval_loss': 0.3306010961532593, 'eval_accuracy': 0.8531646626586507, 'eval_f1': 0.8351380490320162, 'eval_auc': 0.8500194922710463, 'eval_precision': 0.9018932874354562, 'eval_recall': 0.7775837988826816, 'eval_runtime': 1.3438, 'eval_samples_per_second': 17824.121, 'eval_steps_per_second': 1114.008, 'epoch': 30.0}


 30%|███       | 180044/598800 [10:00<22:27, 310.67it/s]  

{'loss': 0.3437, 'learning_rate': 3.499916429884673e-05, 'epoch': 30.06}


                                                        
 31%|███       | 185635/598800 [10:18<1:43:32, 66.50it/s]

{'eval_loss': 0.33182939887046814, 'eval_accuracy': 0.8548346693386774, 'eval_f1': 0.8388263106661105, 'eval_auc': 0.852128547056846, 'eval_precision': 0.8943362656914104, 'eval_recall': 0.789804469273743, 'eval_runtime': 1.3076, 'eval_samples_per_second': 18317.708, 'eval_steps_per_second': 1144.857, 'epoch': 31.0}


                                                          
 32%|███▏      | 191616/598800 [10:36<20:13, 335.60it/s]

{'eval_loss': 0.3245355486869812, 'eval_accuracy': 0.8580494321977288, 'eval_f1': 0.84453589391861, 'eval_auc': 0.8558888006173149, 'eval_precision': 0.8867870174764739, 'eval_recall': 0.8061277932960894, 'eval_runtime': 1.2973, 'eval_samples_per_second': 18462.94, 'eval_steps_per_second': 1153.934, 'epoch': 32.0}


 33%|███▎      | 195041/598800 [10:48<22:49, 294.76it/s]  

{'loss': 0.3393, 'learning_rate': 3.3745612568945346e-05, 'epoch': 32.57}


                                                        
 33%|███▎      | 197610/598800 [10:59<2:04:36, 53.66it/s]

{'eval_loss': 0.323408305644989, 'eval_accuracy': 0.8582999331997327, 'eval_f1': 0.8466335291459558, 'eval_auc': 0.8566119929327105, 'eval_precision': 0.8776466179501593, 'eval_recall': 0.8177374301675978, 'eval_runtime': 1.3949, 'eval_samples_per_second': 17171.251, 'eval_steps_per_second': 1073.203, 'epoch': 33.0}


                                                         
 34%|███▍      | 203595/598800 [11:20<1:57:29, 56.06it/s]

{'eval_loss': 0.32271477580070496, 'eval_accuracy': 0.8592184368737475, 'eval_f1': 0.8477789815817983, 'eval_auc': 0.8575721884634369, 'eval_precision': 0.8778982797307404, 'eval_recall': 0.8196578212290503, 'eval_runtime': 1.3533, 'eval_samples_per_second': 17698.386, 'eval_steps_per_second': 1106.149, 'epoch': 34.0}


                                                         
 35%|███▌      | 209599/598800 [11:42<1:57:45, 55.09it/s]

{'eval_loss': 0.33041054010391235, 'eval_accuracy': 0.8557114228456913, 'eval_f1': 0.8383989525858038, 'eval_auc': 0.8526673228170446, 'eval_precision': 0.9028197381671702, 'eval_recall': 0.7825593575418994, 'eval_runtime': 1.4138, 'eval_samples_per_second': 16941.873, 'eval_steps_per_second': 1058.867, 'epoch': 35.0}


 35%|███▌      | 210039/598800 [11:43<23:51, 271.59it/s] 

{'loss': 0.3347, 'learning_rate': 3.249206083904396e-05, 'epoch': 35.07}


                                                        
 36%|███▌      | 215587/598800 [12:03<1:52:51, 56.59it/s]

{'eval_loss': 0.32615506649017334, 'eval_accuracy': 0.8588844355377422, 'eval_f1': 0.8463217241065745, 'eval_auc': 0.8569505929048133, 'eval_precision': 0.8831846650218258, 'eval_recall': 0.8124127094972067, 'eval_runtime': 1.406, 'eval_samples_per_second': 17035.679, 'eval_steps_per_second': 1064.73, 'epoch': 36.0}


                                                         
 37%|███▋      | 221574/598800 [12:25<1:51:54, 56.18it/s]

{'eval_loss': 0.3291877210140228, 'eval_accuracy': 0.8588844355377422, 'eval_f1': 0.8446976658702443, 'eval_auc': 0.8565328614653895, 'eval_precision': 0.8917345750873108, 'eval_recall': 0.8023743016759777, 'eval_runtime': 1.4021, 'eval_samples_per_second': 17082.632, 'eval_steps_per_second': 1067.665, 'epoch': 37.0}


 38%|███▊      | 225047/598800 [12:36<18:57, 328.58it/s] 

{'loss': 0.3291, 'learning_rate': 3.1238509109142574e-05, 'epoch': 37.58}


                                                        
 38%|███▊      | 227553/598800 [12:45<1:38:14, 62.98it/s]

{'eval_loss': 0.31859561800956726, 'eval_accuracy': 0.8606796927187709, 'eval_f1': 0.8485453637725231, 'eval_auc': 0.8588200738202705, 'eval_precision': 0.8838044814219533, 'eval_recall': 0.8159916201117319, 'eval_runtime': 1.3825, 'eval_samples_per_second': 17324.921, 'eval_steps_per_second': 1082.808, 'epoch': 38.0}


                                                         
 39%|███▉      | 233542/598800 [13:04<1:32:13, 66.01it/s]

{'eval_loss': 0.31928059458732605, 'eval_accuracy': 0.8609719438877755, 'eval_f1': 0.8488562091503268, 'eval_auc': 0.8591110607908498, 'eval_precision': 0.8841717095310136, 'eval_recall': 0.8162534916201117, 'eval_runtime': 1.3153, 'eval_samples_per_second': 18209.938, 'eval_steps_per_second': 1138.121, 'epoch': 39.0}


                                                          
 40%|████      | 239543/598800 [13:24<1:44:15, 57.43it/s]

{'eval_loss': 0.31983593106269836, 'eval_accuracy': 0.8614311957247829, 'eval_f1': 0.8511325409284592, 'eval_auc': 0.8600488469159293, 'eval_precision': 0.8753575053049174, 'eval_recall': 0.8282122905027933, 'eval_runtime': 1.3655, 'eval_samples_per_second': 17540.696, 'eval_steps_per_second': 1096.294, 'epoch': 40.0}


 40%|████      | 240053/598800 [13:26<18:43, 319.26it/s] 

{'loss': 0.3269, 'learning_rate': 2.9984957379241185e-05, 'epoch': 40.08}


                                                        
 41%|████      | 245518/598800 [13:43<1:31:40, 64.23it/s]

{'eval_loss': 0.32272541522979736, 'eval_accuracy': 0.8606379425517702, 'eval_f1': 0.84683857942553, 'eval_auc': 0.8583441673402528, 'eval_precision': 0.8926291352292514, 'eval_recall': 0.8055167597765364, 'eval_runtime': 1.3798, 'eval_samples_per_second': 17358.422, 'eval_steps_per_second': 1084.901, 'epoch': 41.0}


                                                         
 42%|████▏     | 251518/598800 [14:02<1:33:03, 62.20it/s]

{'eval_loss': 0.31540951132774353, 'eval_accuracy': 0.863059452237809, 'eval_f1': 0.850909090909091, 'eval_auc': 0.8611443930214093, 'eval_precision': 0.8877086494688923, 'eval_recall': 0.8170391061452514, 'eval_runtime': 1.3645, 'eval_samples_per_second': 17554.277, 'eval_steps_per_second': 1097.142, 'epoch': 42.0}


 43%|████▎     | 255045/598800 [14:13<17:55, 319.70it/s] 

{'loss': 0.3228, 'learning_rate': 2.87314056493398e-05, 'epoch': 42.59}


                                                        
 43%|████▎     | 257493/598800 [14:22<1:28:58, 63.93it/s]

{'eval_loss': 0.3241121470928192, 'eval_accuracy': 0.8603456913827655, 'eval_f1': 0.8466932490031625, 'eval_auc': 0.8580967697372656, 'eval_precision': 0.8913442053459423, 'eval_recall': 0.806302374301676, 'eval_runtime': 1.3596, 'eval_samples_per_second': 17616.673, 'eval_steps_per_second': 1101.042, 'epoch': 43.0}


                                                         
 44%|████▍     | 263497/598800 [14:41<1:25:48, 65.12it/s]

{'eval_loss': 0.335957795381546, 'eval_accuracy': 0.8585921843687375, 'eval_f1': 0.8425603123692652, 'eval_auc': 0.8557841861350939, 'eval_precision': 0.9011633687978522, 'eval_recall': 0.7911138268156425, 'eval_runtime': 1.3289, 'eval_samples_per_second': 18023.615, 'eval_steps_per_second': 1126.476, 'epoch': 44.0}


                                                         
 45%|████▌     | 269469/598800 [15:00<1:23:05, 66.05it/s]

{'eval_loss': 0.3231247365474701, 'eval_accuracy': 0.8631012024048096, 'eval_f1': 0.8554933674143933, 'eval_auc': 0.8624412325910772, 'eval_precision': 0.8639074321317312, 'eval_recall': 0.8472416201117319, 'eval_runtime': 1.3283, 'eval_samples_per_second': 18032.123, 'eval_steps_per_second': 1127.008, 'epoch': 45.0}


 45%|████▌     | 270027/598800 [15:02<16:23, 334.32it/s] 

{'loss': 0.3181, 'learning_rate': 2.7477853919438413e-05, 'epoch': 45.09}


                                                        
 46%|████▌     | 275476/598800 [15:20<1:09:48, 77.19it/s]

{'eval_loss': 0.31528955698013306, 'eval_accuracy': 0.864562458249833, 'eval_f1': 0.8550621034760075, 'eval_auc': 0.8633440354544739, 'eval_precision': 0.8758008420281896, 'eval_recall': 0.8352828212290503, 'eval_runtime': 1.3554, 'eval_samples_per_second': 17671.035, 'eval_steps_per_second': 1104.44, 'epoch': 46.0}


                                                         
 47%|████▋     | 281454/598800 [15:40<1:20:30, 65.69it/s]

{'eval_loss': 0.3201102316379547, 'eval_accuracy': 0.8650217100868404, 'eval_f1': 0.8562663939892411, 'eval_auc': 0.8640057555848039, 'eval_precision': 0.872519706441968, 'eval_recall': 0.8406075418994413, 'eval_runtime': 1.3377, 'eval_samples_per_second': 17905.4, 'eval_steps_per_second': 1119.087, 'epoch': 47.0}


 48%|████▊     | 285035/598800 [15:51<16:43, 312.76it/s] 

{'loss': 0.315, 'learning_rate': 2.6224302189537027e-05, 'epoch': 47.6}


                                                        
 48%|████▊     | 287427/598800 [15:59<1:19:11, 65.53it/s]

{'eval_loss': 0.3239298164844513, 'eval_accuracy': 0.8618904475617902, 'eval_f1': 0.8463824649391659, 'eval_auc': 0.8591268200237484, 'eval_precision': 0.9042468743798373, 'eval_recall': 0.7954783519553073, 'eval_runtime': 1.359, 'eval_samples_per_second': 17624.341, 'eval_steps_per_second': 1101.521, 'epoch': 48.0}


                                                         
 49%|████▉     | 293423/598800 [16:18<1:17:07, 65.99it/s]

{'eval_loss': 0.32021644711494446, 'eval_accuracy': 0.864437207748831, 'eval_f1': 0.8565622653178425, 'eval_auc': 0.8636816854018985, 'eval_precision': 0.8670959663715231, 'eval_recall': 0.8462814245810056, 'eval_runtime': 1.3424, 'eval_samples_per_second': 17842.403, 'eval_steps_per_second': 1115.15, 'epoch': 49.0}


                                                         
 50%|█████     | 299404/598800 [16:38<1:10:39, 70.62it/s]

{'eval_loss': 0.3180707097053528, 'eval_accuracy': 0.865439211756847, 'eval_f1': 0.857723038891096, 'eval_auc': 0.8647146416462206, 'eval_precision': 0.8676431186925069, 'eval_recall': 0.8480272346368715, 'eval_runtime': 1.2427, 'eval_samples_per_second': 19274.182, 'eval_steps_per_second': 1204.636, 'epoch': 50.0}


 50%|█████     | 300035/598800 [16:40<14:18, 348.04it/s] 

{'loss': 0.3109, 'learning_rate': 2.4970750459635637e-05, 'epoch': 50.1}


                                                        
 51%|█████     | 305408/598800 [16:57<1:09:06, 70.76it/s]

{'eval_loss': 0.3156368136405945, 'eval_accuracy': 0.8655227120908484, 'eval_f1': 0.8560189531089357, 'eval_auc': 0.8642861246325081, 'eval_precision': 0.8772331653687586, 'eval_recall': 0.83580656424581, 'eval_runtime': 1.2402, 'eval_samples_per_second': 19312.26, 'eval_steps_per_second': 1207.016, 'epoch': 51.0}


                                                         
 52%|█████▏    | 311391/598800 [17:17<1:19:53, 59.96it/s]

{'eval_loss': 0.3191133737564087, 'eval_accuracy': 0.8619739478957916, 'eval_f1': 0.8467031438375219, 'eval_auc': 0.8592685972360319, 'eval_precision': 0.9030662710187932, 'eval_recall': 0.7969622905027933, 'eval_runtime': 1.3854, 'eval_samples_per_second': 17288.345, 'eval_steps_per_second': 1080.522, 'epoch': 52.0}


 53%|█████▎    | 315069/598800 [17:30<15:14, 310.31it/s] 

{'loss': 0.3079, 'learning_rate': 2.3717198729734248e-05, 'epoch': 52.61}


                                                        
 53%|█████▎    | 317383/598800 [17:39<1:20:04, 58.57it/s]

{'eval_loss': 0.32778626680374146, 'eval_accuracy': 0.8605961923847696, 'eval_f1': 0.8432762262379723, 'eval_auc': 0.857414204947818, 'eval_precision': 0.9120722916032085, 'eval_recall': 0.7841305865921788, 'eval_runtime': 1.3145, 'eval_samples_per_second': 18221.729, 'eval_steps_per_second': 1138.858, 'epoch': 53.0}


                                                         
 54%|█████▍    | 323360/598800 [18:00<1:14:51, 61.32it/s]

{'eval_loss': 0.3171077370643616, 'eval_accuracy': 0.8632264529058116, 'eval_f1': 0.8483614145528606, 'eval_auc': 0.8605924845671283, 'eval_precision': 0.9030350808040993, 'eval_recall': 0.7999301675977654, 'eval_runtime': 1.2818, 'eval_samples_per_second': 18685.968, 'eval_steps_per_second': 1167.873, 'epoch': 54.0}


                                                         
 55%|█████▌    | 329361/598800 [18:21<1:10:07, 64.04it/s]

{'eval_loss': 0.3217604160308838, 'eval_accuracy': 0.8618486973947895, 'eval_f1': 0.8453087747183395, 'eval_auc': 0.8588252710140989, 'eval_precision': 0.9100150981378963, 'eval_recall': 0.78919343575419, 'eval_runtime': 1.2714, 'eval_samples_per_second': 18839.27, 'eval_steps_per_second': 1177.454, 'epoch': 55.0}


 55%|█████▌    | 330056/598800 [18:24<14:16, 313.82it/s] 

{'loss': 0.3051, 'learning_rate': 2.246364699983286e-05, 'epoch': 55.11}


                                                        
 56%|█████▌    | 335328/598800 [18:43<14:42, 298.55it/s]


{'eval_loss': 0.31642377376556396, 'eval_accuracy': 0.8662324649298597, 'eval_f1': 0.8592267135325132, 'eval_auc': 0.8657037291039279, 'eval_precision': 0.8650035385704176, 'eval_recall': 0.8535265363128491, 'eval_runtime': 1.2752, 'eval_samples_per_second': 18783.159, 'eval_steps_per_second': 1173.947, 'epoch': 56.0}
{'train_runtime': 1123.1755, 'train_samples_per_second': 8529.834, 'train_steps_per_second': 533.131, 'train_loss': 0.35676768882002285, 'epoch': 56.0}


100%|██████████| 1497/1497 [00:01<00:00, 1141.13it/s]

{'eval_loss': 0.31528955698013306, 'eval_accuracy': 0.864562458249833, 'eval_f1': 0.8550621034760075, 'eval_auc': 0.8633440354544739, 'eval_precision': 0.8758008420281896, 'eval_recall': 0.8352828212290503, 'eval_runtime': 1.3139, 'eval_samples_per_second': 18230.29, 'eval_steps_per_second': 1139.393, 'epoch': 56.0}





In [8]:
trainer.evaluate(devdata)

100%|██████████| 313/313 [00:00<00:00, 1103.29it/s]


{'eval_loss': 1.1675323247909546,
 'eval_accuracy': 0.5634,
 'eval_f1': 0.3821115199547127,
 'eval_auc': 0.5634,
 'eval_precision': 0.6534365924491772,
 'eval_recall': 0.27,
 'eval_runtime': 0.2857,
 'eval_samples_per_second': 17501.089,
 'eval_steps_per_second': 1095.568,
 'epoch': 56.0}

In [9]:
trainer.save_model('SavedModels/'+'cnn_'+str(round(SAMPLES_TO_TRAIN/1000))+'k')

In [11]:
# All samples trained with training dataset
print(''''eval_loss': 0.31528955698013306, 'eval_accuracy': 0.864562458249833, 'eval_f1': 0.8550621034760075, 'eval_auc': 0.8633440354544739, 'eval_precision': 0.8758008420281896, 'eval_recall': 0.8352828212290503''')

'eval_loss': 0.31528955698013306, 'eval_accuracy': 0.864562458249833, 'eval_f1': 0.8550621034760075, 'eval_auc': 0.8633440354544739, 'eval_precision': 0.8758008420281896, 'eval_recall': 0.8352828212290503
