In [39]:
# data manipulation
import numpy as np
import pandas as pd

# graphic representation
import matplotlib.pyplot as plt
import seaborn as sns

# file system management
import os

In [40]:
print(os.listdir("./Data"))

['cleaned_dataset.csv', 'test_df.csv', 'df_5000.csv', 'sample_dataset.csv', 'dataset.csv', 'train_df.csv', 'normalized_dataset.csv', 'tweets.csv']


In [41]:
data = pd.read_csv('./Data/tweets.csv', encoding='latin-1', header=None)
data = data.rename(columns={data.columns[0]: 'target'})
data = data.rename(columns={data.columns[1]: 'id'})
data = data.rename(columns={data.columns[2]: 'date'})
data = data.rename(columns={data.columns[3]: 'flag'})
data = data.rename(columns={data.columns[4]: 'user'})
data = data.rename(columns={data.columns[5]: 'text'})
data

Unnamed: 0,target,id,date,flag,user,text
0,0,1467810369,Mon Apr 06 22:19:45 PDT 2009,NO_QUERY,_TheSpecialOne_,"@switchfoot http://twitpic.com/2y1zl - Awww, t..."
1,0,1467810672,Mon Apr 06 22:19:49 PDT 2009,NO_QUERY,scotthamilton,is upset that he can't update his Facebook by ...
2,0,1467810917,Mon Apr 06 22:19:53 PDT 2009,NO_QUERY,mattycus,@Kenichan I dived many times for the ball. Man...
3,0,1467811184,Mon Apr 06 22:19:57 PDT 2009,NO_QUERY,ElleCTF,my whole body feels itchy and like its on fire
4,0,1467811193,Mon Apr 06 22:19:57 PDT 2009,NO_QUERY,Karoli,"@nationwideclass no, it's not behaving at all...."
...,...,...,...,...,...,...
1599995,4,2193601966,Tue Jun 16 08:40:49 PDT 2009,NO_QUERY,AmandaMarie1028,Just woke up. Having no school is the best fee...
1599996,4,2193601969,Tue Jun 16 08:40:49 PDT 2009,NO_QUERY,TheWDBoards,TheWDB.com - Very cool to hear old Walt interv...
1599997,4,2193601991,Tue Jun 16 08:40:49 PDT 2009,NO_QUERY,bpbabe,Are you ready for your MoJo Makeover? Ask me f...
1599998,4,2193602064,Tue Jun 16 08:40:49 PDT 2009,NO_QUERY,tinydiamondz,Happy 38th Birthday to my boo of alll time!!! ...


In [42]:
with pd.option_context('display.max_colwidth', None):
    print(data.iloc[0, :])

target                                                                                                                      0
id                                                                                                                 1467810369
date                                                                                             Mon Apr 06 22:19:45 PDT 2009
flag                                                                                                                 NO_QUERY
user                                                                                                          _TheSpecialOne_
text      @switchfoot http://twitpic.com/2y1zl - Awww, that's a bummer.  You shoulda got David Carr of Third Day to do it. ;D
Name: 0, dtype: object


In [43]:
import re
data['text'] = data['text'].apply(lambda x: re.sub(r'\S*@\S*\s?', '', x))
data['text'] = data['text'].apply(lambda x: re.sub(r'http\S+', '', x))
data['text'] = data['text'].apply(lambda x: re.sub(r'[^\w\s]', '', x))
data['text'] = data['text'].apply(lambda x: re.sub(r'\b\w\b', '', x))
data['text'] = data['text'].apply(lambda x: re.sub(r'\d', '', x))
data['text'] = data['text'].apply(lambda x: re.sub(r'\s+', ' ', x))
data['text'] = data['text'].apply(lambda x: x.lower())

In [44]:
with pd.option_context('display.max_colwidth', None):
    print(data.iloc[0, :])

target                                                                       0
id                                                                  1467810369
date                                              Mon Apr 06 22:19:45 PDT 2009
flag                                                                  NO_QUERY
user                                                           _TheSpecialOne_
text       awww thats bummer you shoulda got david carr of third day to do it 
Name: 0, dtype: object


In [45]:
df_neg = data[data['target']== 0].sample(5000)
df_pos = data[data['target']== 4].sample(5000)
df_pos['target'] = 1
liste_concat = [df_neg, df_pos]
df_sample = pd.concat([df_neg, df_pos], ignore_index=True)
df_sample = df_sample.sample(frac=1).reset_index(drop=True)
df_sample

Unnamed: 0,target,id,date,flag,user,text
0,0,2186835492,Mon Jun 15 19:24:24 PDT 2009,NO_QUERY,AmayaLove,still sitting under the dryer my neck hurts
1,1,2189647623,Tue Jun 16 00:32:41 PDT 2009,NO_QUERY,ronga,this is my nightmare even tho only have posts
2,0,1965889776,Fri May 29 16:43:56 PDT 2009,NO_QUERY,vvvracer,black is good tight or should say too tight no...
3,0,2056256086,Sat Jun 06 10:48:45 PDT 2009,NO_QUERY,Kayerodriguez,takes forever for everybody to get ready
4,0,1564138266,Mon Apr 20 01:58:22 PDT 2009,NO_QUERY,besz,omg all the la bad weather aura is trickling i...
...,...,...,...,...,...,...
9995,1,2001820518,Tue Jun 02 02:11:10 PDT 2009,NO_QUERY,meganthevegan,pounds are soo over rated think inches and bod...
9996,0,2065706900,Sun Jun 07 09:08:02 PDT 2009,NO_QUERY,meghanlynx3,im going to miss dancing this summer
9997,0,2203443747,Tue Jun 16 23:45:12 PDT 2009,NO_QUERY,alicerydon,need hugand less cynicism its making me depre...
9998,0,2067517993,Sun Jun 07 12:29:15 PDT 2009,NO_QUERY,laurelfairy,ugh have an upset stomachugh no feel good


In [46]:
sample_df = './Data/df_5000.csv'
df_sample.to_csv(sample_df, index=False)

In [47]:
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import TensorDataset

In [48]:
data = pd.read_csv("./data/df_5000.csv")
data = data.dropna(subset=['text'])

In [49]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
max_len = 128



In [50]:
def encode_text(texts, tokenizer, max_len):
    input_ids = []
    attention_masks = []

    for text in texts:
        encoded_dict = tokenizer.encode_plus(
                            text,
                            add_special_tokens=True,
                            max_length=max_len,
                            padding='max_length',
                            return_attention_mask=True,
                            return_tensors='pt',
                       )
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

In [51]:
input_ids, attention_masks = encode_text(data['text'], tokenizer, max_len)
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, data['target'].values, test_size=0.1, random_state=42)
train_masks, validation_masks, _, _ = train_test_split(attention_masks, input_ids, test_size=0.1, random_state=42)

In [52]:
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)

train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
eval_dataset = TensorDataset(validation_inputs, validation_masks, validation_labels)

  train_inputs = torch.tensor(train_inputs)
  validation_inputs = torch.tensor(validation_inputs)
  train_masks = torch.tensor(train_masks)
  validation_masks = torch.tensor(validation_masks)


In [53]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
def my_data_collator(features):
    batch = {}
    batch['input_ids'] = torch.stack([f[0] for f in features])
    batch['attention_mask'] = torch.stack([f[1] for f in features])
    batch['labels'] = torch.tensor([f[2] for f in features])
    return batch

In [55]:
import mlflow
import mlflow.pytorch
from transformers import Trainer, TrainingArguments, TrainerCallback
from sklearn.metrics import accuracy_score, f1_score, roc_curve, auc

