In [1]:
'''Import libraries'''
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import wandb
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import os
from torch.utils.data import Dataset
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from imblearn.under_sampling import RandomUnderSampler
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForSequenceClassification, AdamW, Trainer, TrainingArguments
from tqdm import tqdm
from torch.nn import functional as F
import torch.nn as nn

wandb.login()

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33malberto-rodero557[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
'''Variables and parameters'''

SAMPLES_TO_TRAIN=10000
DIMENSIONS=200

N_LABELS=2
MAX_LEN = 256
EPOCHS=100
PATIENCE=10
LEARNING_RATE=.00005
WEIGHT_DECAY=.01
BATCH_SIZE=16
METRIC_FOR_BEST_MODEL='eval_loss'
if METRIC_FOR_BEST_MODEL=='eval_loss':
    GREATER_IS_BETTER = False
else:
    GREATER_IS_BETTER = True

In [3]:
'''Preparing dataset'''

df = pd.read_json(os.getcwd()+'/datasets/subtaskA_train_monolingual.jsonl', lines=True)
df = df[['text', 'label']]

df=df.sample(round(SAMPLES_TO_TRAIN))
# test_train_df=df.sample(round(SAMPLES_TO_TRAIN*.2))

# df = pd.read_json(os.getcwd()+'/datasets/subtaskA_dev_monolingual.jsonl', lines=True)
# df = df[['text', 'label']]

# val_df= df.sample(round(SAMPLES_TO_TRAIN*.2))
# test_dev_df= df.sample(round(SAMPLES_TO_TRAIN*.2))

# we balance the training set
print(f'Dataset size before balancing: {df.shape}')
counts = df['label'].value_counts()
sampler = RandomUnderSampler(random_state=42)
x_text, y = sampler.fit_resample(df[['text']], df['label'])

print(f'Dataset size after balancing: {x_text.shape}')
print(f'Entried dropped: {df.shape[0]-x_text.shape[0]}')

# Create a new balanced DataFrame
df = pd.DataFrame({'text': x_text['text'], 'label': y})

# Print the balanced DataFrame
print("\nBalanced DataFrame:")
print(df['label'].value_counts())

Dataset size before balancing: (10000, 2)
Dataset size after balancing: (9264, 1)
Entried dropped: 736

Balanced DataFrame:
label
0    4632
1    4632
Name: count, dtype: int64


In [4]:
'''loading glove'''
embeddings_index={}
with open('../0 playground and indoor/OtherData/glove.6B.200d.txt','r',encoding='utf-8') as f:
    for line in f:
        values=line.split()
        word=values[0]
        vectors=np.asarray(values[1:],'float32')
        embeddings_index[word]=vectors
f.close()
print('Found %s word vectors.' % len(embeddings_index))


Found 400000 word vectors.


In [5]:
'''glove building'''

from nltk.tokenize import word_tokenize
from tqdm import tqdm 

def sent2vec(s):
    """ Function Creates a normalized vector for the whole sentence"""
    words = str(s).lower()
    words = word_tokenize(words)
    words = [w for w in words if w.isalpha()]
    M = []
    for w in words:
        try:
            M.append(embeddings_index[w])
        except:
            continue
    M = np.array(M)
    v = M.sum(axis=0)
    if type(v) != np.ndarray:
        return np.zeros(200)
    return v / np.sqrt((v ** 2).sum())

print('Training df:')
df_x = np.array([sent2vec(x) for x in tqdm(df['text'])])
print(df_x.shape)
train_y=df['label']


Training df:


100%|██████████| 9264/9264 [00:12<00:00, 726.20it/s] 

(9264, 200)





In [6]:
'''Preparing for training'''

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# Initialize the StandardScaler
scaler = StandardScaler()
# Fit the scaler to the training data and transform the data
train_x = scaler.fit_transform(df_x)

import pickle

# Save the trained scaler
with open('scaler.pkl', 'wb') as scaler_file:
    pickle.dump(scaler, scaler_file)


In [7]:
'''metrics'''

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'auc': auc,
        'precision': precision,
        'recall': recall,
    }

In [8]:
class Data(Dataset):
    def __init__(self, X_train, y_train):
        self.X = torch.from_numpy(X_train.astype(np.float32))
        self.y = torch.from_numpy(y_train).type(torch.LongTensor)
        self.len = self.X.shape[0]

    def __getitem__(self, index):
        return {'input_ids': self.X[index], 'labels': self.y[index]}

    def __len__(self):
        return self.len

X_train, X_test, y_train, y_test = train_test_split(train_x, train_y.values, test_size=0.2, random_state=42)
traindata = Data(X_train, y_train)
testdata = Data(X_test, y_test)

