In [1]:
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.tree import plot_tree

In [2]:
raw_data      = pd.read_csv('../data/sentence-relations/train.csv', index_col='id')
raw_submissions = pd.read_csv('../data/sentence-relations/test.csv', index_col='id')

# Only take english
raw_data = raw_data[raw_data['lang_abv'] == 'en']

training_data, test_data = train_test_split(raw_data, test_size=0.2, random_state=42)

training_data.head()

Unnamed: 0_level_0,premise,hypothesis,lang_abv,language,label
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
b6e01c1a07,"Also, the Holy Family are said to have shelter...",The Holy family spent a total of three days here.,en,English,1
b8fa1a0044,Participants generally viewed the new internal...,Those organizations affected by the Sarbanes-O...,en,English,0
c828f51ef6,With a little practice almost anyone can flip ...,Practicing lets you do anything you put your m...,en,English,1
b2c98d5a99,More reserved and remote but a better administ...,The uncle had no match in administration; cert...,en,English,2
9cd35fee05,The company later told us that it had disconti...,The company later told us that it had enhanced...,en,English,2


In [3]:
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer

roberta_tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased", truncation=True)
data_collator = DataCollatorWithPadding(tokenizer=roberta_tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from torch.utils.data import Dataset, DataLoader

class BertSentenceDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

        self.X = self.data['premise'] + '[SEP]' + self.data['hypothesis']
        self.y = self.data['label']

        self.encoded = self.tokenizer(self.X.tolist(), padding=True, truncation=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encoded['input_ids'][idx],
            'attention_mask': self.encoded['attention_mask'][idx],
            'label': self.y.iloc[idx]
        }

bert_train_dataset = BertSentenceDataset(training_data, roberta_tokenizer)
bert_test_dataset = BertSentenceDataset(test_data, roberta_tokenizer)

In [14]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
num_labels = len(id2label)

model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased", num_labels=num_labels, id2label=id2label, label2id=label2id)

# Freeze the model except
modfication_layers = ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']

for name, param in model.named_parameters():
    if not any(layer in name for layer in modfication_layers):
        param.requires_grad = False

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

training_args = TrainingArguments(
    output_dir="roberta-base-sentence-relation",
    learning_rate=1e-3,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    num_train_epochs=30,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    # fp16=True
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [18]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=bert_train_dataset,
    eval_dataset=bert_test_dataset,
    tokenizer=roberta_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

  3%|▎         | 22/660 [00:22<09:30,  1.12it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                
  3%|▎         | 22/660 [00:27<09:30,  1.12it/s]
[A

{'eval_loss': 1.063098430633545, 'eval_accuracy': 0.4366812227074236, 'eval_f1': 0.4070744816125707, 'eval_precision': 0.4326339325367261, 'eval_recall': 0.4288465070460162, 'eval_runtime': 4.3319, 'eval_samples_per_second': 317.179, 'eval_steps_per_second': 1.385, 'epoch': 1.0}


  7%|▋         | 44/660 [00:50<09:12,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                
  7%|▋         | 44/660 [00:54<09:12,  1.11it/s]
[A

{'eval_loss': 1.0937672853469849, 'eval_accuracy': 0.3937409024745269, 'eval_f1': 0.3253555493041426, 'eval_precision': 0.4178565600979394, 'eval_recall': 0.3884380774282598, 'eval_runtime': 4.3309, 'eval_samples_per_second': 317.256, 'eval_steps_per_second': 1.385, 'epoch': 2.0}


 10%|█         | 66/660 [01:18<08:52,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                
 10%|█         | 66/660 [01:22<08:52,  1.11it/s]
[A

{'eval_loss': 1.0500534772872925, 'eval_accuracy': 0.4621542940320233, 'eval_f1': 0.45716249226518185, 'eval_precision': 0.45835402168458544, 'eval_recall': 0.45901219417173134, 'eval_runtime': 4.3299, 'eval_samples_per_second': 317.326, 'eval_steps_per_second': 1.386, 'epoch': 3.0}


 13%|█▎        | 88/660 [01:46<08:32,  1.12it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                
 13%|█▎        | 88/660 [01:50<08:32,  1.12it/s]
[A

{'eval_loss': 1.0821384191513062, 'eval_accuracy': 0.4264919941775837, 'eval_f1': 0.41301469176676914, 'eval_precision': 0.4501154941593068, 'eval_recall': 0.4316256762171909, 'eval_runtime': 4.329, 'eval_samples_per_second': 317.393, 'eval_steps_per_second': 1.386, 'epoch': 4.0}


 17%|█▋        | 110/660 [02:13<08:13,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 17%|█▋        | 110/660 [02:18<08:13,  1.11it/s]
[A

{'eval_loss': 1.0493252277374268, 'eval_accuracy': 0.46142649199417757, 'eval_f1': 0.4553160563757958, 'eval_precision': 0.4585908973367971, 'eval_recall': 0.45790060998686527, 'eval_runtime': 4.3298, 'eval_samples_per_second': 317.336, 'eval_steps_per_second': 1.386, 'epoch': 5.0}


 20%|██        | 132/660 [02:41<07:53,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 20%|██        | 132/660 [02:45<07:53,  1.11it/s]
[A

{'eval_loss': 1.086621642112732, 'eval_accuracy': 0.43231441048034935, 'eval_f1': 0.41802477202291016, 'eval_precision': 0.4548607288212178, 'eval_recall': 0.4356052004719606, 'eval_runtime': 4.3339, 'eval_samples_per_second': 317.033, 'eval_steps_per_second': 1.384, 'epoch': 6.0}


 23%|██▎       | 154/660 [03:09<07:33,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 23%|██▎       | 154/660 [03:13<07:33,  1.11it/s]
[A

{'eval_loss': 1.061726689338684, 'eval_accuracy': 0.45924308588064044, 'eval_f1': 0.4526271201432288, 'eval_precision': 0.47662143444861, 'eval_recall': 0.46322895100069017, 'eval_runtime': 4.3326, 'eval_samples_per_second': 317.129, 'eval_steps_per_second': 1.385, 'epoch': 7.0}


 27%|██▋       | 176/660 [03:36<07:14,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 27%|██▋       | 176/660 [03:40<07:14,  1.11it/s]
[A

{'eval_loss': 1.053725004196167, 'eval_accuracy': 0.46870451237263466, 'eval_f1': 0.46714115736038514, 'eval_precision': 0.4691819726867383, 'eval_recall': 0.468719806763285, 'eval_runtime': 4.3316, 'eval_samples_per_second': 317.205, 'eval_steps_per_second': 1.385, 'epoch': 8.0}


 30%|███       | 198/660 [04:04<06:54,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 30%|███       | 198/660 [04:08<06:54,  1.11it/s]
[A

{'eval_loss': 1.056142807006836, 'eval_accuracy': 0.44905385735080056, 'eval_f1': 0.4406340349896436, 'eval_precision': 0.44869727869145454, 'eval_recall': 0.44558190854649476, 'eval_runtime': 4.3402, 'eval_samples_per_second': 316.577, 'eval_steps_per_second': 1.382, 'epoch': 9.0}


 33%|███▎      | 220/660 [04:31<06:34,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 33%|███▎      | 220/660 [04:36<06:34,  1.11it/s]
[A

{'eval_loss': 1.0637173652648926, 'eval_accuracy': 0.4417758369723435, 'eval_f1': 0.40803691800142383, 'eval_precision': 0.44730550698014, 'eval_recall': 0.43283591019390455, 'eval_runtime': 4.3408, 'eval_samples_per_second': 316.532, 'eval_steps_per_second': 1.382, 'epoch': 10.0}


 37%|███▋      | 242/660 [04:59<06:15,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 37%|███▋      | 242/660 [05:03<06:15,  1.11it/s]
[A

{'eval_loss': 1.0489882230758667, 'eval_accuracy': 0.45487627365356625, 'eval_f1': 0.44338010446211645, 'eval_precision': 0.4534807833277525, 'eval_recall': 0.44982371045659963, 'eval_runtime': 4.3421, 'eval_samples_per_second': 316.435, 'eval_steps_per_second': 1.382, 'epoch': 11.0}


 40%|████      | 264/660 [05:27<05:55,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 40%|████      | 264/660 [05:31<05:55,  1.11it/s]
[A

{'eval_loss': 1.0729222297668457, 'eval_accuracy': 0.44614264919941776, 'eval_f1': 0.4395796134538534, 'eval_precision': 0.45263251540465094, 'eval_recall': 0.44514208909370195, 'eval_runtime': 4.3391, 'eval_samples_per_second': 316.657, 'eval_steps_per_second': 1.383, 'epoch': 12.0}


 43%|████▎     | 286/660 [05:54<05:35,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 43%|████▎     | 286/660 [05:59<05:35,  1.11it/s]
[A

{'eval_loss': 1.0700676441192627, 'eval_accuracy': 0.44759825327510916, 'eval_f1': 0.4275121366540675, 'eval_precision': 0.46016121895912837, 'eval_recall': 0.44325076804915514, 'eval_runtime': 4.3361, 'eval_samples_per_second': 316.877, 'eval_steps_per_second': 1.384, 'epoch': 13.0}


 47%|████▋     | 308/660 [06:22<05:16,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 47%|████▋     | 308/660 [06:26<05:16,  1.11it/s]
[A

{'eval_loss': 1.0709553956985474, 'eval_accuracy': 0.44250363901018924, 'eval_f1': 0.40956402125578717, 'eval_precision': 0.45275840886088276, 'eval_recall': 0.4351247523319753, 'eval_runtime': 4.3425, 'eval_samples_per_second': 316.407, 'eval_steps_per_second': 1.382, 'epoch': 14.0}


 50%|█████     | 330/660 [06:49<04:56,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 50%|█████     | 330/660 [06:54<04:56,  1.11it/s]
[A

{'eval_loss': 1.0653513669967651, 'eval_accuracy': 0.44395924308588064, 'eval_f1': 0.42619873279591464, 'eval_precision': 0.4498588996577055, 'eval_recall': 0.4390803446203166, 'eval_runtime': 4.3436, 'eval_samples_per_second': 316.331, 'eval_steps_per_second': 1.381, 'epoch': 15.0}


 53%|█████▎    | 352/660 [07:17<04:36,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 53%|█████▎    | 352/660 [07:21<04:36,  1.11it/s]
[A

{'eval_loss': 1.0718618631362915, 'eval_accuracy': 0.44468704512372637, 'eval_f1': 0.4298022231538982, 'eval_precision': 0.44811073781674865, 'eval_recall': 0.44033343351365795, 'eval_runtime': 4.3434, 'eval_samples_per_second': 316.34, 'eval_steps_per_second': 1.381, 'epoch': 16.0}


 57%|█████▋    | 374/660 [07:45<04:17,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 57%|█████▋    | 374/660 [07:49<04:17,  1.11it/s]
[A

{'eval_loss': 1.057857632637024, 'eval_accuracy': 0.47016011644832606, 'eval_f1': 0.46505337088707077, 'eval_precision': 0.4670310877831563, 'eval_recall': 0.46710217279992877, 'eval_runtime': 4.3433, 'eval_samples_per_second': 316.349, 'eval_steps_per_second': 1.381, 'epoch': 17.0}


 60%|██████    | 396/660 [08:12<03:57,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 60%|██████    | 396/660 [08:17<03:57,  1.11it/s]
[A

{'eval_loss': 1.070281982421875, 'eval_accuracy': 0.438136826783115, 'eval_f1': 0.4265033249882892, 'eval_precision': 0.4383403246387146, 'eval_recall': 0.4340249949909838, 'eval_runtime': 4.3453, 'eval_samples_per_second': 316.203, 'eval_steps_per_second': 1.381, 'epoch': 18.0}


 63%|██████▎   | 418/660 [08:40<03:37,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 63%|██████▎   | 418/660 [08:45<03:37,  1.11it/s]
[A

{'eval_loss': 1.0614190101623535, 'eval_accuracy': 0.4606986899563319, 'eval_f1': 0.454876631884608, 'eval_precision': 0.4604079871644487, 'eval_recall': 0.4577649491306574, 'eval_runtime': 4.3445, 'eval_samples_per_second': 316.261, 'eval_steps_per_second': 1.381, 'epoch': 19.0}


 67%|██████▋   | 440/660 [09:08<03:17,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 67%|██████▋   | 440/660 [09:12<03:17,  1.11it/s]
[A

{'eval_loss': 1.069000005722046, 'eval_accuracy': 0.44541484716157204, 'eval_f1': 0.4355451816444724, 'eval_precision': 0.445080142988411, 'eval_recall': 0.44136014270130675, 'eval_runtime': 4.3473, 'eval_samples_per_second': 316.056, 'eval_steps_per_second': 1.38, 'epoch': 20.0}


 70%|███████   | 462/660 [09:36<02:58,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 70%|███████   | 462/660 [09:40<02:58,  1.11it/s]
[A

{'eval_loss': 1.066681146621704, 'eval_accuracy': 0.45269286754002913, 'eval_f1': 0.4449794104650979, 'eval_precision': 0.45332419345560254, 'eval_recall': 0.4488859113070193, 'eval_runtime': 4.3473, 'eval_samples_per_second': 316.056, 'eval_steps_per_second': 1.38, 'epoch': 21.0}


 73%|███████▎  | 484/660 [10:03<02:38,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 73%|███████▎  | 484/660 [10:08<02:38,  1.11it/s]
[A

{'eval_loss': 1.071885585784912, 'eval_accuracy': 0.450509461426492, 'eval_f1': 0.4428626843373387, 'eval_precision': 0.45043628973475186, 'eval_recall': 0.44706109886684925, 'eval_runtime': 4.3464, 'eval_samples_per_second': 316.127, 'eval_steps_per_second': 1.38, 'epoch': 22.0}


 76%|███████▌  | 500/660 [10:25<02:52,  1.08s/it]
 76%|███████▌  | 500/660 [10:25<02:52,  1.08s/it]

{'loss': 0.9822, 'grad_norm': 0.5809879302978516, 'learning_rate': 0.00024242424242424245, 'epoch': 22.73}


 77%|███████▋  | 506/660 [10:31<02:18,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 77%|███████▋  | 506/660 [10:35<02:18,  1.11it/s]
[A

{'eval_loss': 1.0680989027023315, 'eval_accuracy': 0.4606986899563319, 'eval_f1': 0.4522100643116218, 'eval_precision': 0.4569407526604022, 'eval_recall': 0.4563890002003606, 'eval_runtime': 4.36, 'eval_samples_per_second': 315.14, 'eval_steps_per_second': 1.376, 'epoch': 23.0}


 80%|████████  | 528/660 [10:59<01:58,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 80%|████████  | 528/660 [11:03<01:58,  1.11it/s]
[A

{'eval_loss': 1.0777695178985596, 'eval_accuracy': 0.44614264919941776, 'eval_f1': 0.4412043869002779, 'eval_precision': 0.4494080364388955, 'eval_recall': 0.4442570237538681, 'eval_runtime': 4.3507, 'eval_samples_per_second': 315.811, 'eval_steps_per_second': 1.379, 'epoch': 24.0}


 83%|████████▎ | 550/660 [11:26<01:39,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 83%|████████▎ | 550/660 [11:31<01:39,  1.11it/s]
[A

{'eval_loss': 1.077375888824463, 'eval_accuracy': 0.4344978165938865, 'eval_f1': 0.4177115606714504, 'eval_precision': 0.4377741528608694, 'eval_recall': 0.42931471092410783, 'eval_runtime': 4.354, 'eval_samples_per_second': 315.571, 'eval_steps_per_second': 1.378, 'epoch': 25.0}


 87%|████████▋ | 572/660 [11:54<01:19,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 87%|████████▋ | 572/660 [11:58<01:19,  1.11it/s]
[A

{'eval_loss': 1.0751562118530273, 'eval_accuracy': 0.4497816593886463, 'eval_f1': 0.4425483468051336, 'eval_precision': 0.4513580145550608, 'eval_recall': 0.4464704523698213, 'eval_runtime': 4.3607, 'eval_samples_per_second': 315.088, 'eval_steps_per_second': 1.376, 'epoch': 26.0}


 90%|█████████ | 594/660 [12:22<00:59,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 90%|█████████ | 594/660 [12:26<00:59,  1.11it/s]
[A

{'eval_loss': 1.07709801197052, 'eval_accuracy': 0.44468704512372637, 'eval_f1': 0.43546428938947224, 'eval_precision': 0.4445888386145704, 'eval_recall': 0.4408732941516953, 'eval_runtime': 4.3609, 'eval_samples_per_second': 315.075, 'eval_steps_per_second': 1.376, 'epoch': 27.0}


 93%|█████████▎| 616/660 [12:50<00:39,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 93%|█████████▎| 616/660 [12:54<00:39,  1.11it/s]
[A

{'eval_loss': 1.0738681554794312, 'eval_accuracy': 0.4556040756914119, 'eval_f1': 0.4536942463343352, 'eval_precision': 0.4549871353052511, 'eval_recall': 0.45398133306618577, 'eval_runtime': 4.37, 'eval_samples_per_second': 314.415, 'eval_steps_per_second': 1.373, 'epoch': 28.0}


 97%|█████████▋| 638/660 [13:17<00:19,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
 97%|█████████▋| 638/660 [13:22<00:19,  1.11it/s]
[A

{'eval_loss': 1.0795832872390747, 'eval_accuracy': 0.4468704512372635, 'eval_f1': 0.44042858038163873, 'eval_precision': 0.4488820566550586, 'eval_recall': 0.4440342616710078, 'eval_runtime': 4.3689, 'eval_samples_per_second': 314.499, 'eval_steps_per_second': 1.373, 'epoch': 29.0}


100%|██████████| 660/660 [13:45<00:00,  1.11it/s]
[A
[A
[A
[A
[A
[A

[A[A                                       
                                                 
100%|██████████| 660/660 [13:50<00:00,  1.11it/s]
[A

{'eval_loss': 1.0762273073196411, 'eval_accuracy': 0.44905385735080056, 'eval_f1': 0.44523288906185643, 'eval_precision': 0.44846041987413204, 'eval_recall': 0.44673857387742383, 'eval_runtime': 4.3647, 'eval_samples_per_second': 314.796, 'eval_steps_per_second': 1.375, 'epoch': 30.0}



100%|██████████| 660/660 [13:50<00:00,  1.26s/it]

{'train_runtime': 830.6518, 'train_samples_per_second': 198.495, 'train_steps_per_second': 0.795, 'train_loss': 0.9703804247307055, 'epoch': 30.0}





TrainOutput(global_step=660, training_loss=0.9703804247307055, metrics={'train_runtime': 830.6518, 'train_samples_per_second': 198.495, 'train_steps_per_second': 0.795, 'total_flos': 1.011027844695312e+16, 'train_loss': 0.9703804247307055, 'epoch': 30.0})

In [None]:
# Get the best model
model = AutoModelForSequenceClassification.from_pretrained('roberta-base-sentence-relation/checkpoint-374', num_labels=num_labels, id2label=id2label, label2id=label2id)

# Evaluate the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=bert_train_dataset,
    eval_dataset=bert_test_dataset,
    tokenizer=None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

100%|██████████| 43/43 [00:11<00:00,  3.61it/s]


{'eval_loss': 1.0233200788497925,
 'eval_accuracy': 0.5029112081513828,
 'eval_f1': 0.49280351950455037,
 'eval_precision': 0.5022752603866406,
 'eval_recall': 0.4993200260468844,
 'eval_runtime': 12.2102,
 'eval_samples_per_second': 112.529,
 'eval_steps_per_second': 3.522}