In [56]:
training_args = TrainingArguments(
    output_dir='./models/bert-base',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    logging_dir='./models/logs_bert',
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    warmup_steps=0,
    weight_decay=0.01,
    seed=42,
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1_score",
    greater_is_better=True
)

In [57]:
def compute_metrics(eval_predictions):
    labels = eval_predictions.label_ids
    preds = eval_predictions.predictions.argmax(-1)
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    
    fpr, tpr, _ = roc_curve(labels, preds, pos_label=1)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = {:.2f})'.format(roc_auc))
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True)
    roc_curve_path = "roc_bert-trained.png"
    plt.savefig(roc_curve_path)
    plt.close()
    
    return {
        'eval_accuracy': accuracy,
        'eval_f1_score': f1,
        'eval_roc_auc': roc_auc,
    }

In [58]:
class MLflowCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            for key, value in logs.items():
                mlflow.log_metric(key, value, step=state.global_step)

In [59]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=my_data_collator,
    compute_metrics=compute_metrics,
    callbacks=[MLflowCallback()]
)

In [60]:
mlflow.set_experiment("bert-base-experiment")

mlflow.start_run()
trainer.train()
mlflow.end_run()

  1%|          | 10/843 [02:23<3:18:35, 14.30s/it]

{'loss': 0.6789, 'grad_norm': 5.204065799713135, 'learning_rate': 1.9762752075919338e-05, 'epoch': 0.04}


  2%|▏         | 20/843 [04:45<3:16:05, 14.30s/it]

{'loss': 0.6661, 'grad_norm': 4.287063121795654, 'learning_rate': 1.952550415183867e-05, 'epoch': 0.07}


  4%|▎         | 30/843 [06:59<2:59:36, 13.26s/it]

{'loss': 0.6445, 'grad_norm': 5.040689468383789, 'learning_rate': 1.9288256227758007e-05, 'epoch': 0.11}


  5%|▍         | 40/843 [09:15<2:57:46, 13.28s/it]

{'loss': 0.6325, 'grad_norm': 4.214048385620117, 'learning_rate': 1.9051008303677344e-05, 'epoch': 0.14}


  6%|▌         | 50/843 [11:34<3:04:04, 13.93s/it]

{'loss': 0.5903, 'grad_norm': 4.991639137268066, 'learning_rate': 1.881376037959668e-05, 'epoch': 0.18}


  7%|▋         | 60/843 [13:54<3:03:11, 14.04s/it]

{'loss': 0.5489, 'grad_norm': 7.306921005249023, 'learning_rate': 1.8576512455516017e-05, 'epoch': 0.21}


  8%|▊         | 70/843 [16:08<2:50:27, 13.23s/it]

{'loss': 0.5018, 'grad_norm': 6.073538780212402, 'learning_rate': 1.8339264531435353e-05, 'epoch': 0.25}


  9%|▉         | 80/843 [18:31<3:02:40, 14.37s/it]

{'loss': 0.5161, 'grad_norm': 5.415849685668945, 'learning_rate': 1.8102016607354686e-05, 'epoch': 0.28}


 11%|█         | 90/843 [21:00<3:06:25, 14.85s/it]

{'loss': 0.5043, 'grad_norm': 4.5463948249816895, 'learning_rate': 1.7864768683274022e-05, 'epoch': 0.32}


 12%|█▏        | 100/843 [23:28<3:06:18, 15.04s/it]

{'loss': 0.5106, 'grad_norm': 6.733071327209473, 'learning_rate': 1.762752075919336e-05, 'epoch': 0.36}



 12%|█▏        | 100/843 [25:21<3:06:18, 15.04s/it]

{'eval_accuracy': 0.7837837837837838, 'eval_f1_score': 0.7836437776937356, 'eval_roc_auc': 0.7842461634675518, 'eval_loss': 0.46684983372688293, 'eval_runtime': 112.7906, 'eval_samples_per_second': 8.857, 'eval_steps_per_second': 0.142, 'epoch': 0.36}


 13%|█▎        | 110/843 [27:48<3:07:56, 15.38s/it] 

{'loss': 0.5032, 'grad_norm': 3.864962339401245, 'learning_rate': 1.7390272835112695e-05, 'epoch': 0.39}


 14%|█▍        | 120/843 [30:11<2:55:19, 14.55s/it]

{'loss': 0.4847, 'grad_norm': 6.687708854675293, 'learning_rate': 1.715302491103203e-05, 'epoch': 0.43}


 15%|█▌        | 130/843 [32:37<2:51:17, 14.41s/it]

{'loss': 0.4561, 'grad_norm': 5.9492621421813965, 'learning_rate': 1.6915776986951368e-05, 'epoch': 0.46}


 17%|█▋        | 140/843 [35:13<3:07:29, 16.00s/it]

{'loss': 0.461, 'grad_norm': 7.0328545570373535, 'learning_rate': 1.66785290628707e-05, 'epoch': 0.5}


 18%|█▊        | 150/843 [37:42<2:52:34, 14.94s/it]

{'loss': 0.4444, 'grad_norm': 8.17874526977539, 'learning_rate': 1.6441281138790037e-05, 'epoch': 0.53}


 19%|█▉        | 160/843 [40:14<2:49:00, 14.85s/it]

{'loss': 0.5135, 'grad_norm': 4.680784225463867, 'learning_rate': 1.620403321470937e-05, 'epoch': 0.57}


 20%|██        | 170/843 [42:41<2:42:28, 14.49s/it]

{'loss': 0.4927, 'grad_norm': 4.038048267364502, 'learning_rate': 1.5966785290628707e-05, 'epoch': 0.6}


 21%|██▏       | 180/843 [45:03<2:37:41, 14.27s/it]

{'loss': 0.4375, 'grad_norm': 6.295567035675049, 'learning_rate': 1.5729537366548043e-05, 'epoch': 0.64}


 23%|██▎       | 190/843 [47:34<2:43:47, 15.05s/it]

{'loss': 0.4404, 'grad_norm': 6.281229496002197, 'learning_rate': 1.549228944246738e-05, 'epoch': 0.68}


 24%|██▎       | 200/843 [50:04<2:47:59, 15.68s/it]

{'loss': 0.443, 'grad_norm': 5.429309368133545, 'learning_rate': 1.5255041518386714e-05, 'epoch': 0.71}



 24%|██▎       | 200/843 [51:58<2:47:59, 15.68s/it]

{'eval_accuracy': 0.8028028028028028, 'eval_f1_score': 0.802247181084095, 'eval_roc_auc': 0.8036753740318469, 'eval_loss': 0.4451714754104614, 'eval_runtime': 114.856, 'eval_samples_per_second': 8.698, 'eval_steps_per_second': 0.139, 'epoch': 0.71}


 25%|██▍       | 210/843 [54:29<2:47:15, 15.85s/it]

{'loss': 0.4675, 'grad_norm': 4.611215591430664, 'learning_rate': 1.5017793594306052e-05, 'epoch': 0.75}


 26%|██▌       | 220/843 [56:56<2:33:47, 14.81s/it]

{'loss': 0.4812, 'grad_norm': 6.203141212463379, 'learning_rate': 1.4780545670225385e-05, 'epoch': 0.78}


 27%|██▋       | 230/843 [59:21<2:32:27, 14.92s/it]

{'loss': 0.4626, 'grad_norm': 4.973392009735107, 'learning_rate': 1.4543297746144722e-05, 'epoch': 0.82}


 28%|██▊       | 240/843 [1:01:48<2:25:15, 14.45s/it]

