In [1]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, GlueDataset, default_data_collator
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import mlflow
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = pd.read_csv("../data/train_df.csv")
data = data.dropna(subset=['words'])
data = data.sample(10000)

In [5]:
from transformers import RobertaTokenizer

In [6]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
max_len = 128

In [7]:
def encode_text(texts, tokenizer, max_len):
    input_ids = []
    attention_masks = []

    for text in texts:
        encoded_dict = tokenizer.encode_plus(
                            text,
                            add_special_tokens=True,
                            max_length=max_len,
                            padding='max_length',
                            return_attention_mask=True,
                            return_tensors='pt',
                       )
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

In [9]:
input_ids, attention_masks = encode_text(data['text'], tokenizer, max_len)
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, data['target'].values, test_size=0.1, random_state=42)
train_masks, validation_masks, _, _ = train_test_split(attention_masks, input_ids, test_size=0.1, random_state=42)

In [10]:
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)

train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
eval_dataset = TensorDataset(validation_inputs, validation_masks, validation_labels)

  train_inputs = torch.tensor(train_inputs)
  validation_inputs = torch.tensor(validation_inputs)
  train_masks = torch.tensor(train_masks)
  validation_masks = torch.tensor(validation_masks)


In [11]:
from transformers import RobertaForSequenceClassification

model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=2)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
def my_data_collator(features):
    batch = {}
    batch['input_ids'] = torch.stack([f[0] for f in features])
    batch['attention_mask'] = torch.stack([f[1] for f in features])
    batch['labels'] = torch.tensor([f[2] for f in features])
    return batch

In [13]:
training_args = TrainingArguments(
    output_dir='./results_roberta',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    logging_dir='./logs_roberta',
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    warmup_steps=0,
    weight_decay=0.01,
    seed=42,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True
)

In [14]:
from sklearn.metrics import accuracy_score

def compute_metrics(eval_predictions):
    labels = eval_predictions.label_ids
    preds = eval_predictions.predictions.argmax(-1)
    accuracy = accuracy_score(labels, preds)
    return {
        'eval_accuracy': accuracy,
    }

In [15]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=my_data_collator,
    compute_metrics=compute_metrics
)

In [16]:
trainer.train()

  1%|          | 10/846 [02:50<3:47:52, 16.36s/it]

{'loss': 0.6884, 'grad_norm': 1.1090093851089478, 'learning_rate': 1.9763593380614657e-05, 'epoch': 0.04}


  2%|▏         | 20/846 [05:29<3:42:47, 16.18s/it]

{'loss': 0.6939, 'grad_norm': 0.974604606628418, 'learning_rate': 1.9527186761229316e-05, 'epoch': 0.07}


  4%|▎         | 30/846 [08:03<3:26:20, 15.17s/it]

{'loss': 0.6685, 'grad_norm': 2.414374351501465, 'learning_rate': 1.929078014184397e-05, 'epoch': 0.11}


  5%|▍         | 40/846 [10:42<3:34:06, 15.94s/it]

{'loss': 0.5453, 'grad_norm': 22.607778549194336, 'learning_rate': 1.905437352245863e-05, 'epoch': 0.14}


  6%|▌         | 50/846 [13:24<3:28:44, 15.73s/it]

{'loss': 0.5625, 'grad_norm': 50.974918365478516, 'learning_rate': 1.881796690307329e-05, 'epoch': 0.18}


  7%|▋         | 60/846 [16:05<3:30:37, 16.08s/it]

{'loss': 0.4962, 'grad_norm': 7.707003593444824, 'learning_rate': 1.8581560283687945e-05, 'epoch': 0.21}


  8%|▊         | 70/846 [18:38<3:16:57, 15.23s/it]

{'loss': 0.4909, 'grad_norm': 10.754965782165527, 'learning_rate': 1.83451536643026e-05, 'epoch': 0.25}


  9%|▉         | 80/846 [21:26<3:35:16, 16.86s/it]

{'loss': 0.5002, 'grad_norm': 11.216720581054688, 'learning_rate': 1.810874704491726e-05, 'epoch': 0.28}


 11%|█         | 90/846 [24:03<3:18:59, 15.79s/it]

