In [1]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, GlueDataset, default_data_collator
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import mlflow
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = pd.read_csv("../data/train_df.csv")
data = data.dropna(subset=['words'])
data = data.sample(10000)

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
max_len = 128



In [4]:
def encode_text(texts, tokenizer, max_len):
    input_ids = []
    attention_masks = []

    for text in texts:
        encoded_dict = tokenizer.encode_plus(
                            text,
                            add_special_tokens=True,
                            max_length=max_len,
                            padding='max_length',
                            return_attention_mask=True,
                            return_tensors='pt',
                       )
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

In [5]:
input_ids, attention_masks = encode_text(data['words'], tokenizer, max_len)
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, data['target'].values, test_size=0.1, random_state=42)
train_masks, validation_masks, _, _ = train_test_split(attention_masks, input_ids, test_size=0.1, random_state=42)

In [6]:
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)

train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
eval_dataset = TensorDataset(validation_inputs, validation_masks, validation_labels)

  train_inputs = torch.tensor(train_inputs)
  validation_inputs = torch.tensor(validation_inputs)
  train_masks = torch.tensor(train_masks)
  validation_masks = torch.tensor(validation_masks)


In [7]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
def my_data_collator(features):
    batch = {}
    batch['input_ids'] = torch.stack([f[0] for f in features])
    batch['attention_mask'] = torch.stack([f[1] for f in features])
    batch['labels'] = torch.tensor([f[2] for f in features])
    return batch

In [9]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    warmup_steps=0,
    weight_decay=0.01,
    seed=42,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True
)

In [10]:
from sklearn.metrics import accuracy_score

def compute_metrics(eval_predictions):
    labels = eval_predictions.label_ids
    preds = eval_predictions.predictions.argmax(-1)
    accuracy = accuracy_score(labels, preds)
    return {
        'eval_accuracy': accuracy,
    }

In [11]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=my_data_collator,
    compute_metrics=compute_metrics
)

In [12]:
trainer.train()

  1%|          | 10/846 [02:23<3:14:27, 13.96s/it]

{'loss': 0.7017, 'grad_norm': 5.859342098236084, 'learning_rate': 1.9763593380614657e-05, 'epoch': 0.04}


  2%|▏         | 20/846 [04:49<3:13:45, 14.07s/it]

{'loss': 0.6969, 'grad_norm': 2.4730896949768066, 'learning_rate': 1.9527186761229316e-05, 'epoch': 0.07}


  4%|▎         | 30/846 [07:13<3:10:55, 14.04s/it]

{'loss': 0.6869, 'grad_norm': 3.1644699573516846, 'learning_rate': 1.929078014184397e-05, 'epoch': 0.11}


  5%|▍         | 40/846 [09:31<3:06:53, 13.91s/it]

{'loss': 0.6603, 'grad_norm': 4.728598117828369, 'learning_rate': 1.905437352245863e-05, 'epoch': 0.14}


  6%|▌         | 50/846 [11:45<2:58:00, 13.42s/it]

{'loss': 0.6257, 'grad_norm': 4.920960426330566, 'learning_rate': 1.881796690307329e-05, 'epoch': 0.18}


  7%|▋         | 60/846 [14:15<3:15:48, 14.95s/it]

{'loss': 0.6198, 'grad_norm': 6.188103675842285, 'learning_rate': 1.8581560283687945e-05, 'epoch': 0.21}


  8%|▊         | 70/846 [16:36<2:58:39, 13.81s/it]

{'loss': 0.5582, 'grad_norm': 5.564546585083008, 'learning_rate': 1.83451536643026e-05, 'epoch': 0.25}


  9%|▉         | 80/846 [18:57<2:58:14, 13.96s/it]

{'loss': 0.5831, 'grad_norm': 3.6173338890075684, 'learning_rate': 1.810874704491726e-05, 'epoch': 0.28}


 11%|█         | 90/846 [21:22<3:03:02, 14.53s/it]

{'loss': 0.5558, 'grad_norm': 3.4912805557250977, 'learning_rate': 1.7872340425531915e-05, 'epoch': 0.32}


 12%|█▏        | 100/846 [23:47<3:04:55, 14.87s/it]