In [28]:
# number of features (len of X cols)
input_dim = train_x.shape[-1]

# number of classes (unique of y)
output_dim = 2

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(0.5)
        
        self.linear2 = nn.Linear(512, 512)
        self.bn2 = nn.BatchNorm1d(512)
        self.dropout2 = nn.Dropout(0.5)
        
        self.linear3 = nn.Linear(512, 256)
        self.bn3 = nn.BatchNorm1d(256)
        self.dropout3 = nn.Dropout(0.5)
        
        self.linear4 = nn.Linear(256, output_dim)
        
        self.loss = nn.CrossEntropyLoss()

    def forward(self, input_ids, labels=None):
        x1 = F.leaky_relu(self.linear1(input_ids))
        x1 = self.bn1(x1)
        x1 = self.dropout1(x1)
        
        x2 = F.leaky_relu(self.linear2(x1))
        x2 = self.bn2(x2)
        x2 = self.dropout2(x2)
        
        # Adding the first skip connection
        x2 += x1
        
        x3 = F.leaky_relu(self.linear3(x2))
        x3 = self.bn3(x3)
        x3 = self.dropout3(x3)
        
        x4 = self.linear4(x3)
        
        outputs = (x4,)
        if labels is not None:
            loss = self.loss(x4, labels)
            outputs = (loss,) + outputs
            
        return (outputs if len(outputs) > 1 else outputs[0])

# Create the model
model = Network()

In [29]:
from transformers import EarlyStoppingCallback

model = Network()

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=500,
    weight_decay=WEIGHT_DECAY,
    metric_for_best_model=METRIC_FOR_BEST_MODEL,
    greater_is_better=GREATER_IS_BETTER,
    logging_dir='./logs',
    logging_steps=15000,
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    push_to_hub=False,
    logging_first_step=False,
    load_best_model_at_end=True,
    save_total_limit=2,
    report_to="wandb"
)

# Create trainer
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=traindata,             # training dataset
    eval_dataset=testdata, 
    compute_metrics=compute_metrics,# training dataset
    callbacks=[EarlyStoppingCallback(early_stopping_patience=PATIENCE)]
)

# Train the model
print(trainer.evaluate())

trainer.train()

print(trainer.evaluate())

100%|██████████| 116/116 [00:00<00:00, 998.91it/s] 


{'eval_loss': 0.696049690246582, 'eval_accuracy': 0.49487317862924984, 'eval_f1': 0.11698113207547171, 'eval_auc': 0.49189500908709627, 'eval_precision': 0.44285714285714284, 'eval_recall': 0.06739130434782609, 'eval_runtime': 0.1171, 'eval_samples_per_second': 15820.313, 'eval_steps_per_second': 990.37}


  1%|          | 447/46400 [00:01<01:53, 404.60it/s]
  1%|          | 528/46400 [00:01<02:24, 317.11it/s]

{'eval_loss': 0.5957270860671997, 'eval_accuracy': 0.6729627630868862, 'eval_f1': 0.6464410735122521, 'eval_auc': 0.6724695931776876, 'eval_precision': 0.6977329974811083, 'eval_recall': 0.6021739130434782, 'eval_runtime': 0.1106, 'eval_samples_per_second': 16748.726, 'eval_steps_per_second': 1048.49, 'epoch': 1.0}


  2%|▏         | 894/46400 [00:02<01:55, 395.33it/s]
  2%|▏         | 976/46400 [00:02<02:24, 314.44it/s]

{'eval_loss': 0.536727249622345, 'eval_accuracy': 0.7339449541284404, 'eval_f1': 0.7168294083859851, 'eval_auc': 0.7335570157043665, 'eval_precision': 0.7600487210718636, 'eval_recall': 0.6782608695652174, 'eval_runtime': 0.1226, 'eval_samples_per_second': 15108.688, 'eval_steps_per_second': 945.822, 'epoch': 2.0}


  3%|▎         | 1354/46400 [00:03<01:48, 413.34it/s]
  3%|▎         | 1435/46400 [00:03<02:19, 321.72it/s]

{'eval_loss': 0.5083063244819641, 'eval_accuracy': 0.7555315704263357, 'eval_f1': 0.746218487394958, 'eval_auc': 0.7553112912996879, 'eval_precision': 0.7699421965317919, 'eval_recall': 0.7239130434782609, 'eval_runtime': 0.1106, 'eval_samples_per_second': 16747.752, 'eval_steps_per_second': 1048.429, 'epoch': 3.0}


  4%|▍         | 1847/46400 [00:04<01:46, 416.57it/s]
  4%|▍         | 1929/46400 [00:05<02:16, 326.72it/s]