{'loss': 0.5276, 'grad_norm': 18.873077392578125, 'learning_rate': 1.7872340425531915e-05, 'epoch': 0.32}


 12%|█▏        | 100/846 [26:43<3:15:40, 15.74s/it]

{'loss': 0.5032, 'grad_norm': 10.724949836730957, 'learning_rate': 1.7635933806146574e-05, 'epoch': 0.35}


                                                   
 12%|█▏        | 100/846 [28:37<3:15:40, 15.74s/it]

{'eval_accuracy': 0.8, 'eval_loss': 0.4566268026828766, 'eval_runtime': 114.1802, 'eval_samples_per_second': 8.758, 'eval_steps_per_second': 0.14, 'epoch': 0.35}


 13%|█▎        | 110/846 [31:24<3:40:43, 17.99s/it] 

{'loss': 0.531, 'grad_norm': 8.861882209777832, 'learning_rate': 1.7399527186761233e-05, 'epoch': 0.39}


 14%|█▍        | 120/846 [34:02<3:12:00, 15.87s/it]

{'loss': 0.4344, 'grad_norm': 9.610867500305176, 'learning_rate': 1.716312056737589e-05, 'epoch': 0.43}


 15%|█▌        | 130/846 [36:39<3:08:37, 15.81s/it]

{'loss': 0.4892, 'grad_norm': 10.273508071899414, 'learning_rate': 1.6926713947990544e-05, 'epoch': 0.46}


 17%|█▋        | 140/846 [39:10<2:57:51, 15.12s/it]

{'loss': 0.5077, 'grad_norm': 8.790712356567383, 'learning_rate': 1.6690307328605203e-05, 'epoch': 0.5}


 18%|█▊        | 150/846 [41:47<3:03:13, 15.80s/it]

{'loss': 0.4759, 'grad_norm': 8.421916007995605, 'learning_rate': 1.645390070921986e-05, 'epoch': 0.53}


 19%|█▉        | 160/846 [44:21<2:54:34, 15.27s/it]

{'loss': 0.4829, 'grad_norm': 21.467042922973633, 'learning_rate': 1.6217494089834514e-05, 'epoch': 0.57}


 20%|██        | 170/846 [46:54<2:50:21, 15.12s/it]

{'loss': 0.4292, 'grad_norm': 8.956393241882324, 'learning_rate': 1.5981087470449176e-05, 'epoch': 0.6}


 21%|██▏       | 180/846 [49:22<2:43:27, 14.73s/it]

{'loss': 0.4326, 'grad_norm': 6.977697372436523, 'learning_rate': 1.5744680851063832e-05, 'epoch': 0.64}


 22%|██▏       | 190/846 [52:03<2:58:16, 16.31s/it]

{'loss': 0.5117, 'grad_norm': 21.52483558654785, 'learning_rate': 1.5508274231678487e-05, 'epoch': 0.67}


 24%|██▎       | 200/846 [54:39<2:49:46, 15.77s/it]

{'loss': 0.423, 'grad_norm': 6.012157440185547, 'learning_rate': 1.5271867612293146e-05, 'epoch': 0.71}


                                                   
 24%|██▎       | 200/846 [56:35<2:49:46, 15.77s/it]

{'eval_accuracy': 0.817, 'eval_loss': 0.4185546040534973, 'eval_runtime': 115.5729, 'eval_samples_per_second': 8.653, 'eval_steps_per_second': 0.138, 'epoch': 0.71}


 25%|██▍       | 210/846 [59:13<2:56:03, 16.61s/it]

{'loss': 0.4394, 'grad_norm': 15.731389999389648, 'learning_rate': 1.5035460992907802e-05, 'epoch': 0.74}


 26%|██▌       | 220/846 [1:01:52<2:49:43, 16.27s/it]

{'loss': 0.507, 'grad_norm': 9.592254638671875, 'learning_rate': 1.4799054373522459e-05, 'epoch': 0.78}


 27%|██▋       | 230/846 [1:04:29<2:39:44, 15.56s/it]

{'loss': 0.3866, 'grad_norm': 10.325312614440918, 'learning_rate': 1.4562647754137118e-05, 'epoch': 0.82}


 28%|██▊       | 240/846 [1:07:09<2:38:52, 15.73s/it]