{'loss': 0.4501, 'grad_norm': 5.431426525115967, 'learning_rate': 1.4306049822064058e-05, 'epoch': 0.85}


 30%|██▉       | 250/843 [1:04:15<2:25:15, 14.70s/it]

{'loss': 0.4632, 'grad_norm': 4.297867298126221, 'learning_rate': 1.4068801897983393e-05, 'epoch': 0.89}


 31%|███       | 260/843 [1:06:49<2:26:55, 15.12s/it]

{'loss': 0.4505, 'grad_norm': 5.272849082946777, 'learning_rate': 1.383155397390273e-05, 'epoch': 0.93}


 32%|███▏      | 270/843 [1:09:15<2:24:02, 15.08s/it]

{'loss': 0.3489, 'grad_norm': 6.241794109344482, 'learning_rate': 1.3594306049822066e-05, 'epoch': 0.96}


 33%|███▎      | 280/843 [1:11:46<2:19:37, 14.88s/it]

{'loss': 0.428, 'grad_norm': 7.516247749328613, 'learning_rate': 1.33570581257414e-05, 'epoch': 1.0}


 34%|███▍      | 290/843 [1:14:09<2:15:09, 14.66s/it]

{'loss': 0.4063, 'grad_norm': 6.321103572845459, 'learning_rate': 1.3119810201660737e-05, 'epoch': 1.03}


 36%|███▌      | 300/843 [1:16:36<2:13:44, 14.78s/it]

{'loss': 0.3612, 'grad_norm': 4.46144962310791, 'learning_rate': 1.2882562277580073e-05, 'epoch': 1.07}



 36%|███▌      | 300/843 [1:18:31<2:13:44, 14.78s/it]

{'eval_accuracy': 0.8028028028028028, 'eval_f1_score': 0.8020644168792317, 'eval_roc_auc': 0.8037956415067109, 'eval_loss': 0.4314257502555847, 'eval_runtime': 114.7941, 'eval_samples_per_second': 8.703, 'eval_steps_per_second': 0.139, 'epoch': 1.07}


 37%|███▋      | 310/843 [1:21:15<2:35:45, 17.53s/it]

{'loss': 0.3791, 'grad_norm': 6.098171710968018, 'learning_rate': 1.2645314353499408e-05, 'epoch': 1.1}


 38%|███▊      | 320/843 [1:23:43<2:09:18, 14.84s/it]

{'loss': 0.3262, 'grad_norm': 5.810633659362793, 'learning_rate': 1.2408066429418744e-05, 'epoch': 1.14}


 39%|███▉      | 330/843 [1:26:13<2:05:24, 14.67s/it]

{'loss': 0.3461, 'grad_norm': 5.272178649902344, 'learning_rate': 1.217081850533808e-05, 'epoch': 1.17}


 40%|████      | 340/843 [1:28:44<2:05:39, 14.99s/it]

{'loss': 0.3196, 'grad_norm': 5.004500389099121, 'learning_rate': 1.1933570581257414e-05, 'epoch': 1.21}


 42%|████▏     | 350/843 [1:31:20<2:09:29, 15.76s/it]

{'loss': 0.3518, 'grad_norm': 5.152387619018555, 'learning_rate': 1.169632265717675e-05, 'epoch': 1.25}


 43%|████▎     | 360/843 [1:33:48<1:57:37, 14.61s/it]

{'loss': 0.3653, 'grad_norm': 5.33811616897583, 'learning_rate': 1.1459074733096086e-05, 'epoch': 1.28}


 44%|████▍     | 370/843 [1:36:22<2:04:15, 15.76s/it]

{'loss': 0.3703, 'grad_norm': 6.044694900512695, 'learning_rate': 1.1221826809015421e-05, 'epoch': 1.32}


 45%|████▌     | 380/843 [1:38:49<1:52:26, 14.57s/it]

{'loss': 0.3508, 'grad_norm': 6.112698554992676, 'learning_rate': 1.0984578884934757e-05, 'epoch': 1.35}


 46%|████▋     | 390/843 [1:41:23<1:54:05, 15.11s/it]

{'loss': 0.3397, 'grad_norm': 3.4489052295684814, 'learning_rate': 1.0747330960854094e-05, 'epoch': 1.39}


 47%|████▋     | 400/843 [1:43:52<1:50:18, 14.94s/it]

{'loss': 0.3219, 'grad_norm': 7.227147102355957, 'learning_rate': 1.0510083036773429e-05, 'epoch': 1.42}



 47%|████▋     | 400/843 [1:45:48<1:50:18, 14.94s/it]

{'eval_accuracy': 0.7897897897897898, 'eval_f1_score': 0.7870480104250037, 'eval_roc_auc': 0.7881488430268918, 'eval_loss': 0.48931199312210083, 'eval_runtime': 115.4121, 'eval_samples_per_second': 8.656, 'eval_steps_per_second': 0.139, 'epoch': 1.42}


 49%|████▊     | 410/843 [1:48:24<1:58:00, 16.35s/it]

{'loss': 0.3845, 'grad_norm': 9.887603759765625, 'learning_rate': 1.0272835112692765e-05, 'epoch': 1.46}


 50%|████▉     | 420/843 [1:50:56<1:49:17, 15.50s/it]

{'loss': 0.3081, 'grad_norm': 6.882087230682373, 'learning_rate': 1.0035587188612101e-05, 'epoch': 1.49}


 51%|█████     | 430/843 [1:53:27<1:42:47, 14.93s/it]

{'loss': 0.3678, 'grad_norm': 5.9178290367126465, 'learning_rate': 9.798339264531436e-06, 'epoch': 1.53}


 52%|█████▏    | 440/843 [1:55:56<1:41:25, 15.10s/it]

{'loss': 0.3065, 'grad_norm': 6.583693027496338, 'learning_rate': 9.561091340450772e-06, 'epoch': 1.57}


 53%|█████▎    | 450/843 [1:58:24<1:36:43, 14.77s/it]

{'loss': 0.3329, 'grad_norm': 4.873513698577881, 'learning_rate': 9.323843416370107e-06, 'epoch': 1.6}


 55%|█████▍    | 460/843 [2:00:56<1:34:27, 14.80s/it]

{'loss': 0.3133, 'grad_norm': 4.64373254776001, 'learning_rate': 9.086595492289444e-06, 'epoch': 1.64}


 56%|█████▌    | 470/843 [2:03:25<1:29:53, 14.46s/it]

{'loss': 0.2942, 'grad_norm': 8.172688484191895, 'learning_rate': 8.84934756820878e-06, 'epoch': 1.67}


 57%|█████▋    | 480/843 [2:05:59<1:35:34, 15.80s/it]

{'loss': 0.3078, 'grad_norm': 4.932862281799316, 'learning_rate': 8.612099644128115e-06, 'epoch': 1.71}


 58%|█████▊    | 490/843 [2:08:27<1:24:17, 14.33s/it]

{'loss': 0.373, 'grad_norm': 8.033852577209473, 'learning_rate': 8.37485172004745e-06, 'epoch': 1.74}


 59%|█████▉    | 500/843 [2:10:52<1:22:11, 14.38s/it]

{'loss': 0.2888, 'grad_norm': 6.702187538146973, 'learning_rate': 8.137603795966786e-06, 'epoch': 1.78}



 59%|█████▉    | 500/843 [2:12:47<1:22:11, 14.38s/it]