{'loss': 0.5923, 'grad_norm': 8.15699291229248, 'learning_rate': 1.7635933806146574e-05, 'epoch': 0.35}


                                                   
 12%|█▏        | 100/846 [25:42<3:04:55, 14.87s/it]

{'eval_accuracy': 0.714, 'eval_loss': 0.5612783432006836, 'eval_runtime': 114.7588, 'eval_samples_per_second': 8.714, 'eval_steps_per_second': 0.139, 'epoch': 0.35}


 13%|█▎        | 110/846 [28:10<3:12:26, 15.69s/it] 

{'loss': 0.5625, 'grad_norm': 4.969337463378906, 'learning_rate': 1.7399527186761233e-05, 'epoch': 0.39}


 14%|█▍        | 120/846 [30:36<2:56:35, 14.59s/it]

{'loss': 0.562, 'grad_norm': 5.875092029571533, 'learning_rate': 1.716312056737589e-05, 'epoch': 0.43}


 15%|█▌        | 130/846 [33:05<2:56:57, 14.83s/it]

{'loss': 0.545, 'grad_norm': 4.597893238067627, 'learning_rate': 1.6926713947990544e-05, 'epoch': 0.46}


 17%|█▋        | 140/846 [35:33<2:50:22, 14.48s/it]

{'loss': 0.5687, 'grad_norm': 4.252157688140869, 'learning_rate': 1.6690307328605203e-05, 'epoch': 0.5}


 18%|█▊        | 150/846 [37:57<2:48:42, 14.54s/it]

{'loss': 0.5578, 'grad_norm': 7.474842548370361, 'learning_rate': 1.645390070921986e-05, 'epoch': 0.53}


 19%|█▉        | 160/846 [40:17<2:39:04, 13.91s/it]

{'loss': 0.5037, 'grad_norm': 3.815683603286743, 'learning_rate': 1.6217494089834514e-05, 'epoch': 0.57}


 20%|██        | 170/846 [42:39<2:39:25, 14.15s/it]

{'loss': 0.5014, 'grad_norm': 7.1823835372924805, 'learning_rate': 1.5981087470449176e-05, 'epoch': 0.6}


 21%|██▏       | 180/846 [45:08<2:45:22, 14.90s/it]

{'loss': 0.5378, 'grad_norm': 4.536266326904297, 'learning_rate': 1.5744680851063832e-05, 'epoch': 0.64}


 22%|██▏       | 190/846 [47:29<2:32:41, 13.97s/it]

{'loss': 0.4995, 'grad_norm': 4.914987564086914, 'learning_rate': 1.5508274231678487e-05, 'epoch': 0.67}


 24%|██▎       | 200/846 [49:49<2:30:03, 13.94s/it]

{'loss': 0.5125, 'grad_norm': 5.882621765136719, 'learning_rate': 1.5271867612293146e-05, 'epoch': 0.71}


                                                   
 24%|██▎       | 200/846 [51:43<2:30:03, 13.94s/it]

{'eval_accuracy': 0.735, 'eval_loss': 0.5377095937728882, 'eval_runtime': 114.5153, 'eval_samples_per_second': 8.732, 'eval_steps_per_second': 0.14, 'epoch': 0.71}


 25%|██▍       | 210/846 [54:17<2:51:41, 16.20s/it]

{'loss': 0.5009, 'grad_norm': 5.421145915985107, 'learning_rate': 1.5035460992907802e-05, 'epoch': 0.74}


 26%|██▌       | 220/846 [56:43<2:30:48, 14.45s/it]

{'loss': 0.4926, 'grad_norm': 8.922821044921875, 'learning_rate': 1.4799054373522459e-05, 'epoch': 0.78}


 27%|██▋       | 230/846 [59:09<2:40:26, 15.63s/it]

{'loss': 0.5474, 'grad_norm': 6.75559663772583, 'learning_rate': 1.4562647754137118e-05, 'epoch': 0.82}


 28%|██▊       | 240/846 [1:01:39<2:30:43, 14.92s/it]