{'loss': 0.4508, 'grad_norm': 17.865659713745117, 'learning_rate': 1.4326241134751775e-05, 'epoch': 0.85}


 30%|██▉       | 250/846 [1:09:40<2:31:11, 15.22s/it]

{'loss': 0.4019, 'grad_norm': 14.540098190307617, 'learning_rate': 1.4089834515366433e-05, 'epoch': 0.89}


 31%|███       | 260/846 [1:12:17<2:31:50, 15.55s/it]

{'loss': 0.4333, 'grad_norm': 8.176437377929688, 'learning_rate': 1.3853427895981088e-05, 'epoch': 0.92}


 32%|███▏      | 270/846 [1:14:52<2:30:01, 15.63s/it]

{'loss': 0.446, 'grad_norm': 8.045886039733887, 'learning_rate': 1.3617021276595745e-05, 'epoch': 0.96}


 33%|███▎      | 280/846 [1:17:25<2:23:43, 15.24s/it]

{'loss': 0.4006, 'grad_norm': 8.094914436340332, 'learning_rate': 1.3380614657210403e-05, 'epoch': 0.99}


 34%|███▍      | 290/846 [1:19:52<2:22:06, 15.33s/it]

{'loss': 0.3768, 'grad_norm': 9.639501571655273, 'learning_rate': 1.314420803782506e-05, 'epoch': 1.03}


 35%|███▌      | 300/846 [1:22:31<2:26:24, 16.09s/it]

{'loss': 0.3608, 'grad_norm': 10.651742935180664, 'learning_rate': 1.2907801418439719e-05, 'epoch': 1.06}


                                                     
 35%|███▌      | 300/846 [1:24:24<2:26:24, 16.09s/it]

{'eval_accuracy': 0.815, 'eval_loss': 0.4500351548194885, 'eval_runtime': 113.0467, 'eval_samples_per_second': 8.846, 'eval_steps_per_second': 0.142, 'epoch': 1.06}


 37%|███▋      | 310/846 [1:27:05<2:32:21, 17.06s/it]

{'loss': 0.3231, 'grad_norm': 7.238884449005127, 'learning_rate': 1.2671394799054376e-05, 'epoch': 1.1}


 38%|███▊      | 320/846 [1:29:41<2:14:29, 15.34s/it]

{'loss': 0.3224, 'grad_norm': 8.712160110473633, 'learning_rate': 1.2434988179669031e-05, 'epoch': 1.13}


 39%|███▉      | 330/846 [1:32:14<2:11:02, 15.24s/it]

{'loss': 0.3085, 'grad_norm': 12.457151412963867, 'learning_rate': 1.2198581560283689e-05, 'epoch': 1.17}


 40%|████      | 340/846 [1:34:43<2:04:41, 14.79s/it]

{'loss': 0.3629, 'grad_norm': 14.665185928344727, 'learning_rate': 1.1962174940898346e-05, 'epoch': 1.21}


 41%|████▏     | 350/846 [1:37:16<2:04:53, 15.11s/it]

{'loss': 0.3725, 'grad_norm': 9.554437637329102, 'learning_rate': 1.1725768321513003e-05, 'epoch': 1.24}


 43%|████▎     | 360/846 [1:39:47<2:03:23, 15.23s/it]

{'loss': 0.3544, 'grad_norm': 7.312509059906006, 'learning_rate': 1.1489361702127662e-05, 'epoch': 1.28}


 44%|████▎     | 370/846 [1:42:17<1:56:25, 14.67s/it]

{'loss': 0.3356, 'grad_norm': 12.506765365600586, 'learning_rate': 1.1252955082742318e-05, 'epoch': 1.31}


 45%|████▍     | 380/846 [1:44:45<1:54:40, 14.77s/it]

{'loss': 0.3674, 'grad_norm': 11.866743087768555, 'learning_rate': 1.1016548463356975e-05, 'epoch': 1.35}


 46%|████▌     | 390/846 [1:47:10<1:50:37, 14.56s/it]

