In [1]:
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.tree import plot_tree

In [2]:
raw_data      = pd.read_csv('../data/sentence-relations/train.csv', index_col='id')
raw_submissions = pd.read_csv('../data/sentence-relations/test.csv', index_col='id')

# Drop chinese and thai
raw_data = raw_data[raw_data['lang_abv'] != 'zh']
raw_data = raw_data[raw_data['lang_abv'] != 'th']

training_data, test_data = train_test_split(raw_data, test_size=0.2, random_state=42)

training_data.head()

Unnamed: 0_level_0,premise,hypothesis,lang_abv,language,label
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
4e5ad9e03a,اب اتنا خفیہ تھا یہ.,یہ عوامی معلومات تھی۔,ur,Urdu,2
fea0d3c7e8,oh that's accommodating,That is disruptive.,en,English,2
4ce586b61e,more than anything else in this day and age th...,"In your decisions age is a big factor, and I a...",en,English,1
0dbcd1012b,"Aie! les boucaniers en-dessous s'écriaient, et...",Les Buccaneers étaient bruyants quand ils disa...,fr,French,0
157895e59c,"一个采购战略组织，是壮大的力资本发展战略的一部分, 这是在原则VI上讨论的。",原则四涉及财富500强机构的资本发展战略。,zh,Chinese,1


In [3]:
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer

roberta_tokenizer = AutoTokenizer.from_pretrained("roberta-base", truncation=True)
data_collator = DataCollatorWithPadding(tokenizer=roberta_tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from torch.utils.data import Dataset, DataLoader

class BertSentenceDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

        self.X = self.data['premise'] + '[SEP]' + self.data['hypothesis']
        self.y = self.data['label']

        self.encoded = self.tokenizer(self.X.tolist(), padding=True, truncation=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encoded['input_ids'][idx],
            'attention_mask': self.encoded['attention_mask'][idx],
            'label': self.y.iloc[idx]
        }

bert_train_dataset = BertSentenceDataset(training_data, roberta_tokenizer)
bert_test_dataset = BertSentenceDataset(test_data, roberta_tokenizer)

In [5]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
num_labels = len(id2label)

model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels, id2label=id2label, label2id=label2id)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

