In [20]:
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.tree import plot_tree

In [21]:
raw_data      = pd.read_csv('../data/sentence-relations/train.csv', index_col='id')
raw_submissions = pd.read_csv('../data/sentence-relations/test.csv', index_col='id')

# Only take english
raw_data = raw_data[raw_data['lang_abv'] == 'en']

training_data, test_data = train_test_split(raw_data, test_size=0.2, random_state=42)

training_data.head()

Unnamed: 0_level_0,premise,hypothesis,lang_abv,language,label
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
b6e01c1a07,"Also, the Holy Family are said to have shelter...",The Holy family spent a total of three days here.,en,English,1
b8fa1a0044,Participants generally viewed the new internal...,Those organizations affected by the Sarbanes-O...,en,English,0
c828f51ef6,With a little practice almost anyone can flip ...,Practicing lets you do anything you put your m...,en,English,1
b2c98d5a99,More reserved and remote but a better administ...,The uncle had no match in administration; cert...,en,English,2
9cd35fee05,The company later told us that it had disconti...,The company later told us that it had enhanced...,en,English,2


In [22]:
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer

roberta_tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased", truncation=True)
data_collator = DataCollatorWithPadding(tokenizer=roberta_tokenizer)



In [23]:
from torch.utils.data import Dataset, DataLoader

class BertSentenceDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

        self.X = self.data['premise'] + '[SEP]' + self.data['hypothesis']
        self.y = self.data['label']

        self.encoded = self.tokenizer(self.X.tolist(), padding=True, truncation=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encoded['input_ids'][idx],
            'attention_mask': self.encoded['attention_mask'][idx],
            'label': self.y.iloc[idx]
        }

bert_train_dataset = BertSentenceDataset(training_data, roberta_tokenizer)
bert_test_dataset = BertSentenceDataset(test_data, roberta_tokenizer)

In [33]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
num_labels = len(id2label)

model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased", num_labels=num_labels, id2label=id2label, label2id=label2id)

# Freeze the model except
# modfication_layers = ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']

# for name, param in model.named_parameters():
#     if not any(layer in name for layer in modfication_layers):
#         param.requires_grad = False

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [34]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