{'eval_loss': 0.49483826756477356, 'eval_accuracy': 0.7560712358337831, 'eval_f1': 0.7280385078219013, 'eval_auc': 0.7553852695838575, 'eval_precision': 0.8153638814016172, 'eval_recall': 0.657608695652174, 'eval_runtime': 0.1083, 'eval_samples_per_second': 17106.533, 'eval_steps_per_second': 1070.889, 'epoch': 4.0}


  5%|▍         | 2283/46400 [00:06<01:53, 390.01it/s]
  5%|▌         | 2358/46400 [00:06<02:28, 295.61it/s]

{'eval_loss': 0.479798287153244, 'eval_accuracy': 0.7706422018348624, 'eval_f1': 0.7539085118702953, 'eval_auc': 0.7702030616524536, 'eval_precision': 0.8066914498141264, 'eval_recall': 0.7076086956521739, 'eval_runtime': 0.1188, 'eval_samples_per_second': 15599.531, 'eval_steps_per_second': 976.549, 'epoch': 5.0}


  6%|▌         | 2752/46400 [00:07<01:54, 382.69it/s]
  6%|▌         | 2830/46400 [00:07<02:25, 298.50it/s]

{'eval_loss': 0.4714539647102356, 'eval_accuracy': 0.7733405288720993, 'eval_f1': 0.7627118644067797, 'eval_auc': 0.773064331981919, 'eval_precision': 0.7941176470588235, 'eval_recall': 0.7336956521739131, 'eval_runtime': 0.1182, 'eval_samples_per_second': 15676.431, 'eval_steps_per_second': 981.363, 'epoch': 6.0}


  7%|▋         | 3224/46400 [00:08<01:49, 394.91it/s]
  7%|▋         | 3303/46400 [00:09<02:21, 305.65it/s]

{'eval_loss': 0.46420711278915405, 'eval_accuracy': 0.7749595250944414, 'eval_f1': 0.7559976594499709, 'eval_auc': 0.7744524441959085, 'eval_precision': 0.8187579214195184, 'eval_recall': 0.7021739130434783, 'eval_runtime': 0.1139, 'eval_samples_per_second': 16269.346, 'eval_steps_per_second': 1018.48, 'epoch': 7.0}


  8%|▊         | 3703/46400 [00:10<01:47, 397.33it/s]
  8%|▊         | 3743/46400 [00:10<02:34, 276.30it/s]

{'eval_loss': 0.4827921390533447, 'eval_accuracy': 0.7598488936859147, 'eval_f1': 0.722048719550281, 'eval_auc': 0.7589321496807867, 'eval_precision': 0.8487518355359766, 'eval_recall': 0.6282608695652174, 'eval_runtime': 0.121, 'eval_samples_per_second': 15315.763, 'eval_steps_per_second': 958.785, 'epoch': 8.0}


  9%|▉         | 4136/46400 [00:11<01:51, 379.84it/s]
  9%|▉         | 4216/46400 [00:11<02:22, 296.43it/s]

{'eval_loss': 0.45511260628700256, 'eval_accuracy': 0.7803561791689153, 'eval_f1': 0.7581699346405228, 'eval_auc': 0.7797509203597558, 'eval_precision': 0.836173001310616, 'eval_recall': 0.6934782608695652, 'eval_runtime': 0.126, 'eval_samples_per_second': 14708.362, 'eval_steps_per_second': 920.761, 'epoch': 9.0}


 10%|▉         | 4619/46400 [00:12<01:45, 395.11it/s]
 10%|█         | 4698/46400 [00:13<02:15, 307.93it/s]

{'eval_loss': 0.4568037688732147, 'eval_accuracy': 0.7787371829465731, 'eval_f1': 0.7524154589371981, 'eval_auc': 0.7780296146139148, 'eval_precision': 0.8464673913043478, 'eval_recall': 0.6771739130434783, 'eval_runtime': 0.1083, 'eval_samples_per_second': 17109.282, 'eval_steps_per_second': 1071.061, 'epoch': 10.0}


 11%|█         | 5096/46400 [00:14<01:44, 395.29it/s]
 11%|█         | 5136/46400 [00:14<02:31, 272.81it/s]

{'eval_loss': 0.44686198234558105, 'eval_accuracy': 0.7857528332433891, 'eval_f1': 0.7627017334130305, 'eval_auc': 0.785109977165758, 'eval_precision': 0.8472775564409031, 'eval_recall': 0.6934782608695652, 'eval_runtime': 0.1212, 'eval_samples_per_second': 15290.122, 'eval_steps_per_second': 957.18, 'epoch': 11.0}


 12%|█▏        | 5530/46400 [00:15<01:48, 377.49it/s]
 12%|█▏        | 5607/46400 [00:15<02:17, 296.01it/s]