{'loss': 0.3047, 'grad_norm': 16.599979400634766, 'learning_rate': 1.0780141843971632e-05, 'epoch': 1.38}


 47%|████▋     | 400/846 [1:49:35<1:48:06, 14.54s/it]

{'loss': 0.3881, 'grad_norm': 11.969459533691406, 'learning_rate': 1.054373522458629e-05, 'epoch': 1.42}


                                                     
 47%|████▋     | 400/846 [1:51:26<1:48:06, 14.54s/it]

{'eval_accuracy': 0.825, 'eval_loss': 0.42732781171798706, 'eval_runtime': 110.6843, 'eval_samples_per_second': 9.035, 'eval_steps_per_second': 0.145, 'epoch': 1.42}


 48%|████▊     | 410/846 [1:54:06<2:00:53, 16.64s/it]

{'loss': 0.3487, 'grad_norm': 13.631902694702148, 'learning_rate': 1.0307328605200947e-05, 'epoch': 1.45}


 50%|████▉     | 420/846 [1:56:33<1:43:40, 14.60s/it]

{'loss': 0.3361, 'grad_norm': 16.819847106933594, 'learning_rate': 1.0070921985815602e-05, 'epoch': 1.49}


 51%|█████     | 430/846 [1:59:03<1:40:30, 14.50s/it]

{'loss': 0.3102, 'grad_norm': 14.549715995788574, 'learning_rate': 9.834515366430261e-06, 'epoch': 1.52}


 52%|█████▏    | 440/846 [2:01:30<1:38:17, 14.53s/it]

{'loss': 0.358, 'grad_norm': 9.97203540802002, 'learning_rate': 9.598108747044918e-06, 'epoch': 1.56}


 53%|█████▎    | 450/846 [2:03:56<1:37:38, 14.79s/it]

{'loss': 0.326, 'grad_norm': 6.144720554351807, 'learning_rate': 9.361702127659576e-06, 'epoch': 1.6}


 54%|█████▍    | 460/846 [2:06:22<1:32:24, 14.36s/it]

{'loss': 0.3218, 'grad_norm': 13.620101928710938, 'learning_rate': 9.125295508274233e-06, 'epoch': 1.63}


 56%|█████▌    | 470/846 [2:08:50<1:31:22, 14.58s/it]

{'loss': 0.3678, 'grad_norm': 11.17308235168457, 'learning_rate': 8.888888888888888e-06, 'epoch': 1.67}


 57%|█████▋    | 480/846 [2:11:20<1:32:22, 15.14s/it]

{'loss': 0.3857, 'grad_norm': 8.618144989013672, 'learning_rate': 8.652482269503547e-06, 'epoch': 1.7}


 58%|█████▊    | 490/846 [2:13:47<1:25:49, 14.47s/it]

{'loss': 0.3641, 'grad_norm': 9.569772720336914, 'learning_rate': 8.416075650118204e-06, 'epoch': 1.74}


 59%|█████▉    | 500/846 [2:16:19<1:28:05, 15.28s/it]

{'loss': 0.3854, 'grad_norm': 14.132160186767578, 'learning_rate': 8.17966903073286e-06, 'epoch': 1.77}


                                                     
 59%|█████▉    | 500/846 [2:18:06<1:28:05, 15.28s/it]

{'eval_accuracy': 0.811, 'eval_loss': 0.43015334010124207, 'eval_runtime': 107.4451, 'eval_samples_per_second': 9.307, 'eval_steps_per_second': 0.149, 'epoch': 1.77}


 60%|██████    | 510/846 [2:20:34<1:28:11, 15.75s/it]

{'loss': 0.3941, 'grad_norm': 18.142732620239258, 'learning_rate': 7.943262411347519e-06, 'epoch': 1.81}


 61%|██████▏   | 520/846 [2:23:07<1:21:22, 14.98s/it]

{'loss': 0.3823, 'grad_norm': 12.111699104309082, 'learning_rate': 7.706855791962176e-06, 'epoch': 1.84}


 63%|██████▎   | 530/846 [2:25:29<1:13:20, 13.92s/it]

{'loss': 0.3545, 'grad_norm': 10.79580307006836, 'learning_rate': 7.4704491725768326e-06, 'epoch': 1.88}


 64%|██████▍   | 540/846 [2:27:54<1:14:40, 14.64s/it]