{'loss': 0.4945, 'grad_norm': 5.767857074737549, 'learning_rate': 1.4326241134751775e-05, 'epoch': 0.85}


 30%|██▉       | 250/846 [1:04:09<2:29:41, 15.07s/it]

{'loss': 0.5037, 'grad_norm': 5.513686656951904, 'learning_rate': 1.4089834515366433e-05, 'epoch': 0.89}


 31%|███       | 260/846 [1:06:27<2:13:14, 13.64s/it]

{'loss': 0.5152, 'grad_norm': 4.3229756355285645, 'learning_rate': 1.3853427895981088e-05, 'epoch': 0.92}


 32%|███▏      | 270/846 [1:08:45<2:12:30, 13.80s/it]

{'loss': 0.5223, 'grad_norm': 4.262813091278076, 'learning_rate': 1.3617021276595745e-05, 'epoch': 0.96}


 33%|███▎      | 280/846 [1:11:08<2:11:02, 13.89s/it]

{'loss': 0.4891, 'grad_norm': 5.216915130615234, 'learning_rate': 1.3380614657210403e-05, 'epoch': 0.99}


 34%|███▍      | 290/846 [1:13:19<2:08:41, 13.89s/it]

{'loss': 0.4495, 'grad_norm': 4.099324703216553, 'learning_rate': 1.314420803782506e-05, 'epoch': 1.03}


 35%|███▌      | 300/846 [1:15:50<2:14:16, 14.75s/it]

{'loss': 0.4425, 'grad_norm': 6.34481143951416, 'learning_rate': 1.2907801418439719e-05, 'epoch': 1.06}


                                                     
 35%|███▌      | 300/846 [1:17:45<2:14:16, 14.75s/it]

{'eval_accuracy': 0.754, 'eval_loss': 0.5410595536231995, 'eval_runtime': 115.0089, 'eval_samples_per_second': 8.695, 'eval_steps_per_second': 0.139, 'epoch': 1.06}


 37%|███▋      | 310/846 [1:20:17<2:26:20, 16.38s/it]

{'loss': 0.4089, 'grad_norm': 4.115036487579346, 'learning_rate': 1.2671394799054376e-05, 'epoch': 1.1}


 38%|███▊      | 320/846 [1:22:44<2:10:54, 14.93s/it]

{'loss': 0.4707, 'grad_norm': 6.867002487182617, 'learning_rate': 1.2434988179669031e-05, 'epoch': 1.13}


 39%|███▉      | 330/846 [1:25:12<2:04:51, 14.52s/it]

{'loss': 0.4136, 'grad_norm': 5.67103910446167, 'learning_rate': 1.2198581560283689e-05, 'epoch': 1.17}


 40%|████      | 340/846 [1:27:40<2:05:35, 14.89s/it]

{'loss': 0.3994, 'grad_norm': 7.666390419006348, 'learning_rate': 1.1962174940898346e-05, 'epoch': 1.21}


 41%|████▏     | 350/846 [1:30:17<2:11:11, 15.87s/it]

{'loss': 0.4189, 'grad_norm': 6.728829860687256, 'learning_rate': 1.1725768321513003e-05, 'epoch': 1.24}


 43%|████▎     | 360/846 [1:32:45<2:02:36, 15.14s/it]

{'loss': 0.4385, 'grad_norm': 3.3561294078826904, 'learning_rate': 1.1489361702127662e-05, 'epoch': 1.28}


 44%|████▎     | 370/846 [1:35:14<1:56:06, 14.64s/it]

{'loss': 0.5113, 'grad_norm': 6.3342061042785645, 'learning_rate': 1.1252955082742318e-05, 'epoch': 1.31}


 45%|████▍     | 380/846 [1:37:46<1:59:19, 15.36s/it]

{'loss': 0.3842, 'grad_norm': 6.925277233123779, 'learning_rate': 1.1016548463356975e-05, 'epoch': 1.35}


 46%|████▌     | 390/846 [1:40:15<1:53:14, 14.90s/it]

{'loss': 0.4668, 'grad_norm': 8.00277042388916, 'learning_rate': 1.0780141843971632e-05, 'epoch': 1.38}


 47%|████▋     | 400/846 [1:42:44<1:52:24, 15.12s/it]