{'eval_accuracy': 0.8058058058058059, 'eval_f1_score': 0.8051707219125229, 'eval_roc_auc': 0.8050103430028382, 'eval_loss': 0.44774115085601807, 'eval_runtime': 115.3064, 'eval_samples_per_second': 8.664, 'eval_steps_per_second': 0.139, 'epoch': 1.78}


 60%|██████    | 510/843 [2:15:13<1:27:47, 15.82s/it]

{'loss': 0.372, 'grad_norm': 5.82549524307251, 'learning_rate': 7.900355871886122e-06, 'epoch': 1.81}


 62%|██████▏   | 520/843 [2:17:44<1:21:23, 15.12s/it]

{'loss': 0.3111, 'grad_norm': 4.385293006896973, 'learning_rate': 7.663107947805457e-06, 'epoch': 1.85}


 63%|██████▎   | 530/843 [2:20:19<1:18:26, 15.04s/it]

{'loss': 0.3756, 'grad_norm': 6.554511070251465, 'learning_rate': 7.425860023724793e-06, 'epoch': 1.89}


 64%|██████▍   | 540/843 [2:22:52<1:15:52, 15.02s/it]

{'loss': 0.3506, 'grad_norm': 10.864510536193848, 'learning_rate': 7.188612099644129e-06, 'epoch': 1.92}


 65%|██████▌   | 550/843 [2:25:21<1:13:16, 15.00s/it]

{'loss': 0.3506, 'grad_norm': 5.636012554168701, 'learning_rate': 6.951364175563464e-06, 'epoch': 1.96}


 66%|██████▋   | 560/843 [2:27:51<1:10:48, 15.01s/it]

{'loss': 0.3357, 'grad_norm': 7.237118244171143, 'learning_rate': 6.7141162514828e-06, 'epoch': 1.99}


 68%|██████▊   | 570/843 [2:30:14<1:04:48, 14.25s/it]

{'loss': 0.3053, 'grad_norm': 7.175760269165039, 'learning_rate': 6.476868327402136e-06, 'epoch': 2.03}


 69%|██████▉   | 580/843 [2:32:46<1:07:37, 15.43s/it]

{'loss': 0.2712, 'grad_norm': 10.148428916931152, 'learning_rate': 6.239620403321471e-06, 'epoch': 2.06}


 70%|██████▉   | 590/843 [2:35:12<1:01:52, 14.67s/it]

{'loss': 0.2791, 'grad_norm': 7.762922763824463, 'learning_rate': 6.0023724792408065e-06, 'epoch': 2.1}


 71%|███████   | 600/843 [2:37:46<1:00:41, 14.98s/it]

{'loss': 0.2185, 'grad_norm': 3.5749611854553223, 'learning_rate': 5.765124555160143e-06, 'epoch': 2.14}



 71%|███████   | 600/843 [2:39:42<1:00:41, 14.98s/it]

{'eval_accuracy': 0.8028028028028028, 'eval_f1_score': 0.8027830356617115, 'eval_roc_auc': 0.8026831673642181, 'eval_loss': 0.46547678112983704, 'eval_runtime': 115.8458, 'eval_samples_per_second': 8.624, 'eval_steps_per_second': 0.138, 'epoch': 2.14}


 72%|███████▏  | 610/843 [2:42:14<1:03:10, 16.27s/it]

{'loss': 0.2401, 'grad_norm': 5.701139450073242, 'learning_rate': 5.5278766310794785e-06, 'epoch': 2.17}


 74%|███████▎  | 620/843 [2:44:40<54:15, 14.60s/it]  

{'loss': 0.2495, 'grad_norm': 3.870624542236328, 'learning_rate': 5.290628706998814e-06, 'epoch': 2.21}


 75%|███████▍  | 630/843 [2:47:10<53:23, 15.04s/it]

{'loss': 0.2139, 'grad_norm': 7.687706470489502, 'learning_rate': 5.05338078291815e-06, 'epoch': 2.24}


 76%|███████▌  | 640/843 [2:49:38<50:58, 15.07s/it]

{'loss': 0.3018, 'grad_norm': 6.998437404632568, 'learning_rate': 4.816132858837486e-06, 'epoch': 2.28}


 77%|███████▋  | 650/843 [2:52:15<52:05, 16.19s/it]

{'loss': 0.2264, 'grad_norm': 5.22311544418335, 'learning_rate': 4.5788849347568215e-06, 'epoch': 2.31}


 78%|███████▊  | 660/843 [2:54:43<44:19, 14.53s/it]

{'loss': 0.2423, 'grad_norm': 10.229194641113281, 'learning_rate': 4.341637010676157e-06, 'epoch': 2.35}


 79%|███████▉  | 670/843 [2:57:08<42:34, 14.76s/it]

{'loss': 0.2729, 'grad_norm': 4.209497928619385, 'learning_rate': 4.104389086595493e-06, 'epoch': 2.38}


 81%|████████  | 680/843 [2:59:31<39:15, 14.45s/it]

{'loss': 0.2852, 'grad_norm': 10.558135032653809, 'learning_rate': 3.867141162514828e-06, 'epoch': 2.42}


 82%|████████▏ | 690/843 [3:02:04<38:42, 15.18s/it]

{'loss': 0.2241, 'grad_norm': 8.01279354095459, 'learning_rate': 3.629893238434164e-06, 'epoch': 2.46}


 83%|████████▎ | 700/843 [3:04:28<33:30, 14.06s/it]

{'loss': 0.2531, 'grad_norm': 7.144690990447998, 'learning_rate': 3.3926453143535e-06, 'epoch': 2.49}



 83%|████████▎ | 700/843 [3:06:23<33:30, 14.06s/it]

{'eval_accuracy': 0.8098098098098098, 'eval_f1_score': 0.8098201014867681, 'eval_roc_auc': 0.8098571222398614, 'eval_loss': 0.4593583941459656, 'eval_runtime': 115.435, 'eval_samples_per_second': 8.654, 'eval_steps_per_second': 0.139, 'epoch': 2.49}


 84%|████████▍ | 710/843 [3:09:00<36:29, 16.46s/it]  

{'loss': 0.2424, 'grad_norm': 7.386525630950928, 'learning_rate': 3.155397390272835e-06, 'epoch': 2.53}


 85%|████████▌ | 720/843 [3:11:32<30:37, 14.94s/it]

{'loss': 0.2394, 'grad_norm': 5.884039402008057, 'learning_rate': 2.918149466192171e-06, 'epoch': 2.56}


 87%|████████▋ | 730/843 [3:13:58<27:47, 14.76s/it]

{'loss': 0.2061, 'grad_norm': 7.899090766906738, 'learning_rate': 2.680901542111507e-06, 'epoch': 2.6}


 88%|████████▊ | 740/843 [3:16:21<24:52, 14.49s/it]

{'loss': 0.1999, 'grad_norm': 5.784512996673584, 'learning_rate': 2.4436536180308423e-06, 'epoch': 2.63}


 89%|████████▉ | 750/843 [3:18:52<23:08, 14.94s/it]

{'loss': 0.2355, 'grad_norm': 4.666039943695068, 'learning_rate': 2.2064056939501782e-06, 'epoch': 2.67}


 90%|█████████ | 760/843 [3:21:19<20:17, 14.66s/it]

{'loss': 0.2167, 'grad_norm': 8.15806770324707, 'learning_rate': 1.9691577698695138e-06, 'epoch': 2.7}


 91%|█████████▏| 770/843 [3:23:49<18:07, 14.89s/it]

