In [1]:
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.tree import plot_tree

In [2]:
raw_data      = pd.read_csv('../data/sentence-relations/train.csv', index_col='id')
raw_submissions = pd.read_csv('../data/sentence-relations/test.csv', index_col='id')

# Remove chinese
# raw_data = raw_data[raw_data['lang_abv'] != 'zh']
# raw_data = raw_data[raw_data['lang_abv'] != 'th']

# Only take english
raw_data = raw_data[raw_data['lang_abv'] == 'en']

training_data, test_data = train_test_split(raw_data, test_size=0.2, random_state=42)

training_data.head()

Unnamed: 0_level_0,premise,hypothesis,lang_abv,language,label
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
b6e01c1a07,"Also, the Holy Family are said to have shelter...",The Holy family spent a total of three days here.,en,English,1
b8fa1a0044,Participants generally viewed the new internal...,Those organizations affected by the Sarbanes-O...,en,English,0
c828f51ef6,With a little practice almost anyone can flip ...,Practicing lets you do anything you put your m...,en,English,1
b2c98d5a99,More reserved and remote but a better administ...,The uncle had no match in administration; cert...,en,English,2
9cd35fee05,The company later told us that it had disconti...,The company later told us that it had enhanced...,en,English,2


In [3]:
from transformers import RobertaTokenizer, RobertaModel
from transformers import DataCollatorWithPadding