training_args = TrainingArguments(
    output_dir="roberta-base-sentence-relation",
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=10,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [7]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=bert_train_dataset,
    eval_dataset=bert_test_dataset,
    tokenizer=roberta_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

  2%|▏         | 500/24240 [02:13<1:46:03,  3.73it/s]

{'loss': 1.1054, 'grad_norm': 4.7216796875, 'learning_rate': 9.793729372937294e-06, 'epoch': 0.21}


  4%|▍         | 1000/24240 [04:27<1:44:33,  3.70it/s]

{'loss': 1.1033, 'grad_norm': 7.050784587860107, 'learning_rate': 9.587458745874588e-06, 'epoch': 0.41}


  6%|▌         | 1500/24240 [06:42<1:41:47,  3.72it/s]

{'loss': 1.1035, 'grad_norm': 1.9336528778076172, 'learning_rate': 9.381188118811881e-06, 'epoch': 0.62}


  8%|▊         | 2000/24240 [08:56<1:39:36,  3.72it/s]

{'loss': 1.0552, 'grad_norm': 36.41044616699219, 'learning_rate': 9.174917491749176e-06, 'epoch': 0.83}


                                                      
 10%|█         | 2424/24240 [11:38<1:37:51,  3.72it/s]

{'eval_loss': 0.8987377285957336, 'eval_accuracy': 0.58003300330033, 'eval_f1': 0.58102141191718, 'eval_precision': 0.5983481660243127, 'eval_recall': 0.5814349427328223, 'eval_runtime': 47.8222, 'eval_samples_per_second': 50.688, 'eval_steps_per_second': 12.672, 'epoch': 1.0}


 10%|█         | 2500/24240 [12:00<1:37:13,  3.73it/s] 

{'loss': 0.9478, 'grad_norm': 20.630041122436523, 'learning_rate': 8.968646864686469e-06, 'epoch': 1.03}


 12%|█▏        | 3000/24240 [14:15<1:35:15,  3.72it/s]

{'loss': 0.8998, 'grad_norm': 5.476257801055908, 'learning_rate': 8.762376237623764e-06, 'epoch': 1.24}


 14%|█▍        | 3500/24240 [16:29<1:32:50,  3.72it/s]

{'loss': 0.8848, 'grad_norm': 129.0258026123047, 'learning_rate': 8.556105610561057e-06, 'epoch': 1.44}


 17%|█▋        | 4000/24240 [18:44<1:30:42,  3.72it/s]

{'loss': 0.9045, 'grad_norm': 5.842039585113525, 'learning_rate': 8.34983498349835e-06, 'epoch': 1.65}


 19%|█▊        | 4500/24240 [20:58<1:28:42,  3.71it/s]

{'loss': 0.8757, 'grad_norm': 102.44477844238281, 'learning_rate': 8.143564356435644e-06, 'epoch': 1.86}


                                                      
 20%|██        | 4848/24240 [23:20<1:27:03,  3.71it/s]

{'eval_loss': 0.9238146543502808, 'eval_accuracy': 0.6283003300330033, 'eval_f1': 0.6240873010137932, 'eval_precision': 0.6445000128414363, 'eval_recall': 0.6262280747759402, 'eval_runtime': 47.9049, 'eval_samples_per_second': 50.6, 'eval_steps_per_second': 12.65, 'epoch': 2.0}


 21%|██        | 5000/24240 [24:18<1:26:15,  3.72it/s]  

{'loss': 0.8376, 'grad_norm': 13.971918106079102, 'learning_rate': 7.937293729372937e-06, 'epoch': 2.06}


 23%|██▎       | 5500/24240 [26:33<1:24:09,  3.71it/s]

{'loss': 0.7433, 'grad_norm': 2.1382336616516113, 'learning_rate': 7.73102310231023e-06, 'epoch': 2.27}


 25%|██▍       | 6000/24240 [28:47<1:21:43,  3.72it/s]

{'loss': 0.7879, 'grad_norm': 2.9458272457122803, 'learning_rate': 7.524752475247525e-06, 'epoch': 2.48}


 27%|██▋       | 6500/24240 [31:02<1:19:30,  3.72it/s]

{'loss': 0.7754, 'grad_norm': 54.041473388671875, 'learning_rate': 7.318481848184819e-06, 'epoch': 2.68}


 29%|██▉       | 7000/24240 [33:17<1:17:19,  3.72it/s]

{'loss': 0.7677, 'grad_norm': 167.6647186279297, 'learning_rate': 7.112211221122113e-06, 'epoch': 2.89}


                                                      
 30%|███       | 7272/24240 [35:18<1:16:04,  3.72it/s]

{'eval_loss': 1.0673168897628784, 'eval_accuracy': 0.6146864686468647, 'eval_f1': 0.6092101013976831, 'eval_precision': 0.681306093617447, 'eval_recall': 0.6161757202672046, 'eval_runtime': 47.8751, 'eval_samples_per_second': 50.632, 'eval_steps_per_second': 12.658, 'epoch': 3.0}


 31%|███       | 7500/24240 [36:20<1:14:50,  3.73it/s] 

{'loss': 0.695, 'grad_norm': 15.961825370788574, 'learning_rate': 6.905940594059406e-06, 'epoch': 3.09}


 33%|███▎      | 8000/24240 [38:35<1:12:47,  3.72it/s]

{'loss': 0.6804, 'grad_norm': 13.960051536560059, 'learning_rate': 6.6996699669967e-06, 'epoch': 3.3}


 35%|███▌      | 8500/24240 [40:49<1:10:26,  3.72it/s]

{'loss': 0.6676, 'grad_norm': 3.461760997772217, 'learning_rate': 6.493399339933993e-06, 'epoch': 3.51}


 37%|███▋      | 9000/24240 [43:04<1:08:09,  3.73it/s]

{'loss': 0.6445, 'grad_norm': 76.6299819946289, 'learning_rate': 6.287128712871288e-06, 'epoch': 3.71}


 39%|███▉      | 9500/24240 [45:18<1:05:52,  3.73it/s]

{'loss': 0.653, 'grad_norm': 19.976964950561523, 'learning_rate': 6.080858085808581e-06, 'epoch': 3.92}


                                                      
 40%|████      | 9696/24240 [46:58<1:05:08,  3.72it/s]

{'eval_loss': 1.1034132242202759, 'eval_accuracy': 0.6357260726072608, 'eval_f1': 0.6364762941542551, 'eval_precision': 0.6491785485339271, 'eval_recall': 0.6377629405377029, 'eval_runtime': 47.6486, 'eval_samples_per_second': 50.872, 'eval_steps_per_second': 12.718, 'epoch': 4.0}


 41%|████▏     | 10000/24240 [48:21<1:03:46,  3.72it/s]

{'loss': 0.6366, 'grad_norm': 78.17413330078125, 'learning_rate': 5.874587458745875e-06, 'epoch': 4.13}


 43%|████▎     | 10500/24240 [50:36<1:01:38,  3.71it/s]

{'loss': 0.5815, 'grad_norm': 25.864952087402344, 'learning_rate': 5.668316831683169e-06, 'epoch': 4.33}


 45%|████▌     | 11000/24240 [52:50<59:12,  3.73it/s]  

{'loss': 0.5991, 'grad_norm': 109.44344329833984, 'learning_rate': 5.4620462046204625e-06, 'epoch': 4.54}


 47%|████▋     | 11500/24240 [55:05<57:08,  3.72it/s]

{'loss': 0.6189, 'grad_norm': 9.01184368133545, 'learning_rate': 5.2557755775577555e-06, 'epoch': 4.74}


 50%|████▉     | 12000/24240 [57:19<54:49,  3.72it/s]

{'loss': 0.5876, 'grad_norm': 25.69721794128418, 'learning_rate': 5.04950495049505e-06, 'epoch': 4.95}


                                                     
 50%|█████     | 12120/24240 [58:39<54:10,  3.73it/s]

{'eval_loss': 1.2187976837158203, 'eval_accuracy': 0.6608910891089109, 'eval_f1': 0.658596634618647, 'eval_precision': 0.6689315286828724, 'eval_recall': 0.659285541095401, 'eval_runtime': 47.836, 'eval_samples_per_second': 50.673, 'eval_steps_per_second': 12.668, 'epoch': 5.0}


 52%|█████▏    | 12500/24240 [1:00:23<52:38,  3.72it/s] 

{'loss': 0.5533, 'grad_norm': 0.9487907886505127, 'learning_rate': 4.843234323432344e-06, 'epoch': 5.16}


 54%|█████▎    | 13000/24240 [1:02:37<50:22,  3.72it/s]

{'loss': 0.4938, 'grad_norm': 82.64549255371094, 'learning_rate': 4.636963696369637e-06, 'epoch': 5.36}


 56%|█████▌    | 13500/24240 [1:04:52<48:22,  3.70it/s]

{'loss': 0.5252, 'grad_norm': 53.53742218017578, 'learning_rate': 4.430693069306931e-06, 'epoch': 5.57}


 58%|█████▊    | 14000/24240 [1:07:07<45:57,  3.71it/s]

{'loss': 0.557, 'grad_norm': 8.554754257202148, 'learning_rate': 4.224422442244225e-06, 'epoch': 5.78}


 60%|█████▉    | 14500/24240 [1:09:21<43:35,  3.72it/s]

{'loss': 0.521, 'grad_norm': 57.56753158569336, 'learning_rate': 4.0181518151815185e-06, 'epoch': 5.98}


                                                       
 60%|██████    | 14544/24240 [1:10:21<43:30,  3.71it/s]

{'eval_loss': 1.2910977602005005, 'eval_accuracy': 0.6468646864686468, 'eval_f1': 0.6399207970791586, 'eval_precision': 0.6617806569735655, 'eval_recall': 0.6429784316348521, 'eval_runtime': 47.8726, 'eval_samples_per_second': 50.634, 'eval_steps_per_second': 12.659, 'epoch': 6.0}


 62%|██████▏   | 15000/24240 [1:12:25<41:19,  3.73it/s]   

{'loss': 0.4607, 'grad_norm': 0.027232948690652847, 'learning_rate': 3.8118811881188123e-06, 'epoch': 6.19}


 64%|██████▍   | 15500/24240 [1:14:40<39:11,  3.72it/s]

{'loss': 0.5025, 'grad_norm': 104.36022186279297, 'learning_rate': 3.605610561056106e-06, 'epoch': 6.39}


 66%|██████▌   | 16000/24240 [1:16:54<36:52,  3.72it/s]

{'loss': 0.4475, 'grad_norm': 0.053507450968027115, 'learning_rate': 3.3993399339933995e-06, 'epoch': 6.6}


 68%|██████▊   | 16500/24240 [1:19:09<34:36,  3.73it/s]

{'loss': 0.4835, 'grad_norm': 66.76116943359375, 'learning_rate': 3.1930693069306933e-06, 'epoch': 6.81}


                                                       
 70%|███████   | 16968/24240 [1:22:02<32:26,  3.74it/s]

{'eval_loss': 1.5261703729629517, 'eval_accuracy': 0.6674917491749175, 'eval_f1': 0.6665768108330803, 'eval_precision': 0.6667737118224704, 'eval_recall': 0.666575722750142, 'eval_runtime': 47.6783, 'eval_samples_per_second': 50.841, 'eval_steps_per_second': 12.71, 'epoch': 7.0}


 70%|███████   | 17000/24240 [1:22:12<32:24,  3.72it/s]   

{'loss': 0.4347, 'grad_norm': 25.050125122070312, 'learning_rate': 2.986798679867987e-06, 'epoch': 7.01}


 72%|███████▏  | 17500/24240 [1:24:26<30:04,  3.73it/s]

{'loss': 0.4319, 'grad_norm': 62.67559051513672, 'learning_rate': 2.780528052805281e-06, 'epoch': 7.22}


 74%|███████▍  | 18000/24240 [1:26:40<27:57,  3.72it/s]

{'loss': 0.4345, 'grad_norm': 19.224254608154297, 'learning_rate': 2.5742574257425744e-06, 'epoch': 7.43}


 76%|███████▋  | 18500/24240 [1:28:55<25:41,  3.72it/s]

{'loss': 0.4294, 'grad_norm': 1.5037357807159424, 'learning_rate': 2.3679867986798682e-06, 'epoch': 7.63}


 78%|███████▊  | 19000/24240 [1:31:09<23:24,  3.73it/s]

{'loss': 0.4679, 'grad_norm': 5.236189842224121, 'learning_rate': 2.161716171617162e-06, 'epoch': 7.84}


                                                       
 80%|████████  | 19392/24240 [1:33:42<21:37,  3.74it/s]

{'eval_loss': 1.5748369693756104, 'eval_accuracy': 0.6596534653465347, 'eval_f1': 0.6552138641350916, 'eval_precision': 0.661931643759255, 'eval_recall': 0.6563418256126647, 'eval_runtime': 47.698, 'eval_samples_per_second': 50.82, 'eval_steps_per_second': 12.705, 'epoch': 8.0}


 80%|████████  | 19500/24240 [1:34:12<21:10,  3.73it/s]   

{'loss': 0.4644, 'grad_norm': 23.332002639770508, 'learning_rate': 1.9554455445544555e-06, 'epoch': 8.04}


 83%|████████▎ | 20000/24240 [1:36:27<19:03,  3.71it/s]

{'loss': 0.3775, 'grad_norm': 24.04757308959961, 'learning_rate': 1.7491749174917493e-06, 'epoch': 8.25}


 85%|████████▍ | 20500/24240 [1:38:42<16:46,  3.72it/s]

{'loss': 0.4079, 'grad_norm': 100.98265075683594, 'learning_rate': 1.5429042904290431e-06, 'epoch': 8.46}


 87%|████████▋ | 21000/24240 [1:40:56<14:28,  3.73it/s]

{'loss': 0.4263, 'grad_norm': 0.019555306062102318, 'learning_rate': 1.3366336633663367e-06, 'epoch': 8.66}


 89%|████████▊ | 21500/24240 [1:43:10<12:14,  3.73it/s]

{'loss': 0.3917, 'grad_norm': 0.08595249056816101, 'learning_rate': 1.1303630363036304e-06, 'epoch': 8.87}


                                                       
 90%|█████████ | 21816/24240 [1:45:22<10:49,  3.73it/s]

{'eval_loss': 1.7713961601257324, 'eval_accuracy': 0.6646039603960396, 'eval_f1': 0.6622007920893216, 'eval_precision': 0.6637151931563942, 'eval_recall': 0.6623837111198605, 'eval_runtime': 47.6334, 'eval_samples_per_second': 50.889, 'eval_steps_per_second': 12.722, 'epoch': 9.0}


 91%|█████████ | 22000/24240 [1:46:13<09:59,  3.74it/s]   

{'loss': 0.4291, 'grad_norm': 0.01362778153270483, 'learning_rate': 9.240924092409241e-07, 'epoch': 9.08}


 93%|█████████▎| 22475/24240 [1:48:21<07:54,  3.72it/s]

KeyboardInterrupt: 

In [None]:
# Get the best model
model = AutoModelForSequenceClassification.from_pretrained('roberta-base-sentence-relation/checkpoint-4816', num_labels=num_labels, id2label=id2label, label2id=label2id)

# Evaluate the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=bert_train_dataset,
    eval_dataset=bert_test_dataset,
    tokenizer=None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

100%|██████████| 86/86 [00:08<00:00, 10.50it/s]


{'eval_loss': 1.6948167085647583,
 'eval_accuracy': 0.7991266375545851,
 'eval_f1': 0.7975053326129699,
 'eval_precision': 0.7997977316986625,
 'eval_recall': 0.7971155023486721,
 'eval_runtime': 8.2844,
 'eval_samples_per_second': 165.854,
 'eval_steps_per_second': 10.381}