{'loss': 0.395, 'grad_norm': 8.526599884033203, 'learning_rate': 1.054373522458629e-05, 'epoch': 1.42}


                                                     
 47%|████▋     | 400/846 [1:44:41<1:52:24, 15.12s/it]

{'eval_accuracy': 0.75, 'eval_loss': 0.542919933795929, 'eval_runtime': 117.484, 'eval_samples_per_second': 8.512, 'eval_steps_per_second': 0.136, 'epoch': 1.42}


 48%|████▊     | 410/846 [1:47:12<1:56:54, 16.09s/it]

{'loss': 0.3898, 'grad_norm': 3.973543882369995, 'learning_rate': 1.0307328605200947e-05, 'epoch': 1.45}


 50%|████▉     | 420/846 [1:49:44<1:49:17, 15.39s/it]

{'loss': 0.4751, 'grad_norm': 9.341863632202148, 'learning_rate': 1.0070921985815602e-05, 'epoch': 1.49}


 51%|█████     | 430/846 [1:52:18<1:45:04, 15.15s/it]

{'loss': 0.4021, 'grad_norm': 6.956063747406006, 'learning_rate': 9.834515366430261e-06, 'epoch': 1.52}


 52%|█████▏    | 440/846 [1:54:52<1:42:49, 15.20s/it]

{'loss': 0.4753, 'grad_norm': 9.563010215759277, 'learning_rate': 9.598108747044918e-06, 'epoch': 1.56}


 53%|█████▎    | 450/846 [1:57:33<1:43:34, 15.69s/it]

{'loss': 0.4198, 'grad_norm': 6.871347904205322, 'learning_rate': 9.361702127659576e-06, 'epoch': 1.6}


 54%|█████▍    | 460/846 [2:00:09<1:44:45, 16.28s/it]

{'loss': 0.4522, 'grad_norm': 4.896792411804199, 'learning_rate': 9.125295508274233e-06, 'epoch': 1.63}


 56%|█████▌    | 470/846 [2:02:36<1:29:16, 14.24s/it]

{'loss': 0.3948, 'grad_norm': 6.256485462188721, 'learning_rate': 8.888888888888888e-06, 'epoch': 1.67}


 57%|█████▋    | 480/846 [2:05:08<1:32:36, 15.18s/it]

{'loss': 0.4163, 'grad_norm': 4.25114631652832, 'learning_rate': 8.652482269503547e-06, 'epoch': 1.7}


 58%|█████▊    | 490/846 [2:07:38<1:27:52, 14.81s/it]

{'loss': 0.4197, 'grad_norm': 6.000930309295654, 'learning_rate': 8.416075650118204e-06, 'epoch': 1.74}


 59%|█████▉    | 500/846 [2:10:03<1:23:05, 14.41s/it]

{'loss': 0.4489, 'grad_norm': 6.54167366027832, 'learning_rate': 8.17966903073286e-06, 'epoch': 1.77}


                                                     
 59%|█████▉    | 500/846 [2:11:58<1:23:05, 14.41s/it]

{'eval_accuracy': 0.745, 'eval_loss': 0.527524471282959, 'eval_runtime': 114.7408, 'eval_samples_per_second': 8.715, 'eval_steps_per_second': 0.139, 'epoch': 1.77}


 60%|██████    | 510/846 [2:14:30<1:28:07, 15.74s/it]

{'loss': 0.4448, 'grad_norm': 4.266756057739258, 'learning_rate': 7.943262411347519e-06, 'epoch': 1.81}


 61%|██████▏   | 520/846 [2:17:02<1:21:59, 15.09s/it]

{'loss': 0.4444, 'grad_norm': 5.481785774230957, 'learning_rate': 7.706855791962176e-06, 'epoch': 1.84}


 63%|██████▎   | 530/846 [2:19:24<1:14:55, 14.23s/it]

{'loss': 0.4224, 'grad_norm': 8.982861518859863, 'learning_rate': 7.4704491725768326e-06, 'epoch': 1.88}


 64%|██████▍   | 540/846 [2:21:47<1:10:26, 13.81s/it]