{'loss': 0.3382, 'grad_norm': 9.140817642211914, 'learning_rate': 7.234042553191491e-06, 'epoch': 1.91}


 65%|██████▌   | 550/846 [2:30:21<1:14:09, 15.03s/it]

{'loss': 0.3712, 'grad_norm': 15.024503707885742, 'learning_rate': 6.997635933806147e-06, 'epoch': 1.95}


 66%|██████▌   | 560/846 [2:32:50<1:10:07, 14.71s/it]

{'loss': 0.3784, 'grad_norm': 9.64501953125, 'learning_rate': 6.761229314420804e-06, 'epoch': 1.99}


 67%|██████▋   | 570/846 [2:34:56<1:00:26, 13.14s/it]

{'loss': 0.2801, 'grad_norm': 6.224301338195801, 'learning_rate': 6.524822695035462e-06, 'epoch': 2.02}


 69%|██████▊   | 580/846 [2:37:23<1:05:07, 14.69s/it]

{'loss': 0.2931, 'grad_norm': 19.75403594970703, 'learning_rate': 6.288416075650119e-06, 'epoch': 2.06}


 70%|██████▉   | 590/846 [2:39:47<1:00:16, 14.13s/it]

{'loss': 0.2786, 'grad_norm': 9.090082168579102, 'learning_rate': 6.052009456264776e-06, 'epoch': 2.09}


 71%|███████   | 600/846 [2:42:15<58:59, 14.39s/it]  

{'loss': 0.2375, 'grad_norm': 14.861852645874023, 'learning_rate': 5.815602836879432e-06, 'epoch': 2.13}


                                                   
 71%|███████   | 600/846 [2:44:04<58:59, 14.39s/it]

{'eval_accuracy': 0.823, 'eval_loss': 0.4972337782382965, 'eval_runtime': 108.7097, 'eval_samples_per_second': 9.199, 'eval_steps_per_second': 0.147, 'epoch': 2.13}


 72%|███████▏  | 610/846 [2:46:35<1:02:53, 15.99s/it]

{'loss': 0.3022, 'grad_norm': 21.797039031982422, 'learning_rate': 5.5791962174940904e-06, 'epoch': 2.16}


 73%|███████▎  | 620/846 [2:48:56<52:43, 14.00s/it]  

{'loss': 0.3379, 'grad_norm': 16.574377059936523, 'learning_rate': 5.342789598108748e-06, 'epoch': 2.2}


 74%|███████▍  | 630/846 [2:51:24<53:41, 14.91s/it]

{'loss': 0.2617, 'grad_norm': 13.426061630249023, 'learning_rate': 5.106382978723404e-06, 'epoch': 2.23}


 76%|███████▌  | 640/846 [2:53:52<49:25, 14.39s/it]

{'loss': 0.2944, 'grad_norm': 8.820116996765137, 'learning_rate': 4.869976359338061e-06, 'epoch': 2.27}


 77%|███████▋  | 650/846 [2:56:15<47:12, 14.45s/it]

{'loss': 0.2848, 'grad_norm': 14.37353515625, 'learning_rate': 4.633569739952719e-06, 'epoch': 2.3}


 78%|███████▊  | 660/846 [2:58:39<44:48, 14.45s/it]

{'loss': 0.2801, 'grad_norm': 14.839713096618652, 'learning_rate': 4.397163120567377e-06, 'epoch': 2.34}


 79%|███████▉  | 670/846 [3:01:04<44:22, 15.13s/it]

{'loss': 0.2336, 'grad_norm': 14.682151794433594, 'learning_rate': 4.160756501182033e-06, 'epoch': 2.38}


 80%|████████  | 680/846 [3:03:25<37:45, 13.65s/it]

{'loss': 0.2737, 'grad_norm': 13.095479011535645, 'learning_rate': 3.924349881796691e-06, 'epoch': 2.41}


 82%|████████▏ | 690/846 [3:06:07<40:18, 15.50s/it]

{'loss': 0.2092, 'grad_norm': 17.125879287719727, 'learning_rate': 3.6879432624113475e-06, 'epoch': 2.45}


 83%|████████▎ | 700/846 [3:08:43<37:34, 15.44s/it]