{'eval_loss': 0.4366384446620941, 'eval_accuracy': 0.7889908256880734, 'eval_f1': 0.7748992515831894, 'eval_auc': 0.7885904515587866, 'eval_precision': 0.8237454100367197, 'eval_recall': 0.7315217391304348, 'eval_runtime': 0.1182, 'eval_samples_per_second': 15680.163, 'eval_steps_per_second': 981.597, 'epoch': 12.0}


 13%|█▎        | 6029/46400 [00:17<01:43, 391.83it/s]
 13%|█▎        | 6069/46400 [00:17<02:24, 278.89it/s]

{'eval_loss': 0.4387155771255493, 'eval_accuracy': 0.7889908256880734, 'eval_f1': 0.7756741250717155, 'eval_auc': 0.7886131692995946, 'eval_precision': 0.8213851761846902, 'eval_recall': 0.7347826086956522, 'eval_runtime': 0.1127, 'eval_samples_per_second': 16438.232, 'eval_steps_per_second': 1029.053, 'epoch': 13.0}


 14%|█▍        | 6468/46400 [00:18<01:41, 393.22it/s]
 14%|█▍        | 6550/46400 [00:18<02:07, 311.93it/s]

{'eval_loss': 0.4485681354999542, 'eval_accuracy': 0.7873718294657313, 'eval_f1': 0.7632211538461539, 'eval_auc': 0.7866949764667505, 'eval_precision': 0.853494623655914, 'eval_recall': 0.6902173913043478, 'eval_runtime': 0.1127, 'eval_samples_per_second': 16449.052, 'eval_steps_per_second': 1029.73, 'epoch': 14.0}


 15%|█▍        | 6958/46400 [00:19<01:39, 396.74it/s]
 15%|█▌        | 6998/46400 [00:19<02:20, 280.77it/s]

{'eval_loss': 0.43232131004333496, 'eval_accuracy': 0.7911494873178629, 'eval_f1': 0.7738164815897135, 'eval_auc': 0.7906507758982245, 'eval_precision': 0.8369152970922883, 'eval_recall': 0.7195652173913043, 'eval_runtime': 0.1178, 'eval_samples_per_second': 15723.559, 'eval_steps_per_second': 984.313, 'epoch': 15.0}


 16%|█▌        | 7396/46400 [00:20<01:38, 395.46it/s]
 16%|█▌        | 7478/46400 [00:21<02:04, 313.63it/s]

{'eval_loss': 0.43559813499450684, 'eval_accuracy': 0.7900701565029682, 'eval_f1': 0.7694131594546532, 'eval_auc': 0.7894805209935224, 'eval_precision': 0.8461538461538461, 'eval_recall': 0.7054347826086956, 'eval_runtime': 0.1081, 'eval_samples_per_second': 17148.505, 'eval_steps_per_second': 1073.517, 'epoch': 16.0}


 17%|█▋        | 7880/46400 [00:22<01:36, 399.95it/s]
 17%|█▋        | 7921/46400 [00:22<02:17, 280.29it/s]

{'eval_loss': 0.43891045451164246, 'eval_accuracy': 0.7916891527253103, 'eval_f1': 0.7680288461538461, 'eval_auc': 0.7910125122326297, 'eval_precision': 0.8588709677419355, 'eval_recall': 0.6945652173913044, 'eval_runtime': 0.1236, 'eval_samples_per_second': 14994.049, 'eval_steps_per_second': 938.645, 'epoch': 17.0}


 18%|█▊        | 8325/46400 [00:23<01:35, 400.69it/s]
 18%|█▊        | 8406/46400 [00:23<02:04, 305.74it/s]

{'eval_loss': 0.42274120450019836, 'eval_accuracy': 0.8057204533189423, 'eval_f1': 0.7982062780269059, 'eval_auc': 0.8054988582878979, 'eval_precision': 0.8240740740740741, 'eval_recall': 0.7739130434782608, 'eval_runtime': 0.1241, 'eval_samples_per_second': 14934.81, 'eval_steps_per_second': 934.937, 'epoch': 18.0}


 19%|█▉        | 8798/46400 [00:24<01:37, 384.87it/s]
 19%|█▉        | 8876/46400 [00:25<02:06, 297.13it/s]