{'loss': 0.4642, 'grad_norm': 5.107002258300781, 'learning_rate': 7.234042553191491e-06, 'epoch': 1.91}


 65%|██████▌   | 550/846 [2:24:15<1:11:18, 14.46s/it]

{'loss': 0.3715, 'grad_norm': 4.8276591300964355, 'learning_rate': 6.997635933806147e-06, 'epoch': 1.95}


 66%|██████▌   | 560/846 [2:26:45<1:12:08, 15.13s/it]

{'loss': 0.4208, 'grad_norm': 7.1624016761779785, 'learning_rate': 6.761229314420804e-06, 'epoch': 1.99}


 67%|██████▋   | 570/846 [2:29:00<1:03:28, 13.80s/it]

{'loss': 0.4282, 'grad_norm': 5.071322917938232, 'learning_rate': 6.524822695035462e-06, 'epoch': 2.02}


 69%|██████▊   | 580/846 [2:31:29<1:05:42, 14.82s/it]

{'loss': 0.3452, 'grad_norm': 3.8869118690490723, 'learning_rate': 6.288416075650119e-06, 'epoch': 2.06}


 70%|██████▉   | 590/846 [2:33:58<1:03:32, 14.89s/it]

{'loss': 0.3364, 'grad_norm': 6.410762786865234, 'learning_rate': 6.052009456264776e-06, 'epoch': 2.09}


 71%|███████   | 600/846 [2:36:24<1:00:25, 14.74s/it]

{'loss': 0.3301, 'grad_norm': 9.713872909545898, 'learning_rate': 5.815602836879432e-06, 'epoch': 2.13}


                                                     
 71%|███████   | 600/846 [2:38:18<1:00:25, 14.74s/it]

{'eval_accuracy': 0.76, 'eval_loss': 0.5928843021392822, 'eval_runtime': 113.6273, 'eval_samples_per_second': 8.801, 'eval_steps_per_second': 0.141, 'epoch': 2.13}


 72%|███████▏  | 610/846 [2:40:42<1:00:55, 15.49s/it]

{'loss': 0.3842, 'grad_norm': 3.46376371383667, 'learning_rate': 5.5791962174940904e-06, 'epoch': 2.16}


 73%|███████▎  | 620/846 [2:43:05<53:41, 14.26s/it]  

{'loss': 0.3175, 'grad_norm': 6.092831134796143, 'learning_rate': 5.342789598108748e-06, 'epoch': 2.2}


 74%|███████▍  | 630/846 [2:45:31<52:02, 14.46s/it]

{'loss': 0.3643, 'grad_norm': 6.376555442810059, 'learning_rate': 5.106382978723404e-06, 'epoch': 2.23}


 76%|███████▌  | 640/846 [2:48:01<50:12, 14.63s/it]

{'loss': 0.3146, 'grad_norm': 5.367424011230469, 'learning_rate': 4.869976359338061e-06, 'epoch': 2.27}


 77%|███████▋  | 650/846 [2:50:20<45:26, 13.91s/it]

{'loss': 0.321, 'grad_norm': 8.063112258911133, 'learning_rate': 4.633569739952719e-06, 'epoch': 2.3}


 78%|███████▊  | 660/846 [2:52:48<45:55, 14.81s/it]

{'loss': 0.3197, 'grad_norm': 8.359748840332031, 'learning_rate': 4.397163120567377e-06, 'epoch': 2.34}


 79%|███████▉  | 670/846 [2:55:12<41:46, 14.24s/it]

{'loss': 0.3329, 'grad_norm': 5.8735151290893555, 'learning_rate': 4.160756501182033e-06, 'epoch': 2.38}


 80%|████████  | 680/846 [2:57:35<38:59, 14.10s/it]

{'loss': 0.3431, 'grad_norm': 6.296497344970703, 'learning_rate': 3.924349881796691e-06, 'epoch': 2.41}


 82%|████████▏ | 690/846 [3:00:01<38:03, 14.64s/it]