{'loss': 0.2104, 'grad_norm': 9.864715576171875, 'learning_rate': 1.7319098457888495e-06, 'epoch': 2.74}


 93%|█████████▎| 780/843 [3:26:18<15:44, 14.99s/it]

{'loss': 0.2611, 'grad_norm': 5.831902503967285, 'learning_rate': 1.494661921708185e-06, 'epoch': 2.78}


 94%|█████████▎| 790/843 [3:28:41<12:38, 14.32s/it]

{'loss': 0.3231, 'grad_norm': 6.469383239746094, 'learning_rate': 1.257413997627521e-06, 'epoch': 2.81}


 95%|█████████▍| 800/843 [3:31:06<10:08, 14.15s/it]

{'loss': 0.2967, 'grad_norm': 11.555684089660645, 'learning_rate': 1.0201660735468566e-06, 'epoch': 2.85}



 95%|█████████▍| 800/843 [3:33:03<10:08, 14.15s/it]

{'eval_accuracy': 0.8078078078078078, 'eval_f1_score': 0.8077835323594977, 'eval_roc_auc': 0.807674267571078, 'eval_loss': 0.479906290769577, 'eval_runtime': 116.58, 'eval_samples_per_second': 8.569, 'eval_steps_per_second': 0.137, 'epoch': 2.85}


 96%|█████████▌| 810/843 [3:35:33<08:43, 15.87s/it]

{'loss': 0.2929, 'grad_norm': 7.132215976715088, 'learning_rate': 7.829181494661923e-07, 'epoch': 2.88}


 97%|█████████▋| 820/843 [3:38:10<06:06, 15.92s/it]

{'loss': 0.2301, 'grad_norm': 3.9687066078186035, 'learning_rate': 5.456702253855279e-07, 'epoch': 2.92}


 98%|█████████▊| 830/843 [3:40:40<03:10, 14.69s/it]

{'loss': 0.2654, 'grad_norm': 8.514796257019043, 'learning_rate': 3.084223013048636e-07, 'epoch': 2.95}


100%|█████████▉| 840/843 [3:43:08<00:44, 14.73s/it]

{'loss': 0.2408, 'grad_norm': 8.040596008300781, 'learning_rate': 7.117437722419929e-08, 'epoch': 2.99}


100%|██████████| 843/843 [3:43:49<00:00, 15.93s/it]

{'train_runtime': 13429.0671, 'train_samples_per_second': 2.007, 'train_steps_per_second': 0.063, 'train_loss': 0.3646720484482436, 'epoch': 3.0}





In [61]:
model_path = "./models/model_bert"
tokenizer_path = "./models/tokenizer_bert"

model.save_pretrained(model_path)
tokenizer.save_pretrained(tokenizer_path)

('./models/tokenizer_bert/tokenizer_config.json',
 './models/tokenizer_bert/special_tokens_map.json',
 './models/tokenizer_bert/vocab.txt',
 './models/tokenizer_bert/added_tokens.json')

# roBERTa

In [66]:
data = pd.read_csv("./data/df_5000.csv")
data = data.dropna(subset=['text'])

In [67]:
from transformers import RobertaTokenizer

tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
max_len = 128



In [68]:
def encode_text(texts, tokenizer, max_len):
    input_ids = []
    attention_masks = []

    for text in texts:
        encoded_dict = tokenizer.encode_plus(
                            text,
                            add_special_tokens=True,
                            max_length=max_len,
                            padding='max_length',
                            return_attention_mask=True,
                            return_tensors='pt',
                       )
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

def encode_text(texts, tokenizer, max_len):
    input_ids = []
    attention_masks = []

    for i, text in enumerate(texts):
        encoded_dict = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

In [71]:
input_ids, attention_masks = encode_text(data['text'], tokenizer, max_len)
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, data['target'].values, test_size=0.1, random_state=42)
train_masks, validation_masks, _, _ = train_test_split(attention_masks, input_ids, test_size=0.1, random_state=42)

In [72]:
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)

train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
eval_dataset = TensorDataset(validation_inputs, validation_masks, validation_labels)

  train_inputs = torch.tensor(train_inputs)
  validation_inputs = torch.tensor(validation_inputs)
  train_masks = torch.tensor(train_masks)
  validation_masks = torch.tensor(validation_masks)


In [73]:
from transformers import RobertaForSequenceClassification

model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=2)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [74]:
def my_data_collator(features):
    batch = {}
    batch['input_ids'] = torch.stack([f[0] for f in features])
    batch['attention_mask'] = torch.stack([f[1] for f in features])
    batch['labels'] = torch.tensor([f[2] for f in features])
    return batch

In [75]:
training_args = TrainingArguments(
    output_dir='./models/roberta-base',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    logging_dir='./models/logs_roberta',
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    warmup_steps=0,
    weight_decay=0.01,
    seed=42,
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1_score",
    greater_is_better=True
)

In [76]:
def compute_metrics(eval_predictions):
    labels = eval_predictions.label_ids
    preds = eval_predictions.predictions.argmax(-1)
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    
    fpr, tpr, _ = roc_curve(labels, preds, pos_label=1)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = {:.2f})'.format(roc_auc))
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True)
    roc_curve_path = "roc_roberta-trained.png"
    plt.savefig(roc_curve_path)
    plt.close()
    
    return {
        'eval_accuracy': accuracy,
        'eval_f1_score': f1,
        'eval_roc_auc': roc_auc,
    }

In [77]:
class MLflowCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            for key, value in logs.items():
                mlflow.log_metric(key, value, step=state.global_step)

In [78]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=my_data_collator,
    compute_metrics=compute_metrics,
    callbacks=[MLflowCallback()]
)

In [79]:
mlflow.set_experiment("roberta-base-experiment")

mlflow.start_run()
trainer.train()
mlflow.end_run()

  1%|          | 10/843 [02:23<3:12:05, 13.84s/it]

{'loss': 0.6965, 'grad_norm': 0.6777755618095398, 'learning_rate': 1.9762752075919338e-05, 'epoch': 0.04}


  2%|▏         | 20/843 [04:36<2:57:22, 12.93s/it]

{'loss': 0.687, 'grad_norm': 1.0352787971496582, 'learning_rate': 1.952550415183867e-05, 'epoch': 0.07}


  4%|▎         | 30/843 [06:53<3:07:48, 13.86s/it]

{'loss': 0.6814, 'grad_norm': 2.3286373615264893, 'learning_rate': 1.9288256227758007e-05, 'epoch': 0.11}


  5%|▍         | 40/843 [09:11<3:01:30, 13.56s/it]

{'loss': 0.6518, 'grad_norm': 7.0107245445251465, 'learning_rate': 1.9051008303677344e-05, 'epoch': 0.14}


  6%|▌         | 50/843 [11:24<2:57:58, 13.47s/it]

{'loss': 0.6212, 'grad_norm': 5.792379856109619, 'learning_rate': 1.881376037959668e-05, 'epoch': 0.18}


  7%|▋         | 60/843 [13:36<2:49:00, 12.95s/it]

{'loss': 0.5403, 'grad_norm': 9.932394027709961, 'learning_rate': 1.8576512455516017e-05, 'epoch': 0.21}


  8%|▊         | 70/843 [15:51<2:58:01, 13.82s/it]

{'loss': 0.5512, 'grad_norm': 11.256462097167969, 'learning_rate': 1.8339264531435353e-05, 'epoch': 0.25}


  9%|▉         | 80/843 [18:05<2:50:04, 13.37s/it]