{'eval_loss': 0.4201640188694, 'eval_accuracy': 0.8062601187263896, 'eval_f1': 0.7975183305132544, 'eval_auc': 0.8059969010671513, 'eval_precision': 0.8288393903868698, 'eval_recall': 0.7684782608695652, 'eval_runtime': 0.114, 'eval_samples_per_second': 16254.578, 'eval_steps_per_second': 1017.556, 'epoch': 19.0}


 20%|█▉        | 9278/46400 [00:26<01:35, 389.75it/s]
 20%|██        | 9318/46400 [00:26<02:15, 274.28it/s]

{'eval_loss': 0.4190221130847931, 'eval_accuracy': 0.8078791149487318, 'eval_f1': 0.7968036529680366, 'eval_auc': 0.8075364648865277, 'eval_precision': 0.8389423076923077, 'eval_recall': 0.758695652173913, 'eval_runtime': 0.1259, 'eval_samples_per_second': 14719.783, 'eval_steps_per_second': 921.476, 'epoch': 20.0}


 21%|██        | 9711/46400 [00:27<01:34, 386.57it/s]
 21%|██        | 9784/46400 [00:27<02:06, 288.58it/s]

{'eval_loss': 0.42643725872039795, 'eval_accuracy': 0.8084187803561792, 'eval_f1': 0.7932440302853815, 'eval_auc': 0.807943636702549, 'eval_precision': 0.8544542032622334, 'eval_recall': 0.7402173913043478, 'eval_runtime': 0.1196, 'eval_samples_per_second': 15497.938, 'eval_steps_per_second': 970.189, 'epoch': 21.0}


 22%|██▏       | 10177/46400 [00:28<01:34, 385.11it/s]
 22%|██▏       | 10257/46400 [00:29<01:58, 304.15it/s]

{'eval_loss': 0.4238094091415405, 'eval_accuracy': 0.806799784133837, 'eval_f1': 0.7925840092699884, 'eval_auc': 0.8063586374015566, 'eval_precision': 0.8486352357320099, 'eval_recall': 0.7434782608695653, 'eval_runtime': 0.1192, 'eval_samples_per_second': 15539.677, 'eval_steps_per_second': 972.802, 'epoch': 22.0}


 23%|██▎       | 10661/46400 [00:30<01:30, 395.21it/s]
 23%|██▎       | 10742/46400 [00:30<01:53, 315.04it/s]

{'eval_loss': 0.42730122804641724, 'eval_accuracy': 0.802482460874258, 'eval_f1': 0.7824019024970272, 'eval_auc': 0.8018745048697516, 'eval_precision': 0.863517060367454, 'eval_recall': 0.7152173913043478, 'eval_runtime': 0.1136, 'eval_samples_per_second': 16315.658, 'eval_steps_per_second': 1021.38, 'epoch': 23.0}


 24%|██▍       | 11102/46400 [00:31<01:30, 391.91it/s]
 24%|██▍       | 11183/46400 [00:31<01:54, 308.62it/s]

{'eval_loss': 0.4162127673625946, 'eval_accuracy': 0.8030221262817053, 'eval_f1': 0.7957470621152771, 'eval_auc': 0.8028117573046275, 'eval_precision': 0.8200692041522492, 'eval_recall': 0.7728260869565218, 'eval_runtime': 0.1146, 'eval_samples_per_second': 16174.205, 'eval_steps_per_second': 1012.524, 'epoch': 24.0}


 25%|██▍       | 11591/46400 [00:32<01:25, 407.08it/s]
 25%|██▌       | 11632/46400 [00:33<02:00, 289.72it/s]

{'eval_loss': 0.4135163426399231, 'eval_accuracy': 0.8084187803561792, 'eval_f1': 0.7960941987363584, 'eval_auc': 0.8080345076657813, 'eval_precision': 0.8440925700365408, 'eval_recall': 0.7532608695652174, 'eval_runtime': 0.1156, 'eval_samples_per_second': 16032.207, 'eval_steps_per_second': 1003.635, 'epoch': 25.0}


 26%|██▌       | 12043/46400 [00:34<01:26, 396.00it/s]
 26%|██▌       | 12125/46400 [00:34<01:51, 308.67it/s]

{'eval_loss': 0.4337846636772156, 'eval_accuracy': 0.8019427954668106, 'eval_f1': 0.7777104784978801, 'eval_auc': 0.8012174379048417, 'eval_precision': 0.8782489740082079, 'eval_recall': 0.6978260869565217, 'eval_runtime': 0.1202, 'eval_samples_per_second': 15420.847, 'eval_steps_per_second': 965.363, 'epoch': 26.0}


 27%|██▋       | 12525/46400 [00:35<01:25, 397.43it/s]
 27%|██▋       | 12565/46400 [00:35<01:59, 283.41it/s]