{'loss': 0.2621, 'grad_norm': 13.172910690307617, 'learning_rate': 3.451536643026005e-06, 'epoch': 2.48}


                                                   
 83%|████████▎ | 700/846 [3:10:32<37:34, 15.44s/it]

{'eval_accuracy': 0.815, 'eval_loss': 0.5109070539474487, 'eval_runtime': 109.5937, 'eval_samples_per_second': 9.125, 'eval_steps_per_second': 0.146, 'epoch': 2.48}


 84%|████████▍ | 710/846 [3:13:13<37:44, 16.65s/it]  

{'loss': 0.2855, 'grad_norm': 8.693255424499512, 'learning_rate': 3.2151300236406624e-06, 'epoch': 2.52}


 85%|████████▌ | 720/846 [3:15:46<32:12, 15.34s/it]

{'loss': 0.2466, 'grad_norm': 9.269671440124512, 'learning_rate': 2.978723404255319e-06, 'epoch': 2.55}


 86%|████████▋ | 730/846 [3:18:14<28:38, 14.82s/it]

{'loss': 0.3047, 'grad_norm': 12.865771293640137, 'learning_rate': 2.742316784869977e-06, 'epoch': 2.59}


 87%|████████▋ | 740/846 [3:20:44<26:22, 14.93s/it]

{'loss': 0.295, 'grad_norm': 29.080312728881836, 'learning_rate': 2.5059101654846336e-06, 'epoch': 2.62}


 89%|████████▊ | 750/846 [3:23:20<25:53, 16.19s/it]

{'loss': 0.3128, 'grad_norm': 10.607925415039062, 'learning_rate': 2.269503546099291e-06, 'epoch': 2.66}


 90%|████████▉ | 760/846 [3:25:53<22:05, 15.41s/it]

{'loss': 0.3132, 'grad_norm': 17.867778778076172, 'learning_rate': 2.033096926713948e-06, 'epoch': 2.7}


 91%|█████████ | 770/846 [3:28:28<19:47, 15.63s/it]

{'loss': 0.2255, 'grad_norm': 21.018362045288086, 'learning_rate': 1.7966903073286054e-06, 'epoch': 2.73}


 92%|█████████▏| 780/846 [3:31:06<17:37, 16.03s/it]

{'loss': 0.272, 'grad_norm': 17.45999526977539, 'learning_rate': 1.5602836879432626e-06, 'epoch': 2.77}


 93%|█████████▎| 790/846 [3:33:39<14:03, 15.07s/it]

{'loss': 0.2654, 'grad_norm': 13.217137336730957, 'learning_rate': 1.3238770685579196e-06, 'epoch': 2.8}


 95%|█████████▍| 800/846 [3:36:06<11:23, 14.85s/it]

{'loss': 0.2713, 'grad_norm': 16.339519500732422, 'learning_rate': 1.087470449172577e-06, 'epoch': 2.84}


                                                   
 95%|█████████▍| 800/846 [3:37:58<11:23, 14.85s/it]

{'eval_accuracy': 0.815, 'eval_loss': 0.48287707567214966, 'eval_runtime': 112.0428, 'eval_samples_per_second': 8.925, 'eval_steps_per_second': 0.143, 'epoch': 2.84}


 96%|█████████▌| 810/846 [3:40:36<09:48, 16.34s/it]

{'loss': 0.2385, 'grad_norm': 13.355484962463379, 'learning_rate': 8.510638297872341e-07, 'epoch': 2.87}


 97%|█████████▋| 820/846 [3:43:08<06:29, 15.00s/it]

{'loss': 0.2478, 'grad_norm': 13.705936431884766, 'learning_rate': 6.146572104018913e-07, 'epoch': 2.91}


 98%|█████████▊| 830/846 [3:45:35<03:52, 14.53s/it]

{'loss': 0.2291, 'grad_norm': 8.59310531616211, 'learning_rate': 3.782505910165485e-07, 'epoch': 2.94}


 99%|█████████▉| 840/846 [3:48:04<01:27, 14.63s/it]