{'loss': 0.5285, 'grad_norm': 9.882940292358398, 'learning_rate': 1.8102016607354686e-05, 'epoch': 0.28}


 11%|█         | 90/843 [20:21<2:48:06, 13.40s/it]

{'loss': 0.5042, 'grad_norm': 13.092741012573242, 'learning_rate': 1.7864768683274022e-05, 'epoch': 0.32}


 12%|█▏        | 100/843 [22:42<2:51:56, 13.89s/it]

{'loss': 0.5093, 'grad_norm': 23.72054672241211, 'learning_rate': 1.762752075919336e-05, 'epoch': 0.36}



 12%|█▏        | 100/843 [24:33<2:51:56, 13.89s/it]

{'eval_accuracy': 0.7647647647647647, 'eval_f1_score': 0.7607655238588281, 'eval_roc_auc': 0.7667712993697985, 'eval_loss': 0.5473454594612122, 'eval_runtime': 110.7192, 'eval_samples_per_second': 9.023, 'eval_steps_per_second': 0.145, 'epoch': 0.36}


 13%|█▎        | 110/843 [27:01<3:12:34, 15.76s/it] 

{'loss': 0.5309, 'grad_norm': 6.934708118438721, 'learning_rate': 1.7390272835112695e-05, 'epoch': 0.39}


 14%|█▍        | 120/843 [29:24<2:50:07, 14.12s/it]

{'loss': 0.488, 'grad_norm': 11.60666561126709, 'learning_rate': 1.715302491103203e-05, 'epoch': 0.43}


 15%|█▌        | 130/843 [31:48<2:47:41, 14.11s/it]

{'loss': 0.5061, 'grad_norm': 14.096639633178711, 'learning_rate': 1.6915776986951368e-05, 'epoch': 0.46}


 17%|█▋        | 140/843 [34:04<2:37:20, 13.43s/it]

{'loss': 0.4612, 'grad_norm': 11.18200969696045, 'learning_rate': 1.66785290628707e-05, 'epoch': 0.5}


 18%|█▊        | 150/843 [36:25<2:49:07, 14.64s/it]

{'loss': 0.4427, 'grad_norm': 13.073321342468262, 'learning_rate': 1.6441281138790037e-05, 'epoch': 0.53}


 19%|█▉        | 160/843 [38:52<2:43:45, 14.39s/it]

{'loss': 0.4992, 'grad_norm': 14.03603744506836, 'learning_rate': 1.620403321470937e-05, 'epoch': 0.57}


 20%|██        | 170/843 [41:13<2:40:19, 14.29s/it]

{'loss': 0.4877, 'grad_norm': 10.274236679077148, 'learning_rate': 1.5966785290628707e-05, 'epoch': 0.6}


 21%|██▏       | 180/843 [43:33<2:29:27, 13.52s/it]

{'loss': 0.4094, 'grad_norm': 10.197854995727539, 'learning_rate': 1.5729537366548043e-05, 'epoch': 0.64}


 23%|██▎       | 190/843 [45:53<2:32:58, 14.06s/it]

{'loss': 0.4678, 'grad_norm': 19.87918472290039, 'learning_rate': 1.549228944246738e-05, 'epoch': 0.68}


 24%|██▎       | 200/843 [48:13<2:32:43, 14.25s/it]

{'loss': 0.4518, 'grad_norm': 13.995081901550293, 'learning_rate': 1.5255041518386714e-05, 'epoch': 0.71}



 24%|██▎       | 200/843 [50:06<2:32:43, 14.25s/it]

{'eval_accuracy': 0.8118118118118118, 'eval_f1_score': 0.8115426587624293, 'eval_roc_auc': 0.8124609130706691, 'eval_loss': 0.428683876991272, 'eval_runtime': 112.8908, 'eval_samples_per_second': 8.849, 'eval_steps_per_second': 0.142, 'epoch': 0.71}


 25%|██▍       | 210/843 [52:39<2:54:06, 16.50s/it]

{'loss': 0.4272, 'grad_norm': 8.333745002746582, 'learning_rate': 1.5017793594306052e-05, 'epoch': 0.75}


 26%|██▌       | 220/843 [55:05<2:28:40, 14.32s/it]

{'loss': 0.4378, 'grad_norm': 16.992353439331055, 'learning_rate': 1.4780545670225385e-05, 'epoch': 0.78}


 27%|██▋       | 230/843 [57:30<2:31:17, 14.81s/it]

{'loss': 0.4439, 'grad_norm': 15.441617012023926, 'learning_rate': 1.4543297746144722e-05, 'epoch': 0.82}


 28%|██▊       | 240/843 [59:59<2:29:24, 14.87s/it]

{'loss': 0.4692, 'grad_norm': 17.03740119934082, 'learning_rate': 1.4306049822064058e-05, 'epoch': 0.85}


 30%|██▉       | 250/843 [1:02:26<2:27:42, 14.95s/it]

{'loss': 0.4651, 'grad_norm': 8.447989463806152, 'learning_rate': 1.4068801897983393e-05, 'epoch': 0.89}


 31%|███       | 260/843 [1:04:52<2:22:54, 14.71s/it]

{'loss': 0.4137, 'grad_norm': 10.320759773254395, 'learning_rate': 1.383155397390273e-05, 'epoch': 0.93}


 32%|███▏      | 270/843 [1:07:24<2:31:16, 15.84s/it]

{'loss': 0.332, 'grad_norm': 11.852431297302246, 'learning_rate': 1.3594306049822066e-05, 'epoch': 0.96}


 33%|███▎      | 280/843 [1:09:50<2:18:17, 14.74s/it]

{'loss': 0.4578, 'grad_norm': 13.674153327941895, 'learning_rate': 1.33570581257414e-05, 'epoch': 1.0}


 34%|███▍      | 290/843 [1:12:07<2:07:36, 13.85s/it]

{'loss': 0.456, 'grad_norm': 9.730764389038086, 'learning_rate': 1.3119810201660737e-05, 'epoch': 1.03}


 36%|███▌      | 300/843 [1:14:34<2:11:50, 14.57s/it]

{'loss': 0.4118, 'grad_norm': 5.4632673263549805, 'learning_rate': 1.2882562277580073e-05, 'epoch': 1.07}



 36%|███▌      | 300/843 [1:16:28<2:11:50, 14.57s/it]

{'eval_accuracy': 0.8318318318318318, 'eval_f1_score': 0.8318341909043312, 'eval_roc_auc': 0.832004377736085, 'eval_loss': 0.39102914929389954, 'eval_runtime': 114.3291, 'eval_samples_per_second': 8.738, 'eval_steps_per_second': 0.14, 'epoch': 1.07}


 37%|███▋      | 310/843 [1:18:55<2:17:34, 15.49s/it]

{'loss': 0.4182, 'grad_norm': 11.368463516235352, 'learning_rate': 1.2645314353499408e-05, 'epoch': 1.1}


 38%|███▊      | 320/843 [1:21:19<2:06:51, 14.55s/it]

{'loss': 0.3065, 'grad_norm': 9.561685562133789, 'learning_rate': 1.2408066429418744e-05, 'epoch': 1.14}


 39%|███▉      | 330/843 [1:23:45<2:03:58, 14.50s/it]

{'loss': 0.3977, 'grad_norm': 11.988395690917969, 'learning_rate': 1.217081850533808e-05, 'epoch': 1.17}


 40%|████      | 340/843 [1:26:08<1:57:42, 14.04s/it]