{'eval_loss': 0.4194309711456299, 'eval_accuracy': 0.809498111171074, 'eval_f1': 0.7924750146972369, 'eval_auc': 0.808962440001864, 'eval_precision': 0.8629961587708067, 'eval_recall': 0.7326086956521739, 'eval_runtime': 0.1122, 'eval_samples_per_second': 16508.027, 'eval_steps_per_second': 1033.422, 'epoch': 27.0}


 28%|██▊       | 12962/46400 [00:36<01:25, 389.46it/s]
 28%|██▊       | 13042/46400 [00:37<01:47, 310.75it/s]

{'eval_loss': 0.408995658159256, 'eval_accuracy': 0.8105774419859687, 'eval_f1': 0.8004548038658329, 'eval_auc': 0.8102614287711449, 'eval_precision': 0.8390941597139452, 'eval_recall': 0.7652173913043478, 'eval_runtime': 0.1127, 'eval_samples_per_second': 16440.562, 'eval_steps_per_second': 1029.199, 'epoch': 28.0}


 29%|██▉       | 13435/46400 [00:38<01:23, 394.77it/s]
 29%|██▉       | 13515/46400 [00:38<01:47, 306.54it/s]

{'eval_loss': 0.4219759702682495, 'eval_accuracy': 0.806799784133837, 'eval_f1': 0.7876631079478055, 'eval_auc': 0.8062071857961695, 'eval_precision': 0.8668407310704961, 'eval_recall': 0.7217391304347827, 'eval_runtime': 0.1156, 'eval_samples_per_second': 16023.944, 'eval_steps_per_second': 1003.118, 'epoch': 29.0}


 30%|██▉       | 13918/46400 [00:39<01:21, 398.48it/s]
 30%|███       | 13959/46400 [00:39<01:55, 280.74it/s]

{'eval_loss': 0.40534698963165283, 'eval_accuracy': 0.8111171073934161, 'eval_f1': 0.8013620885357547, 'eval_auc': 0.8108124796122839, 'eval_precision': 0.838479809976247, 'eval_recall': 0.7673913043478261, 'eval_runtime': 0.1168, 'eval_samples_per_second': 15865.461, 'eval_steps_per_second': 993.197, 'epoch': 30.0}


 31%|███       | 14355/46400 [00:40<01:21, 394.52it/s]
 31%|███       | 14435/46400 [00:41<01:42, 310.77it/s]

{'eval_loss': 0.40440523624420166, 'eval_accuracy': 0.8078791149487318, 'eval_f1': 0.8050383351588172, 'eval_auc': 0.8078166503564937, 'eval_precision': 0.8112582781456954, 'eval_recall': 0.7989130434782609, 'eval_runtime': 0.1107, 'eval_samples_per_second': 16732.5, 'eval_steps_per_second': 1047.474, 'epoch': 31.0}


 32%|███▏      | 14825/46400 [00:42<01:21, 389.14it/s]
 32%|███▏      | 14907/46400 [00:42<01:42, 306.21it/s]

{'eval_loss': 0.40695229172706604, 'eval_accuracy': 0.8143550998381004, 'eval_f1': 0.8036529680365295, 'eval_auc': 0.8140127685353464, 'eval_precision': 0.8461538461538461, 'eval_recall': 0.7652173913043478, 'eval_runtime': 0.1134, 'eval_samples_per_second': 16336.99, 'eval_steps_per_second': 1022.715, 'epoch': 32.0}


 32%|███▏      | 15058/46400 [00:42<01:28, 355.50it/s]

{'loss': 0.4883, 'learning_rate': 3.420479302832244e-05, 'epoch': 32.33}


 33%|███▎      | 15295/46400 [00:43<01:20, 388.59it/s]
 33%|███▎      | 15335/46400 [00:43<01:54, 270.74it/s]

{'eval_loss': 0.41219261288642883, 'eval_accuracy': 0.8148947652455477, 'eval_f1': 0.799062683069713, 'eval_auc': 0.814382077450021, 'eval_precision': 0.866581956797967, 'eval_recall': 0.741304347826087, 'eval_runtime': 0.1198, 'eval_samples_per_second': 15464.88, 'eval_steps_per_second': 968.12, 'epoch': 33.0}


 34%|███▍      | 15772/46400 [00:44<01:18, 391.24it/s]
 34%|███▍      | 15812/46400 [00:45<01:49, 279.94it/s]