roberta_tokenizer = RobertaTokenizer.from_pretrained('FacebookAI/roberta-base')
data_collator = DataCollatorWithPadding(tokenizer=roberta_tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from torch.utils.data import Dataset, DataLoader

class SentenceRelationTransformerDataset(Dataset):

    def __init__(self, data, tokenizer):
        self.data               = data
        self.sentence_relations = []
        self.labels             = []
        self.tokenizer          = tokenizer

        self.perform_preprocessing()

    def __len__(self):
        return len(self.data)

    def get_max_input_length(self):

        max_length = 0

        for sentence_pair in self.sentence_relations:
            max_length = max(max_length, len(sentence_pair[0]), len(sentence_pair[1]))

        return max_length

    def perform_preprocessing(self):
        self.sentence_relations = []
        self.labels             = []

        for i in range(len(self.data)):
            premise    = self.data.iloc[i]['premise']
            hypothesis = self.data.iloc[i]['hypothesis']

            input_string = f"Premise: {premise} | Hypothesis: {hypothesis}"

            input_tokenized = self.tokenizer(input_string, padding='max_length', max_length=512, truncation=True)

            label = self.data.iloc[i]['label']

            self.labels.append(label)
            self.sentence_relations.append(input_tokenized)

    def __getitem__(self, idx):
        temp_dict = self.sentence_relations[idx]
        temp_dict['label'] = self.labels[idx]

        return temp_dict


train_dataset = SentenceRelationTransformerDataset(training_data, roberta_tokenizer)
test_dataset  = SentenceRelationTransformerDataset(test_data, roberta_tokenizer)

In [5]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
num_labels = len(id2label)

model = AutoModelForSequenceClassification.from_pretrained('FacebookAI/roberta-base', num_labels=num_labels, id2label=id2label, label2id=label2id)

# Freeze the model except
modfication_layers = ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']

for name, param in model.named_parameters():
    if not any(layer in name for layer in modfication_layers):
        param.requires_grad = False

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

training_args = TrainingArguments(
    output_dir="roberta-base-sentence-relation",
    learning_rate=1e-3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=30,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=True
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [7]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
  3%|▎         | 172/5160 [01:09<25:27,  3.26it/s]

{'eval_loss': 1.0925101041793823, 'eval_accuracy': 0.33478893740902477, 'eval_f1': 0.16721192293711376, 'eval_precision': 0.11159631246967493, 'eval_recall': 0.3333333333333333, 'eval_runtime': 12.2983, 'eval_samples_per_second': 111.723, 'eval_steps_per_second': 3.496, 'epoch': 1.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
  7%|▋         | 344/5160 [02:19<24:40,  3.25it/s]

{'eval_loss': 1.0978069305419922, 'eval_accuracy': 0.3777292576419214, 'eval_f1': 0.26338009806944, 'eval_precision': 0.28906787402200246, 'eval_recall': 0.3733997584541063, 'eval_runtime': 12.3264, 'eval_samples_per_second': 111.468, 'eval_steps_per_second': 3.488, 'epoch': 2.0}


 10%|▉         | 500/5160 [03:12<25:53,  3.00it/s]  

{'loss': 1.1061, 'grad_norm': 1.6466537714004517, 'learning_rate': 0.0009031007751937985, 'epoch': 2.91}


                                                  
 10%|█         | 516/5160 [03:29<23:50,  3.25it/s]

{'eval_loss': 1.0871872901916504, 'eval_accuracy': 0.3500727802037846, 'eval_f1': 0.25730843989087887, 'eval_precision': 0.465828796721849, 'eval_recall': 0.363346941160756, 'eval_runtime': 12.3821, 'eval_samples_per_second': 110.966, 'eval_steps_per_second': 3.473, 'epoch': 3.0}


                                                    
 13%|█▎        | 688/5160 [04:40<23:08,  3.22it/s]

{'eval_loss': 1.0768485069274902, 'eval_accuracy': 0.4148471615720524, 'eval_f1': 0.3632733743042728, 'eval_precision': 0.45411326667921975, 'eval_recall': 0.4229169449453461, 'eval_runtime': 12.5038, 'eval_samples_per_second': 109.887, 'eval_steps_per_second': 3.439, 'epoch': 4.0}


                                                    
 17%|█▋        | 860/5160 [05:51<22:24,  3.20it/s]

{'eval_loss': 1.0721232891082764, 'eval_accuracy': 0.39446870451237265, 'eval_f1': 0.33303149552900113, 'eval_precision': 0.45031999692794294, 'eval_recall': 0.39432061043211114, 'eval_runtime': 12.5026, 'eval_samples_per_second': 109.897, 'eval_steps_per_second': 3.439, 'epoch': 5.0}


 19%|█▉        | 1000/5160 [06:39<23:25,  2.96it/s] 

{'loss': 1.0779, 'grad_norm': 0.4983506500720978, 'learning_rate': 0.0008062015503875969, 'epoch': 5.81}


                                                   
 20%|██        | 1032/5160 [07:02<21:24,  3.21it/s]

{'eval_loss': 1.0860074758529663, 'eval_accuracy': 0.3791848617176128, 'eval_f1': 0.2563625122688727, 'eval_precision': 0.4588531875384465, 'eval_recall': 0.3645023542376277, 'eval_runtime': 12.5446, 'eval_samples_per_second': 109.529, 'eval_steps_per_second': 3.428, 'epoch': 6.0}


                                                     
 23%|██▎       | 1204/5160 [08:13<20:27,  3.22it/s]

{'eval_loss': 1.0504542589187622, 'eval_accuracy': 0.462882096069869, 'eval_f1': 0.46211445081064073, 'eval_precision': 0.4760687879885362, 'eval_recall': 0.4659423072641866, 'eval_runtime': 12.5018, 'eval_samples_per_second': 109.904, 'eval_steps_per_second': 3.44, 'epoch': 7.0}


                                                     
 27%|██▋       | 1376/5160 [09:24<19:31,  3.23it/s]

{'eval_loss': 1.0468919277191162, 'eval_accuracy': 0.4606986899563319, 'eval_f1': 0.4195922481468528, 'eval_precision': 0.4568755169561621, 'eval_recall': 0.45245747901778754, 'eval_runtime': 12.4975, 'eval_samples_per_second': 109.942, 'eval_steps_per_second': 3.441, 'epoch': 8.0}


 29%|██▉       | 1500/5160 [10:06<20:44,  2.94it/s]  

{'loss': 1.0653, 'grad_norm': 1.9449361562728882, 'learning_rate': 0.0007093023255813954, 'epoch': 8.72}


                                                   
 30%|███       | 1548/5160 [10:35<18:45,  3.21it/s]

{'eval_loss': 1.0493340492248535, 'eval_accuracy': 0.44905385735080056, 'eval_f1': 0.3902055237178801, 'eval_precision': 0.473864358002182, 'eval_recall': 0.43736072152096, 'eval_runtime': 12.5139, 'eval_samples_per_second': 109.798, 'eval_steps_per_second': 3.436, 'epoch': 9.0}


                                                     
 33%|███▎      | 1720/5160 [11:46<17:44,  3.23it/s]

{'eval_loss': 1.0529671907424927, 'eval_accuracy': 0.44468704512372637, 'eval_f1': 0.406555960345998, 'eval_precision': 0.484369141708671, 'eval_recall': 0.43588751419221267, 'eval_runtime': 12.4761, 'eval_samples_per_second': 110.131, 'eval_steps_per_second': 3.447, 'epoch': 10.0}


                                                     
 37%|███▋      | 1892/5160 [12:57<16:57,  3.21it/s]

{'eval_loss': 1.036966323852539, 'eval_accuracy': 0.46797671033478894, 'eval_f1': 0.4474594136016785, 'eval_precision': 0.47094498422865133, 'eval_recall': 0.46086483447984145, 'eval_runtime': 12.5184, 'eval_samples_per_second': 109.759, 'eval_steps_per_second': 3.435, 'epoch': 11.0}


 39%|███▉      | 2000/5160 [13:34<17:46,  2.96it/s]  

{'loss': 1.0537, 'grad_norm': 2.1915581226348877, 'learning_rate': 0.0006124031007751938, 'epoch': 11.63}


                                                   
 40%|████      | 2064/5160 [14:08<16:00,  3.22it/s]

{'eval_loss': 1.0472009181976318, 'eval_accuracy': 0.44759825327510916, 'eval_f1': 0.4128674529931333, 'eval_precision': 0.4653665039974812, 'eval_recall': 0.44398848482824643, 'eval_runtime': 12.5049, 'eval_samples_per_second': 109.877, 'eval_steps_per_second': 3.439, 'epoch': 12.0}


                                                     
 43%|████▎     | 2236/5160 [15:19<15:14,  3.20it/s]

{'eval_loss': 1.0354187488555908, 'eval_accuracy': 0.4745269286754003, 'eval_f1': 0.4735930714295771, 'eval_precision': 0.4775199929150886, 'eval_recall': 0.4752074567554932, 'eval_runtime': 12.4849, 'eval_samples_per_second': 110.053, 'eval_steps_per_second': 3.444, 'epoch': 13.0}


                                                     
 47%|████▋     | 2408/5160 [16:30<14:14,  3.22it/s]

{'eval_loss': 1.0347530841827393, 'eval_accuracy': 0.47161572052401746, 'eval_f1': 0.4363177123470831, 'eval_precision': 0.4696875720321015, 'eval_recall': 0.46458583784144797, 'eval_runtime': 12.4881, 'eval_samples_per_second': 110.024, 'eval_steps_per_second': 3.443, 'epoch': 14.0}


 48%|████▊     | 2500/5160 [17:02<14:58,  2.96it/s]  

{'loss': 1.0471, 'grad_norm': 1.0434789657592773, 'learning_rate': 0.0005156976744186047, 'epoch': 14.53}


                                                   
 50%|█████     | 2580/5160 [17:41<13:19,  3.23it/s]

{'eval_loss': 1.040223479270935, 'eval_accuracy': 0.4759825327510917, 'eval_f1': 0.47102585955217896, 'eval_precision': 0.4932364167040862, 'eval_recall': 0.47956502259622885, 'eval_runtime': 12.4912, 'eval_samples_per_second': 109.997, 'eval_steps_per_second': 3.442, 'epoch': 15.0}


                                                     
 53%|█████▎    | 2752/5160 [18:52<12:26,  3.22it/s]

{'eval_loss': 1.030613899230957, 'eval_accuracy': 0.48398835516739447, 'eval_f1': 0.475678757626213, 'eval_precision': 0.48147915507346317, 'eval_recall': 0.48100594959816556, 'eval_runtime': 12.4693, 'eval_samples_per_second': 110.191, 'eval_steps_per_second': 3.448, 'epoch': 16.0}


                                                     
 57%|█████▋    | 2924/5160 [20:03<11:31,  3.23it/s]

{'eval_loss': 1.0451891422271729, 'eval_accuracy': 0.4570596797671033, 'eval_f1': 0.4269580104367123, 'eval_precision': 0.49243113553902756, 'eval_recall': 0.4494369030922327, 'eval_runtime': 12.4505, 'eval_samples_per_second': 110.357, 'eval_steps_per_second': 3.454, 'epoch': 17.0}


 58%|█████▊    | 3000/5160 [20:29<12:03,  2.98it/s]  

{'loss': 1.0382, 'grad_norm': 1.8171579837799072, 'learning_rate': 0.0004187984496124031, 'epoch': 17.44}


                                                   
 60%|██████    | 3096/5160 [21:14<10:39,  3.23it/s]

{'eval_loss': 1.0360091924667358, 'eval_accuracy': 0.4606986899563319, 'eval_f1': 0.4145372375101699, 'eval_precision': 0.46069939286019684, 'eval_recall': 0.45242533783031674, 'eval_runtime': 12.4848, 'eval_samples_per_second': 110.054, 'eval_steps_per_second': 3.444, 'epoch': 18.0}


                                                     
 63%|██████▎   | 3268/5160 [22:25<09:50,  3.21it/s]

{'eval_loss': 1.0360023975372314, 'eval_accuracy': 0.4606986899563319, 'eval_f1': 0.4361381898990193, 'eval_precision': 0.4814743567678062, 'eval_recall': 0.45297883412364476, 'eval_runtime': 12.4979, 'eval_samples_per_second': 109.939, 'eval_steps_per_second': 3.441, 'epoch': 19.0}


                                                     
 67%|██████▋   | 3440/5160 [23:35<08:51,  3.24it/s]

{'eval_loss': 1.0307884216308594, 'eval_accuracy': 0.47889374090247455, 'eval_f1': 0.4688028185550139, 'eval_precision': 0.4871895721453243, 'eval_recall': 0.47466370021594423, 'eval_runtime': 12.39, 'eval_samples_per_second': 110.896, 'eval_steps_per_second': 3.471, 'epoch': 20.0}


 68%|██████▊   | 3500/5160 [23:56<09:16,  2.98it/s]  

{'loss': 1.0369, 'grad_norm': 1.2760347127914429, 'learning_rate': 0.00032189922480620154, 'epoch': 20.35}


                                                   
 70%|███████   | 3612/5160 [24:46<07:58,  3.23it/s]

{'eval_loss': 1.0306378602981567, 'eval_accuracy': 0.48326055312954874, 'eval_f1': 0.47329111409835556, 'eval_precision': 0.48453698348435187, 'eval_recall': 0.4813909481511165, 'eval_runtime': 12.4757, 'eval_samples_per_second': 110.134, 'eval_steps_per_second': 3.447, 'epoch': 21.0}


                                                     
 73%|███████▎  | 3784/5160 [25:57<07:06,  3.23it/s]

{'eval_loss': 1.0299369096755981, 'eval_accuracy': 0.4861717612809316, 'eval_f1': 0.48616905031536706, 'eval_precision': 0.49138947099799096, 'eval_recall': 0.4877963667935617, 'eval_runtime': 12.4926, 'eval_samples_per_second': 109.985, 'eval_steps_per_second': 3.442, 'epoch': 22.0}


                                                     
 77%|███████▋  | 3956/5160 [27:08<06:12,  3.23it/s]

{'eval_loss': 1.0281200408935547, 'eval_accuracy': 0.48326055312954874, 'eval_f1': 0.48364937708086636, 'eval_precision': 0.4900169242489952, 'eval_recall': 0.4844675972305706, 'eval_runtime': 12.4956, 'eval_samples_per_second': 109.958, 'eval_steps_per_second': 3.441, 'epoch': 23.0}


 78%|███████▊  | 4000/5160 [27:23<06:32,  2.95it/s]  

{'loss': 1.0341, 'grad_norm': 0.6226695775985718, 'learning_rate': 0.00022500000000000002, 'epoch': 23.26}


                                                   
 80%|████████  | 4128/5160 [28:19<05:21,  3.21it/s]

{'eval_loss': 1.025191068649292, 'eval_accuracy': 0.49417758369723436, 'eval_f1': 0.4896147613741488, 'eval_precision': 0.4966456950460794, 'eval_recall': 0.4913003127852356, 'eval_runtime': 12.5708, 'eval_samples_per_second': 109.301, 'eval_steps_per_second': 3.421, 'epoch': 24.0}


                                                     
 83%|████████▎ | 4300/5160 [29:30<04:29,  3.20it/s]

{'eval_loss': 1.0250341892242432, 'eval_accuracy': 0.4861717612809316, 'eval_f1': 0.4640038651191438, 'eval_precision': 0.48830244112605303, 'eval_recall': 0.479307197399764, 'eval_runtime': 12.5284, 'eval_samples_per_second': 109.671, 'eval_steps_per_second': 3.432, 'epoch': 25.0}


                                                     
 87%|████████▋ | 4472/5160 [30:41<03:34,  3.20it/s]

{'eval_loss': 1.0240832567214966, 'eval_accuracy': 0.49417758369723436, 'eval_f1': 0.4892720487727302, 'eval_precision': 0.49179695581847244, 'eval_recall': 0.492330500456377, 'eval_runtime': 12.5242, 'eval_samples_per_second': 109.708, 'eval_steps_per_second': 3.433, 'epoch': 26.0}


 87%|████████▋ | 4500/5160 [30:51<03:42,  2.97it/s]

{'loss': 1.0276, 'grad_norm': 1.1962478160858154, 'learning_rate': 0.00012810077519379844, 'epoch': 26.16}


                                                   
 90%|█████████ | 4644/5160 [31:52<02:40,  3.22it/s]

{'eval_loss': 1.022760033607483, 'eval_accuracy': 0.49053857350800584, 'eval_f1': 0.4857373881114888, 'eval_precision': 0.4867092806999551, 'eval_recall': 0.4879467764197778, 'eval_runtime': 12.4491, 'eval_samples_per_second': 110.37, 'eval_steps_per_second': 3.454, 'epoch': 27.0}


                                                   
 93%|█████████▎| 4816/5160 [33:03<01:47,  3.21it/s]

{'eval_loss': 1.0233200788497925, 'eval_accuracy': 0.5029112081513828, 'eval_f1': 0.49280351950455037, 'eval_precision': 0.5022752603866406, 'eval_recall': 0.4993200260468844, 'eval_runtime': 12.4947, 'eval_samples_per_second': 109.967, 'eval_steps_per_second': 3.441, 'epoch': 28.0}


                                                   
 97%|█████████▋| 4988/5160 [34:14<00:53,  3.22it/s]

{'eval_loss': 1.021627426147461, 'eval_accuracy': 0.4883551673944687, 'eval_f1': 0.4840031746885845, 'eval_precision': 0.48508767703587913, 'eval_recall': 0.4857087201406977, 'eval_runtime': 12.4874, 'eval_samples_per_second': 110.031, 'eval_steps_per_second': 3.443, 'epoch': 29.0}


 97%|█████████▋| 5000/5160 [34:19<01:06,  2.42it/s]

{'loss': 1.0262, 'grad_norm': 1.2978962659835815, 'learning_rate': 3.12015503875969e-05, 'epoch': 29.07}


                                                   
100%|██████████| 5160/5160 [35:25<00:00,  3.21it/s]

{'eval_loss': 1.0224545001983643, 'eval_accuracy': 0.4949053857350801, 'eval_f1': 0.4887884569894043, 'eval_precision': 0.4910328928749433, 'eval_recall': 0.49249649368863957, 'eval_runtime': 12.4512, 'eval_samples_per_second': 110.351, 'eval_steps_per_second': 3.453, 'epoch': 30.0}


100%|██████████| 5160/5160 [35:25<00:00,  2.43it/s]

{'train_runtime': 2125.9899, 'train_samples_per_second': 77.554, 'train_steps_per_second': 2.427, 'train_loss': 1.0503149904946023, 'epoch': 30.0}





TrainOutput(global_step=5160, training_loss=1.0503149904946023, metrics={'train_runtime': 2125.9899, 'train_samples_per_second': 77.554, 'train_steps_per_second': 2.427, 'total_flos': 4.338214031499264e+16, 'train_loss': 1.0503149904946023, 'epoch': 30.0})

In [12]:
# Get the best model
model = AutoModelForSequenceClassification.from_pretrained('roberta-base-sentence-relation/checkpoint-4816', num_labels=num_labels, id2label=id2label, label2id=label2id)

# Evaluate the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

100%|██████████| 43/43 [00:11<00:00,  3.61it/s]


{'eval_loss': 1.0233200788497925,
 'eval_accuracy': 0.5029112081513828,
 'eval_f1': 0.49280351950455037,
 'eval_precision': 0.5022752603866406,
 'eval_recall': 0.4993200260468844,
 'eval_runtime': 12.2102,
 'eval_samples_per_second': 112.529,
 'eval_steps_per_second': 3.522}