{'loss': 0.3436, 'grad_norm': 9.007534980773926, 'learning_rate': 1.1933570581257414e-05, 'epoch': 1.21}


 42%|████▏     | 350/843 [1:28:31<1:56:50, 14.22s/it]

{'loss': 0.4068, 'grad_norm': 8.532970428466797, 'learning_rate': 1.169632265717675e-05, 'epoch': 1.25}


 43%|████▎     | 360/843 [1:30:56<1:57:19, 14.57s/it]

{'loss': 0.379, 'grad_norm': 9.266630172729492, 'learning_rate': 1.1459074733096086e-05, 'epoch': 1.28}


 44%|████▍     | 370/843 [1:33:21<1:51:07, 14.10s/it]

{'loss': 0.3866, 'grad_norm': 8.347237586975098, 'learning_rate': 1.1221826809015421e-05, 'epoch': 1.32}


 45%|████▌     | 380/843 [1:35:44<1:52:21, 14.56s/it]

{'loss': 0.3923, 'grad_norm': 12.892433166503906, 'learning_rate': 1.0984578884934757e-05, 'epoch': 1.35}


 46%|████▋     | 390/843 [1:38:16<1:59:04, 15.77s/it]

{'loss': 0.3389, 'grad_norm': 8.26994514465332, 'learning_rate': 1.0747330960854094e-05, 'epoch': 1.39}


 47%|████▋     | 400/843 [1:40:39<1:43:31, 14.02s/it]

{'loss': 0.3538, 'grad_norm': 16.288429260253906, 'learning_rate': 1.0510083036773429e-05, 'epoch': 1.42}



 47%|████▋     | 400/843 [1:42:32<1:43:31, 14.02s/it]

{'eval_accuracy': 0.7917917917917918, 'eval_f1_score': 0.789360381608771, 'eval_roc_auc': 0.7902414970895272, 'eval_loss': 0.4981069266796112, 'eval_runtime': 113.8363, 'eval_samples_per_second': 8.776, 'eval_steps_per_second': 0.141, 'epoch': 1.42}


 49%|████▊     | 410/843 [1:45:01<1:53:55, 15.79s/it]

{'loss': 0.4324, 'grad_norm': 12.664071083068848, 'learning_rate': 1.0272835112692765e-05, 'epoch': 1.46}


 50%|████▉     | 420/843 [1:47:23<1:38:49, 14.02s/it]

{'loss': 0.33, 'grad_norm': 12.60415267944336, 'learning_rate': 1.0035587188612101e-05, 'epoch': 1.49}


 51%|█████     | 430/843 [1:49:49<1:38:48, 14.35s/it]

{'loss': 0.3727, 'grad_norm': 15.280821800231934, 'learning_rate': 9.798339264531436e-06, 'epoch': 1.53}


 52%|█████▏    | 440/843 [1:52:15<1:40:35, 14.98s/it]

{'loss': 0.3587, 'grad_norm': 8.308649063110352, 'learning_rate': 9.561091340450772e-06, 'epoch': 1.57}


 53%|█████▎    | 450/843 [1:54:47<1:37:02, 14.82s/it]

{'loss': 0.3407, 'grad_norm': 9.496369361877441, 'learning_rate': 9.323843416370107e-06, 'epoch': 1.6}


 55%|█████▍    | 460/843 [1:57:15<1:33:58, 14.72s/it]

{'loss': 0.2855, 'grad_norm': 16.663564682006836, 'learning_rate': 9.086595492289444e-06, 'epoch': 1.64}


 56%|█████▌    | 470/843 [1:59:48<1:35:05, 15.30s/it]

{'loss': 0.3197, 'grad_norm': 14.33398723602295, 'learning_rate': 8.84934756820878e-06, 'epoch': 1.67}


 57%|█████▋    | 480/843 [2:02:17<1:30:31, 14.96s/it]

{'loss': 0.362, 'grad_norm': 5.565577983856201, 'learning_rate': 8.612099644128115e-06, 'epoch': 1.71}


 58%|█████▊    | 490/843 [2:04:50<1:31:56, 15.63s/it]

{'loss': 0.3731, 'grad_norm': 11.0617094039917, 'learning_rate': 8.37485172004745e-06, 'epoch': 1.74}


 59%|█████▉    | 500/843 [2:07:26<1:27:56, 15.38s/it]

{'loss': 0.3287, 'grad_norm': 20.22356605529785, 'learning_rate': 8.137603795966786e-06, 'epoch': 1.78}



 59%|█████▉    | 500/843 [2:09:23<1:27:56, 15.38s/it]

{'eval_accuracy': 0.8078078078078078, 'eval_f1_score': 0.8069880589464098, 'eval_roc_auc': 0.8068925289844614, 'eval_loss': 0.4548109173774719, 'eval_runtime': 116.5759, 'eval_samples_per_second': 8.57, 'eval_steps_per_second': 0.137, 'epoch': 1.78}


 60%|██████    | 510/843 [2:12:00<1:34:41, 17.06s/it]

{'loss': 0.3952, 'grad_norm': 20.607398986816406, 'learning_rate': 7.900355871886122e-06, 'epoch': 1.81}


 62%|██████▏   | 520/843 [2:14:29<1:19:19, 14.73s/it]

{'loss': 0.2922, 'grad_norm': 8.793346405029297, 'learning_rate': 7.663107947805457e-06, 'epoch': 1.85}


 63%|██████▎   | 530/843 [2:16:56<1:14:21, 14.25s/it]

{'loss': 0.3861, 'grad_norm': 12.750994682312012, 'learning_rate': 7.425860023724793e-06, 'epoch': 1.89}


 64%|██████▍   | 540/843 [2:19:23<1:13:56, 14.64s/it]

{'loss': 0.3782, 'grad_norm': 25.37036895751953, 'learning_rate': 7.188612099644129e-06, 'epoch': 1.92}


 65%|██████▌   | 550/843 [2:21:53<1:13:34, 15.07s/it]

{'loss': 0.3974, 'grad_norm': 10.068746566772461, 'learning_rate': 6.951364175563464e-06, 'epoch': 1.96}


 66%|██████▋   | 560/843 [2:24:24<1:11:52, 15.24s/it]

{'loss': 0.3525, 'grad_norm': 20.87345314025879, 'learning_rate': 6.7141162514828e-06, 'epoch': 1.99}


 68%|██████▊   | 570/843 [2:26:49<1:06:22, 14.59s/it]

{'loss': 0.3343, 'grad_norm': 13.886741638183594, 'learning_rate': 6.476868327402136e-06, 'epoch': 2.03}


 69%|██████▉   | 580/843 [2:29:19<1:05:11, 14.87s/it]

{'loss': 0.2588, 'grad_norm': 24.944786071777344, 'learning_rate': 6.239620403321471e-06, 'epoch': 2.06}


 70%|██████▉   | 590/843 [2:31:48<1:01:55, 14.69s/it]

{'loss': 0.3071, 'grad_norm': 10.863167762756348, 'learning_rate': 6.0023724792408065e-06, 'epoch': 2.1}


 71%|███████   | 600/843 [2:34:14<57:24, 14.17s/it]  

{'loss': 0.2649, 'grad_norm': 14.140105247497559, 'learning_rate': 5.765124555160143e-06, 'epoch': 2.14}



 71%|███████   | 600/843 [2:36:09<57:24, 14.17s/it]