{'loss': 0.3809, 'grad_norm': 7.341783046722412, 'learning_rate': 3.6879432624113475e-06, 'epoch': 2.45}


 83%|████████▎ | 700/846 [3:02:28<34:45, 14.28s/it]

{'loss': 0.3968, 'grad_norm': 6.800283908843994, 'learning_rate': 3.451536643026005e-06, 'epoch': 2.48}


                                                   
 83%|████████▎ | 700/846 [3:04:23<34:45, 14.28s/it]

{'eval_accuracy': 0.759, 'eval_loss': 0.5539574027061462, 'eval_runtime': 114.9533, 'eval_samples_per_second': 8.699, 'eval_steps_per_second': 0.139, 'epoch': 2.48}


 84%|████████▍ | 710/846 [3:06:54<36:23, 16.06s/it]  

{'loss': 0.3711, 'grad_norm': 6.49587869644165, 'learning_rate': 3.2151300236406624e-06, 'epoch': 2.52}


 85%|████████▌ | 720/846 [3:09:13<29:02, 13.83s/it]

{'loss': 0.3063, 'grad_norm': 6.1934380531311035, 'learning_rate': 2.978723404255319e-06, 'epoch': 2.55}


 86%|████████▋ | 730/846 [3:11:33<27:01, 13.98s/it]

{'loss': 0.2912, 'grad_norm': 15.227716445922852, 'learning_rate': 2.742316784869977e-06, 'epoch': 2.59}


 87%|████████▋ | 740/846 [3:13:53<24:57, 14.12s/it]

{'loss': 0.354, 'grad_norm': 7.026764869689941, 'learning_rate': 2.5059101654846336e-06, 'epoch': 2.62}


 89%|████████▊ | 750/846 [3:16:17<23:21, 14.60s/it]

{'loss': 0.3392, 'grad_norm': 7.367557525634766, 'learning_rate': 2.269503546099291e-06, 'epoch': 2.66}


 90%|████████▉ | 760/846 [3:18:44<20:46, 14.49s/it]

{'loss': 0.3084, 'grad_norm': 6.388322830200195, 'learning_rate': 2.033096926713948e-06, 'epoch': 2.7}


 91%|█████████ | 770/846 [3:21:14<19:00, 15.01s/it]

{'loss': 0.3463, 'grad_norm': 8.634504318237305, 'learning_rate': 1.7966903073286054e-06, 'epoch': 2.73}


 92%|█████████▏| 780/846 [3:23:42<15:57, 14.50s/it]

{'loss': 0.2961, 'grad_norm': 5.8350396156311035, 'learning_rate': 1.5602836879432626e-06, 'epoch': 2.77}


 93%|█████████▎| 790/846 [3:26:05<13:07, 14.06s/it]

{'loss': 0.3263, 'grad_norm': 6.1937575340271, 'learning_rate': 1.3238770685579196e-06, 'epoch': 2.8}


 95%|█████████▍| 800/846 [3:28:34<11:25, 14.90s/it]

{'loss': 0.3333, 'grad_norm': 8.846382141113281, 'learning_rate': 1.087470449172577e-06, 'epoch': 2.84}


                                                   
 95%|█████████▍| 800/846 [3:30:27<11:25, 14.90s/it]

{'eval_accuracy': 0.756, 'eval_loss': 0.5845742225646973, 'eval_runtime': 113.5017, 'eval_samples_per_second': 8.81, 'eval_steps_per_second': 0.141, 'epoch': 2.84}


 96%|█████████▌| 810/846 [3:33:00<09:44, 16.24s/it]

{'loss': 0.3309, 'grad_norm': 5.481354713439941, 'learning_rate': 8.510638297872341e-07, 'epoch': 2.87}


 97%|█████████▋| 820/846 [3:35:30<06:30, 15.03s/it]

{'loss': 0.3251, 'grad_norm': 5.8354082107543945, 'learning_rate': 6.146572104018913e-07, 'epoch': 2.91}


 98%|█████████▊| 830/846 [3:37:57<03:53, 14.59s/it]

{'loss': 0.3716, 'grad_norm': 5.203203201293945, 'learning_rate': 3.782505910165485e-07, 'epoch': 2.94}


 99%|█████████▉| 840/846 [3:40:18<01:25, 14.19s/it]