{'loss': 0.2581, 'grad_norm': 14.951705932617188, 'learning_rate': 1.4184397163120568e-07, 'epoch': 2.98}


100%|██████████| 846/846 [3:49:23<00:00, 16.27s/it]

{'train_runtime': 13763.1391, 'train_samples_per_second': 1.962, 'train_steps_per_second': 0.061, 'train_loss': 0.3720646234268838, 'epoch': 3.0}





TrainOutput(global_step=846, training_loss=0.3720646234268838, metrics={'train_runtime': 13763.1391, 'train_samples_per_second': 1.962, 'train_steps_per_second': 0.061, 'total_flos': 1775999623680000.0, 'train_loss': 0.3720646234268838, 'epoch': 3.0})

In [17]:
model_path = "./model_roberta"
tokenizer_path = "./tokenizer_roberta"

model.save_pretrained(model_path)
tokenizer.save_pretrained(tokenizer_path)

('./tokenizer_roberta/tokenizer_config.json',
 './tokenizer_roberta/special_tokens_map.json',
 './tokenizer_roberta/vocab.json',
 './tokenizer_roberta/merges.txt',
 './tokenizer_roberta/added_tokens.json')

In [18]:
tokenizer_path = "./tokenizer_roberta"
model_path = "./model_roberta"

tokenizer = RobertaTokenizer.from_pretrained(tokenizer_path)
model = RobertaForSequenceClassification.from_pretrained(model_path)

In [19]:
from transformers import pipeline

classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
text = "Hello, how are you?"

result = classifier(text)
print(result)

[{'label': 'LABEL_1', 'score': 0.9570664763450623}]


In [20]:
text = "I want you to die"

result = classifier(text)
print(result)

[{'label': 'LABEL_0', 'score': 0.9770116209983826}]


In [None]:
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc, confusion_matrix
import matplotlib.pyplot as plt
import torch
from transformers import BertTokenizer, BertForSequenceClassification

data = pd.read_csv("../Data/train_df.csv")
data = data.dropna(subset=['words'])
data = data.sample(1000)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained("chemin/vers/votre/modele")

# Prétraiter les données de test
test_texts = ["Hello, how are you?"]  # Mettez vos propres données de test ici
test_input_ids = []
test_attention_masks = []
for text in test_texts:
    encoded_dict = tokenizer.encode_plus(
                        text,
                        add_special_tokens=True,
                        max_length=128,
                        padding='max_length',
                        return_attention_mask=True,
                        return_tensors='pt',
                   )
    test_input_ids.append(encoded_dict['input_ids'])
    test_attention_masks.append(encoded_dict['attention_mask'])

test_input_ids = torch.cat(test_input_ids, dim=0)
test_attention_masks = torch.cat(test_attention_masks, dim=0)

# Effectuer les prédictions
with torch.no_grad():
    outputs = model(test_input_ids, attention_mask=test_attention_masks)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=1)

# Calculer les métriques de performance
accuracy = accuracy_score(data['target'], predictions)
precision = precision_score(data['target'], predictions)
recall = recall_score(data['target'], predictions)
f1 = f1_score(data['target'], predictions)
fpr, tpr, thresholds = roc_curve(data['target'], predictions)
roc_auc = auc(fpr, tpr)
conf_matrix = confusion_matrix(data['target'], predictions)

# Enregistrer les résultats dans MLflow
import mlflow

artifact_path = './artifacts/'

with mlflow.start_run(run_name="bert-base-uncased"):
    mlflow.log_metric("accuracy", accuracy)
    mlflow.log_metric("Precision", precision)
    mlflow.log_metric("Recall", recall)
    mlflow.log_metric("F1_Score", f1)
    mlflow.log_metric("AUC", roc_auc)

    conf_matrix_path = f"{artifact_path}confusion_matrix.csv"
    pd.DataFrame(conf_matrix).to_csv(conf_matrix_path, index=False, header=False)
    mlflow.log_artifact(conf_matrix_path, "metrics")

    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = {:.2f})'.format(roc_auc))
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True)
    roc_curve_path = f"{artifact_path}roc_curve.png"
    plt.savefig(roc_curve_path)
    plt.close()
    mlflow.log_artifact(roc_curve_path, "plots")