{'eval_accuracy': 0.8208208208208209, 'eval_f1_score': 0.8208297986241891, 'eval_roc_auc': 0.8208555828161832, 'eval_loss': 0.4433671236038208, 'eval_runtime': 115.0926, 'eval_samples_per_second': 8.68, 'eval_steps_per_second': 0.139, 'epoch': 2.14}


 72%|███████▏  | 610/843 [2:38:45<1:07:46, 17.45s/it]

{'loss': 0.2941, 'grad_norm': 6.9376726150512695, 'learning_rate': 5.5278766310794785e-06, 'epoch': 2.17}


 74%|███████▎  | 620/843 [2:41:15<55:33, 14.95s/it]  

{'loss': 0.2174, 'grad_norm': 7.291998863220215, 'learning_rate': 5.290628706998814e-06, 'epoch': 2.21}


 75%|███████▍  | 630/843 [2:43:40<50:29, 14.22s/it]

{'loss': 0.2181, 'grad_norm': 23.945011138916016, 'learning_rate': 5.05338078291815e-06, 'epoch': 2.24}


 76%|███████▌  | 640/843 [2:46:06<49:41, 14.69s/it]

{'loss': 0.3713, 'grad_norm': 25.103435516357422, 'learning_rate': 4.816132858837486e-06, 'epoch': 2.28}


 77%|███████▋  | 650/843 [2:48:34<47:28, 14.76s/it]

{'loss': 0.2883, 'grad_norm': 18.5189151763916, 'learning_rate': 4.5788849347568215e-06, 'epoch': 2.31}


 78%|███████▊  | 660/843 [2:51:00<44:51, 14.71s/it]

{'loss': 0.2661, 'grad_norm': 5.989052772521973, 'learning_rate': 4.341637010676157e-06, 'epoch': 2.35}


 79%|███████▉  | 670/843 [2:53:29<42:51, 14.86s/it]

{'loss': 0.3079, 'grad_norm': 8.983159065246582, 'learning_rate': 4.104389086595493e-06, 'epoch': 2.38}


 81%|████████  | 680/843 [2:56:07<40:57, 15.08s/it]

{'loss': 0.3039, 'grad_norm': 12.372864723205566, 'learning_rate': 3.867141162514828e-06, 'epoch': 2.42}


 82%|████████▏ | 690/843 [2:58:40<38:44, 15.19s/it]

{'loss': 0.2892, 'grad_norm': 13.05830192565918, 'learning_rate': 3.629893238434164e-06, 'epoch': 2.46}


 83%|████████▎ | 700/843 [3:01:12<35:50, 15.04s/it]

{'loss': 0.281, 'grad_norm': 9.757838249206543, 'learning_rate': 3.3926453143535e-06, 'epoch': 2.49}



 83%|████████▎ | 700/843 [3:03:07<35:50, 15.04s/it]

{'eval_accuracy': 0.8298298298298298, 'eval_f1_score': 0.8298035698431709, 'eval_roc_auc': 0.8301221917544619, 'eval_loss': 0.4442916512489319, 'eval_runtime': 115.0442, 'eval_samples_per_second': 8.684, 'eval_steps_per_second': 0.139, 'epoch': 2.49}


 84%|████████▍ | 710/843 [3:05:43<36:24, 16.42s/it]  

{'loss': 0.3043, 'grad_norm': 15.06360912322998, 'learning_rate': 3.155397390272835e-06, 'epoch': 2.53}


 85%|████████▌ | 720/843 [3:08:10<30:05, 14.68s/it]

{'loss': 0.3127, 'grad_norm': 17.668485641479492, 'learning_rate': 2.918149466192171e-06, 'epoch': 2.56}


 87%|████████▋ | 730/843 [3:10:45<28:31, 15.15s/it]

{'loss': 0.2563, 'grad_norm': 14.702408790588379, 'learning_rate': 2.680901542111507e-06, 'epoch': 2.6}


 88%|████████▊ | 740/843 [3:13:17<25:09, 14.65s/it]

{'loss': 0.2191, 'grad_norm': 12.221165657043457, 'learning_rate': 2.4436536180308423e-06, 'epoch': 2.63}


 89%|████████▉ | 750/843 [3:15:44<23:08, 14.93s/it]

{'loss': 0.263, 'grad_norm': 6.7711358070373535, 'learning_rate': 2.2064056939501782e-06, 'epoch': 2.67}


 90%|█████████ | 760/843 [3:18:14<20:36, 14.90s/it]

{'loss': 0.2433, 'grad_norm': 12.67586898803711, 'learning_rate': 1.9691577698695138e-06, 'epoch': 2.7}


 91%|█████████▏| 770/843 [3:20:46<18:22, 15.10s/it]

{'loss': 0.2662, 'grad_norm': 15.90100383758545, 'learning_rate': 1.7319098457888495e-06, 'epoch': 2.74}


 93%|█████████▎| 780/843 [3:23:15<15:25, 14.69s/it]

{'loss': 0.3253, 'grad_norm': 7.780243873596191, 'learning_rate': 1.494661921708185e-06, 'epoch': 2.78}


 94%|█████████▎| 790/843 [3:25:52<13:26, 15.21s/it]

{'loss': 0.3185, 'grad_norm': 9.452860832214355, 'learning_rate': 1.257413997627521e-06, 'epoch': 2.81}


 95%|█████████▍| 800/843 [3:28:26<11:17, 15.75s/it]

{'loss': 0.3397, 'grad_norm': 16.044979095458984, 'learning_rate': 1.0201660735468566e-06, 'epoch': 2.85}



 95%|█████████▍| 800/843 [3:30:23<11:17, 15.75s/it]

{'eval_accuracy': 0.8268268268268268, 'eval_f1_score': 0.8267486723290097, 'eval_roc_auc': 0.8265622744984846, 'eval_loss': 0.4538688063621521, 'eval_runtime': 116.7521, 'eval_samples_per_second': 8.557, 'eval_steps_per_second': 0.137, 'epoch': 2.85}


 96%|█████████▌| 810/843 [3:32:57<08:55, 16.23s/it]

{'loss': 0.291, 'grad_norm': 18.725431442260742, 'learning_rate': 7.829181494661923e-07, 'epoch': 2.88}


 97%|█████████▋| 820/843 [3:35:26<05:41, 14.83s/it]

{'loss': 0.222, 'grad_norm': 10.88500690460205, 'learning_rate': 5.456702253855279e-07, 'epoch': 2.92}


 98%|█████████▊| 830/843 [3:37:56<03:15, 15.05s/it]

{'loss': 0.3089, 'grad_norm': 16.558542251586914, 'learning_rate': 3.084223013048636e-07, 'epoch': 2.95}


100%|█████████▉| 840/843 [3:40:31<00:46, 15.48s/it]

{'loss': 0.3196, 'grad_norm': 12.470589637756348, 'learning_rate': 7.117437722419929e-08, 'epoch': 2.99}


100%|██████████| 843/843 [3:41:14<00:00, 15.75s/it]

{'train_runtime': 13274.5988, 'train_samples_per_second': 2.031, 'train_steps_per_second': 0.064, 'train_loss': 0.38610544111380796, 'epoch': 3.0}





In [80]:
model_path = "./models/model_roberta"
tokenizer_path = "./models/tokenizer_roberta"

model.save_pretrained(model_path)
tokenizer.save_pretrained(tokenizer_path)

('./models/tokenizer_roberta/tokenizer_config.json',
 './models/tokenizer_roberta/special_tokens_map.json',
 './models/tokenizer_roberta/vocab.json',
 './models/tokenizer_roberta/merges.txt',
 './models/tokenizer_roberta/added_tokens.json')