{'eval_loss': 0.40308094024658203, 'eval_accuracy': 0.8143550998381004, 'eval_f1': 0.8038768529076395, 'eval_auc': 0.8140203411156158, 'eval_precision': 0.8453237410071942, 'eval_recall': 0.7663043478260869, 'eval_runtime': 0.1158, 'eval_samples_per_second': 16007.541, 'eval_steps_per_second': 1002.091, 'epoch': 34.0}


 35%|███▍      | 16214/46400 [00:46<01:16, 393.43it/s]
 35%|███▌      | 16294/46400 [00:46<01:37, 310.09it/s]

{'eval_loss': 0.40509119629859924, 'eval_accuracy': 0.8148947652455477, 'eval_f1': 0.8071950534007869, 'eval_auc': 0.8146546903397176, 'eval_precision': 0.8358556461001164, 'eval_recall': 0.7804347826086957, 'eval_runtime': 0.1108, 'eval_samples_per_second': 16728.538, 'eval_steps_per_second': 1047.226, 'epoch': 35.0}


 36%|███▌      | 16692/46400 [00:47<01:16, 387.60it/s]
 36%|███▌      | 16731/46400 [00:47<01:47, 276.01it/s]

{'eval_loss': 0.400961309671402, 'eval_accuracy': 0.8154344306529951, 'eval_f1': 0.8069977426636569, 'eval_auc': 0.8151678782795098, 'eval_precision': 0.8392018779342723, 'eval_recall': 0.7771739130434783, 'eval_runtime': 0.1123, 'eval_samples_per_second': 16494.749, 'eval_steps_per_second': 1032.591, 'epoch': 36.0}


 37%|███▋      | 17130/46400 [00:48<01:14, 394.24it/s]
 37%|███▋      | 17210/46400 [00:49<01:33, 311.01it/s]

{'eval_loss': 0.40877342224121094, 'eval_accuracy': 0.8132757690232056, 'eval_f1': 0.7995365005793742, 'eval_auc': 0.8128349410503751, 'eval_precision': 0.8560794044665012, 'eval_recall': 0.75, 'eval_runtime': 0.114, 'eval_samples_per_second': 16249.107, 'eval_steps_per_second': 1017.213, 'epoch': 37.0}


 38%|███▊      | 17610/46400 [00:50<01:11, 400.46it/s]
 38%|███▊      | 17691/46400 [00:50<01:32, 309.73it/s]

{'eval_loss': 0.40702366828918457, 'eval_accuracy': 0.8121964382083109, 'eval_f1': 0.7995391705069124, 'eval_auc': 0.8117934200102521, 'eval_precision': 0.8504901960784313, 'eval_recall': 0.7543478260869565, 'eval_runtime': 0.1145, 'eval_samples_per_second': 16184.882, 'eval_steps_per_second': 1013.193, 'epoch': 38.0}


 39%|███▉      | 18088/46400 [00:51<01:12, 388.67it/s]
 39%|███▉      | 18128/46400 [00:51<01:40, 281.96it/s]

{'eval_loss': 0.40971097350120544, 'eval_accuracy': 0.8105774419859687, 'eval_f1': 0.7950963222416813, 'eval_auc': 0.81008725942495, 'eval_precision': 0.8587641866330391, 'eval_recall': 0.7402173913043478, 'eval_runtime': 0.1102, 'eval_samples_per_second': 16808.384, 'eval_steps_per_second': 1052.225, 'epoch': 39.0}


 40%|███▉      | 18530/46400 [00:52<01:10, 397.77it/s]
 40%|████      | 18611/46400 [00:53<01:29, 311.38it/s]

{'eval_loss': 0.4043671786785126, 'eval_accuracy': 0.8105774419859687, 'eval_f1': 0.8004548038658329, 'eval_auc': 0.8102614287711449, 'eval_precision': 0.8390941597139452, 'eval_recall': 0.7652173913043478, 'eval_runtime': 0.1166, 'eval_samples_per_second': 15889.302, 'eval_steps_per_second': 994.689, 'epoch': 40.0}


 41%|████      | 19018/46400 [00:54<01:07, 405.68it/s]
 41%|████      | 19059/46400 [00:54<01:36, 281.96it/s]

{'eval_loss': 0.4063046872615814, 'eval_accuracy': 0.8175930922827847, 'eval_f1': 0.8132596685082873, 'eval_auc': 0.817470525187567, 'eval_precision': 0.8269662921348314, 'eval_recall': 0.8, 'eval_runtime': 0.1168, 'eval_samples_per_second': 15865.818, 'eval_steps_per_second': 993.219, 'epoch': 41.0}


 42%|████▏     | 19468/46400 [00:55<01:06, 404.49it/s]
 42%|████▏     | 19548/46400 [00:55<01:26, 308.73it/s]