training_args = TrainingArguments(
    output_dir="roberta-base-sentence-relation",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=20,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    # fp16=True
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [36]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=bert_train_dataset,
    eval_dataset=bert_test_dataset,
    tokenizer=roberta_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

 10%|█         | 707/6880 [02:55<25:34,  4.02it/s]
  5%|▌         | 344/6880 [01:11<19:36,  5.56it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                  
  5%|▌         | 344/6880 [01:15<19:36,  5.56it/s]
[A

{'eval_loss': 0.9307237863540649, 'eval_accuracy': 0.5684133915574964, 'eval_f1': 0.5657818453572135, 'eval_precision': 0.5716503547543336, 'eval_recall': 0.5707506845655513, 'eval_runtime': 4.4673, 'eval_samples_per_second': 307.571, 'eval_steps_per_second': 19.251, 'epoch': 1.0}


  7%|▋         | 500/6880 [01:49<22:10,  4.79it/s]  
  7%|▋         | 501/6880 [01:49<22:26,  4.74it/s]

{'loss': 0.9845, 'grad_norm': 8.003182411193848, 'learning_rate': 9.273255813953488e-06, 'epoch': 1.45}


 10%|█         | 688/6880 [02:28<18:34,  5.56it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                  
 10%|█         | 688/6880 [02:32<18:34,  5.56it/s]
[A

{'eval_loss': 0.852742075920105, 'eval_accuracy': 0.6120815138282387, 'eval_f1': 0.6120467797678604, 'eval_precision': 0.6118111721737428, 'eval_recall': 0.6141027460540083, 'eval_runtime': 4.499, 'eval_samples_per_second': 305.402, 'eval_steps_per_second': 19.115, 'epoch': 2.0}


 15%|█▍        | 1000/6880 [03:39<20:18,  4.83it/s] 
 15%|█▍        | 1001/6880 [03:39<20:24,  4.80it/s]

{'loss': 0.7425, 'grad_norm': 11.460726737976074, 'learning_rate': 8.546511627906978e-06, 'epoch': 2.91}


 15%|█▌        | 1032/6880 [03:45<17:32,  5.56it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                   
 15%|█▌        | 1032/6880 [03:50<17:32,  5.56it/s]
[A

{'eval_loss': 0.8634228706359863, 'eval_accuracy': 0.6157205240174672, 'eval_f1': 0.6153631204335429, 'eval_precision': 0.6162466741429734, 'eval_recall': 0.6176825229858189, 'eval_runtime': 4.4835, 'eval_samples_per_second': 306.456, 'eval_steps_per_second': 19.181, 'epoch': 3.0}


 20%|██        | 1376/6880 [05:02<16:32,  5.55it/s]  
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                   
 20%|██        | 1376/6880 [05:07<16:32,  5.55it/s]
[A

{'eval_loss': 0.9376757144927979, 'eval_accuracy': 0.6251819505094615, 'eval_f1': 0.623590811825844, 'eval_precision': 0.6263742633260176, 'eval_recall': 0.6257325686235223, 'eval_runtime': 4.4785, 'eval_samples_per_second': 306.796, 'eval_steps_per_second': 19.203, 'epoch': 4.0}


 22%|██▏       | 1500/6880 [05:33<18:41,  4.80it/s]  
 22%|██▏       | 1501/6880 [05:34<18:48,  4.77it/s]

{'loss': 0.5314, 'grad_norm': 20.382686614990234, 'learning_rate': 7.819767441860465e-06, 'epoch': 4.36}


 25%|██▌       | 1720/6880 [06:19<15:33,  5.53it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                   
 25%|██▌       | 1720/6880 [06:24<15:33,  5.53it/s]
[A

{'eval_loss': 1.0134062767028809, 'eval_accuracy': 0.6149927219796215, 'eval_f1': 0.6132992210987466, 'eval_precision': 0.6201217364993419, 'eval_recall': 0.6180959059640686, 'eval_runtime': 4.4961, 'eval_samples_per_second': 305.596, 'eval_steps_per_second': 19.128, 'epoch': 5.0}


 29%|██▉       | 2000/6880 [07:23<17:38,  4.61it/s]  
 29%|██▉       | 2001/6880 [07:23<17:42,  4.59it/s]

{'loss': 0.3849, 'grad_norm': 14.614089012145996, 'learning_rate': 7.0930232558139545e-06, 'epoch': 5.81}


 30%|███       | 2064/6880 [07:36<14:33,  5.51it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                   
 30%|███       | 2064/6880 [07:41<14:33,  5.51it/s]
[A

{'eval_loss': 1.1026487350463867, 'eval_accuracy': 0.6259097525473072, 'eval_f1': 0.626082351100019, 'eval_precision': 0.626807211430186, 'eval_recall': 0.6256684253879204, 'eval_runtime': 4.5575, 'eval_samples_per_second': 301.482, 'eval_steps_per_second': 18.87, 'epoch': 6.0}


 35%|███▌      | 2408/6880 [08:53<13:27,  5.54it/s]  
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                   
 35%|███▌      | 2408/6880 [08:58<13:27,  5.54it/s]
[A

{'eval_loss': 1.2163230180740356, 'eval_accuracy': 0.62882096069869, 'eval_f1': 0.6288706742015587, 'eval_precision': 0.6301812330190839, 'eval_recall': 0.630616637725684, 'eval_runtime': 4.4836, 'eval_samples_per_second': 306.448, 'eval_steps_per_second': 19.181, 'epoch': 7.0}


 36%|███▋      | 2500/6880 [09:18<15:23,  4.74it/s]  
 36%|███▋      | 2501/6880 [09:18<15:31,  4.70it/s]

{'loss': 0.25, 'grad_norm': 17.826583862304688, 'learning_rate': 6.366279069767443e-06, 'epoch': 7.27}


 40%|████      | 2752/6880 [10:11<12:24,  5.55it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                   
 40%|████      | 2752/6880 [10:15<12:24,  5.55it/s]
[A

{'eval_loss': 1.4262678623199463, 'eval_accuracy': 0.6251819505094615, 'eval_f1': 0.6254295002383253, 'eval_precision': 0.6261190923762895, 'eval_recall': 0.6262139907388856, 'eval_runtime': 4.5027, 'eval_samples_per_second': 305.15, 'eval_steps_per_second': 19.1, 'epoch': 8.0}


 44%|████▎     | 3000/6880 [11:08<13:38,  4.74it/s]  
 44%|████▎     | 3001/6880 [11:08<13:48,  4.68it/s]

{'loss': 0.1641, 'grad_norm': 8.158695220947266, 'learning_rate': 5.6395348837209305e-06, 'epoch': 8.72}


 45%|████▌     | 3096/6880 [11:28<11:21,  5.55it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                   
 45%|████▌     | 3096/6880 [11:32<11:21,  5.55it/s]
[A

{'eval_loss': 1.653538703918457, 'eval_accuracy': 0.6251819505094615, 'eval_f1': 0.6203657204805063, 'eval_precision': 0.6287468047363198, 'eval_recall': 0.6242358467463657, 'eval_runtime': 4.5371, 'eval_samples_per_second': 302.84, 'eval_steps_per_second': 18.955, 'epoch': 9.0}


 50%|█████     | 3440/6880 [12:45<10:22,  5.53it/s]  
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                   
 50%|█████     | 3440/6880 [12:50<10:22,  5.53it/s]
[A

{'eval_loss': 1.8067269325256348, 'eval_accuracy': 0.6106259097525473, 'eval_f1': 0.6100819830724215, 'eval_precision': 0.6164291894445189, 'eval_recall': 0.6090714675304437, 'eval_runtime': 4.5225, 'eval_samples_per_second': 303.813, 'eval_steps_per_second': 19.016, 'epoch': 10.0}


 50%|█████     | 3464/6880 [12:55<11:53,  4.78it/s]  

KeyboardInterrupt: 

In [37]:
# Get the best model
model = AutoModelForSequenceClassification.from_pretrained('roberta-base-sentence-relation/checkpoint-2752', num_labels=num_labels, id2label=id2label, label2id=label2id)

# Evaluate the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=bert_train_dataset,
    eval_dataset=bert_test_dataset,
    tokenizer=None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.evaluate()


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 86/86 [00:04<00:00, 19.88it/s]


{'eval_loss': 1.4262678623199463,
 'eval_accuracy': 0.6251819505094615,
 'eval_f1': 0.6254295002383253,
 'eval_precision': 0.6261190923762895,
 'eval_recall': 0.6262139907388856,
 'eval_runtime': 4.3755,
 'eval_samples_per_second': 314.019,
 'eval_steps_per_second': 19.655}