{'loss': 0.3891, 'grad_norm': 4.673222064971924, 'learning_rate': 1.4184397163120568e-07, 'epoch': 2.98}


100%|██████████| 846/846 [3:41:35<00:00, 15.72s/it]

{'train_runtime': 13296.0218, 'train_samples_per_second': 2.031, 'train_steps_per_second': 0.064, 'train_loss': 0.4437316661749044, 'epoch': 3.0}





TrainOutput(global_step=846, training_loss=0.4437316661749044, metrics={'train_runtime': 13296.0218, 'train_samples_per_second': 2.031, 'train_steps_per_second': 0.064, 'total_flos': 1775999623680000.0, 'train_loss': 0.4437316661749044, 'epoch': 3.0})

In [13]:
model_path = "./model"
tokenizer_path = "./tokenizer"

model.save_pretrained(model_path)
tokenizer.save_pretrained(tokenizer_path)

('./tokenizer/tokenizer_config.json',
 './tokenizer/special_tokens_map.json',
 './tokenizer/vocab.txt',
 './tokenizer/added_tokens.json')

In [14]:
tokenizer_path = "./tokenizer"
model_path = "./model"

tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
model = BertForSequenceClassification.from_pretrained(model_path)

In [15]:
from transformers import pipeline

classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
text = "Hello, how are you?"

result = classifier(text)
print(result)

[{'label': 'LABEL_1', 'score': 0.8972613215446472}]


In [16]:
text = "I want you to die"

result = classifier(text)
print(result)

[{'label': 'LABEL_0', 'score': 0.7834637761116028}]


In [None]:
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc, confusion_matrix
import matplotlib.pyplot as plt
import torch
from transformers import BertTokenizer, BertForSequenceClassification

data = pd.read_csv("../Data/train_df.csv")
data = data.dropna(subset=['words'])
data = data.sample(1000)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained("chemin/vers/votre/modele")

# Prétraiter les données de test
test_texts = ["Hello, how are you?"]  # Mettez vos propres données de test ici
test_input_ids = []
test_attention_masks = []
for text in test_texts:
    encoded_dict = tokenizer.encode_plus(
                        text,
                        add_special_tokens=True,
                        max_length=128,
                        padding='max_length',
                        return_attention_mask=True,
                        return_tensors='pt',
                   )
    test_input_ids.append(encoded_dict['input_ids'])
    test_attention_masks.append(encoded_dict['attention_mask'])

test_input_ids = torch.cat(test_input_ids, dim=0)
test_attention_masks = torch.cat(test_attention_masks, dim=0)

# Effectuer les prédictions
with torch.no_grad():
    outputs = model(test_input_ids, attention_mask=test_attention_masks)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=1)

# Calculer les métriques de performance
accuracy = accuracy_score(data['target'], predictions)
precision = precision_score(data['target'], predictions)
recall = recall_score(data['target'], predictions)
f1 = f1_score(data['target'], predictions)
fpr, tpr, thresholds = roc_curve(data['target'], predictions)
roc_auc = auc(fpr, tpr)
conf_matrix = confusion_matrix(data['target'], predictions)

# Enregistrer les résultats dans MLflow
import mlflow

artifact_path = './artifacts/'

with mlflow.start_run(run_name="bert-base-uncased"):
    mlflow.log_metric("accuracy", accuracy)
    mlflow.log_metric("Precision", precision)
    mlflow.log_metric("Recall", recall)
    mlflow.log_metric("F1_Score", f1)
    mlflow.log_metric("AUC", roc_auc)

    conf_matrix_path = f"{artifact_path}confusion_matrix.csv"
    pd.DataFrame(conf_matrix).to_csv(conf_matrix_path, index=False, header=False)
    mlflow.log_artifact(conf_matrix_path, "metrics")

    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = {:.2f})'.format(roc_auc))
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True)
    roc_curve_path = f"{artifact_path}roc_curve.png"
    plt.savefig(roc_curve_path)
    plt.close()
    mlflow.log_artifact(roc_curve_path, "plots")