{'eval_loss': 0.410773366689682, 'eval_accuracy': 0.8057204533189423, 'eval_f1': 0.7919075144508669, 'eval_auc': 0.8052943986206252, 'eval_precision': 0.845679012345679, 'eval_recall': 0.7445652173913043, 'eval_runtime': 0.113, 'eval_samples_per_second': 16391.879, 'eval_steps_per_second': 1026.151, 'epoch': 42.0}


 43%|████▎     | 19952/46400 [00:56<01:04, 407.20it/s]
 43%|████▎     | 19993/46400 [00:56<01:32, 285.12it/s]

{'eval_loss': 0.4156760275363922, 'eval_accuracy': 0.8127361036157582, 'eval_f1': 0.797667638483965, 'eval_auc': 0.8122535998881588, 'eval_precision': 0.8603773584905661, 'eval_recall': 0.7434782608695653, 'eval_runtime': 0.1156, 'eval_samples_per_second': 16023.547, 'eval_steps_per_second': 1003.093, 'epoch': 43.0}


 44%|████▍     | 20391/46400 [00:57<01:06, 391.05it/s]
 44%|████▍     | 20470/46400 [00:58<01:25, 302.79it/s]

{'eval_loss': 0.4061633050441742, 'eval_accuracy': 0.8154344306529951, 'eval_f1': 0.8080808080808081, 'eval_auc': 0.8152057411808565, 'eval_precision': 0.8352668213457076, 'eval_recall': 0.782608695652174, 'eval_runtime': 0.1163, 'eval_samples_per_second': 15937.818, 'eval_steps_per_second': 997.726, 'epoch': 44.0}


 45%|████▍     | 20863/46400 [00:59<01:06, 386.00it/s]
 45%|████▌     | 20942/46400 [00:59<01:25, 298.39it/s]

{'eval_loss': 0.4016591012477875, 'eval_accuracy': 0.8192120885051268, 'eval_f1': 0.8135781858653311, 'eval_auc': 0.819040379328021, 'eval_precision': 0.8335233751425314, 'eval_recall': 0.7945652173913044, 'eval_runtime': 0.1133, 'eval_samples_per_second': 16355.039, 'eval_steps_per_second': 1023.845, 'epoch': 45.0}


 46%|████▌     | 21330/46400 [01:00<01:04, 387.13it/s]
 46%|████▌     | 21344/46400 [01:00<01:11, 350.88it/s]


{'eval_loss': 0.4066920876502991, 'eval_accuracy': 0.8202914193200216, 'eval_f1': 0.8214477211796246, 'eval_auc': 0.8203772309986486, 'eval_precision': 0.8105820105820106, 'eval_recall': 0.8326086956521739, 'eval_runtime': 0.1147, 'eval_samples_per_second': 16158.334, 'eval_steps_per_second': 1011.531, 'epoch': 46.0}
{'train_runtime': 60.83, 'train_samples_per_second': 12183.142, 'train_steps_per_second': 762.782, 'train_loss': 0.45767735791528064, 'epoch': 46.0}


100%|██████████| 116/116 [00:00<00:00, 943.12it/s]

{'eval_loss': 0.400961309671402, 'eval_accuracy': 0.8154344306529951, 'eval_f1': 0.8069977426636569, 'eval_auc': 0.8151678782795098, 'eval_precision': 0.8392018779342723, 'eval_recall': 0.7771739130434783, 'eval_runtime': 0.1245, 'eval_samples_per_second': 14884.328, 'eval_steps_per_second': 931.777, 'epoch': 46.0}





In [33]:
# 200 100 100 2
# bn and dp 0.3
# 'eval_loss': 0.41982632875442505, 'eval_accuracy': 0.8111171073934161, 'eval_f1': 0.8046875 epoch 79

# 100 to 256
# eval_loss': 0.40763354301452637, 'eval_accuracy': 0.813815434430653, 'eval_f1': 0.807585052983826

# leaky relu
# 'eval_loss': 0.41205745935440063, 'eval_accuracy': 0.8235294117647058, 'eval_f1': 0.8229561451001625

# 512 and dp 0.5
# 

100%|██████████| 250/250 [00:00<00:00, 1069.73it/s]


{'eval_loss': 0.6834174990653992,
 'eval_accuracy': 0.522,
 'eval_f1': 0.0020876826722338207,
 'eval_auc': 0.5005224660397074,
 'eval_precision': 1.0,
 'eval_recall': 0.0010449320794148381,
 'eval_runtime': 0.2357,
 'eval_samples_per_second': 8485.118,
 'eval_steps_per_second': 1060.64,
 'epoch': 9.0}