In [1]:
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.tree import plot_tree

In [2]:
raw_data      = pd.read_csv('../data/sentence-relations/train.csv', index_col='id')
raw_submissions = pd.read_csv('../data/sentence-relations/test.csv', index_col='id')

# Remove chinese
# raw_data = raw_data[raw_data['lang_abv'] != 'zh']
# raw_data = raw_data[raw_data['lang_abv'] != 'th']

# Only take english
raw_data = raw_data[raw_data['lang_abv'] == 'en']

training_data, test_data = train_test_split(raw_data, test_size=0.2, random_state=42)

training_data.head()

Unnamed: 0_level_0,premise,hypothesis,lang_abv,language,label
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
b6e01c1a07,"Also, the Holy Family are said to have shelter...",The Holy family spent a total of three days here.,en,English,1
b8fa1a0044,Participants generally viewed the new internal...,Those organizations affected by the Sarbanes-O...,en,English,0
c828f51ef6,With a little practice almost anyone can flip ...,Practicing lets you do anything you put your m...,en,English,1
b2c98d5a99,More reserved and remote but a better administ...,The uncle had no match in administration; cert...,en,English,2
9cd35fee05,The company later told us that it had disconti...,The company later told us that it had enhanced...,en,English,2


In [3]:
from transformers import RobertaTokenizer, RobertaModel
from transformers import DataCollatorWithPadding

roberta_tokenizer = RobertaTokenizer.from_pretrained('FacebookAI/roberta-base')
data_collator = DataCollatorWithPadding(tokenizer=roberta_tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from torch.utils.data import Dataset, DataLoader

class SentenceRelationTransformerDataset(Dataset):

    def __init__(self, data, tokenizer):
        self.data               = data
        self.sentence_relations = []
        self.labels             = []
        self.tokenizer          = tokenizer

        self.perform_preprocessing()

    def __len__(self):
        return len(self.data)

    def get_max_input_length(self):

        max_length = 0

        for sentence_pair in self.sentence_relations:
            max_length = max(max_length, len(sentence_pair[0]), len(sentence_pair[1]))

        return max_length

    def perform_preprocessing(self):
        self.sentence_relations = []
        self.labels             = []

        for i in range(len(self.data)):
            premise    = self.data.iloc[i]['premise']
            hypothesis = self.data.iloc[i]['hypothesis']

            input_string = f"Premise: {premise} | Hypothesis: {hypothesis}"

            input_tokenized = self.tokenizer(input_string, padding='max_length', max_length=512, truncation=True)

            label = self.data.iloc[i]['label']

            self.labels.append(label)
            self.sentence_relations.append(input_tokenized)

    def __getitem__(self, idx):
        temp_dict = self.sentence_relations[idx]
        temp_dict['label'] = self.labels[idx]

        return temp_dict


train_dataset = SentenceRelationTransformerDataset(training_data, roberta_tokenizer)
test_dataset  = SentenceRelationTransformerDataset(test_data, roberta_tokenizer)

In [10]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
num_labels = len(id2label)

model = AutoModelForSequenceClassification.from_pretrained('FacebookAI/roberta-base', num_labels=num_labels, id2label=id2label, label2id=label2id)

# Freeze the model except
# modfication_layers = ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']

# for name, param in model.named_parameters():
#     if not any(layer in name for layer in modfication_layers):
#         param.requires_grad = False

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

training_args = TrainingArguments(
    output_dir="roberta-base-sentence-relation",
    learning_rate=1e-3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=True
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [12]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

  3%|▎         | 766/27480 [01:31<53:17,  8.35it/s]
                                                  
  7%|▋         | 501/6870 [01:22<17:44,  5.98it/s]

{'loss': 1.1838, 'grad_norm': 3.1822421550750732, 'learning_rate': 0.0009275109170305677, 'epoch': 0.36}


                                                   
 15%|█▍        | 1001/6870 [02:45<16:23,  5.97it/s]

{'loss': 1.1409, 'grad_norm': 3.164424419403076, 'learning_rate': 0.0008547307132459971, 'epoch': 0.73}


 20%|██        | 1374/6870 [03:47<15:04,  6.07it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
[A                                              

 20%|██        | 1374/6870 [04:02<15:04,  6.07it/s]
[A
[A

{'eval_loss': 1.1158868074417114, 'eval_accuracy': 0.3158660844250364, 'eval_f1': 0.16002949852507375, 'eval_precision': 0.10528869480834546, 'eval_recall': 0.3333333333333333, 'eval_runtime': 14.752, 'eval_samples_per_second': 93.14, 'eval_steps_per_second': 23.319, 'epoch': 1.0}


                                                     
 22%|██▏       | 1501/6870 [04:24<15:00,  5.96it/s]

{'loss': 1.14, 'grad_norm': 1.4449799060821533, 'learning_rate': 0.0007819505094614265, 'epoch': 1.09}


                                                   
 29%|██▉       | 2001/6870 [05:47<13:33,  5.98it/s]

{'loss': 1.1373, 'grad_norm': 2.5436692237854004, 'learning_rate': 0.000709170305676856, 'epoch': 1.46}


                                                   
 36%|███▋      | 2501/6870 [07:10<12:07,  6.01it/s]

{'loss': 1.12, 'grad_norm': 1.5093458890914917, 'learning_rate': 0.0006363901018922853, 'epoch': 1.82}


 40%|████      | 2748/6870 [07:51<11:21,  6.05it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   

[A[A                                           
 40%|████      | 2748/6870 [08:05<11:21,  6.05it/s]
[A
[A

{'eval_loss': 1.110573649406433, 'eval_accuracy': 0.33478893740902477, 'eval_f1': 0.16721192293711376, 'eval_precision': 0.11159631246967493, 'eval_recall': 0.3333333333333333, 'eval_runtime': 14.7196, 'eval_samples_per_second': 93.345, 'eval_steps_per_second': 23.37, 'epoch': 2.0}


                                                     
 44%|████▎     | 3001/6870 [08:49<10:47,  5.98it/s]

{'loss': 1.1194, 'grad_norm': 3.369414806365967, 'learning_rate': 0.0005636098981077147, 'epoch': 2.18}


                                                   
 51%|█████     | 3501/6870 [10:12<09:24,  5.97it/s]

{'loss': 1.115, 'grad_norm': 1.584072470664978, 'learning_rate': 0.0004908296943231441, 'epoch': 2.55}


                                                   
 58%|█████▊    | 4001/6870 [11:35<08:01,  5.96it/s]

{'loss': 1.1092, 'grad_norm': 0.5623304843902588, 'learning_rate': 0.0004180494905385735, 'epoch': 2.91}


 60%|██████    | 4122/6870 [11:55<07:35,  6.03it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   

[A[A                                           
 60%|██████    | 4122/6870 [12:10<07:35,  6.03it/s]
[A
[A

{'eval_loss': 1.1004438400268555, 'eval_accuracy': 0.3158660844250364, 'eval_f1': 0.16002949852507375, 'eval_precision': 0.10528869480834546, 'eval_recall': 0.3333333333333333, 'eval_runtime': 14.8586, 'eval_samples_per_second': 92.472, 'eval_steps_per_second': 23.152, 'epoch': 3.0}


                                                     
 66%|██████▌   | 4501/6870 [13:14<06:38,  5.95it/s]

{'loss': 1.1091, 'grad_norm': 2.0627851486206055, 'learning_rate': 0.00034541484716157206, 'epoch': 3.28}


                                                   
 73%|███████▎  | 5001/6870 [14:37<05:11,  5.99it/s]

{'loss': 1.1047, 'grad_norm': 0.9143522381782532, 'learning_rate': 0.0002726346433770014, 'epoch': 3.64}


 80%|████████  | 5496/6870 [15:59<03:48,  6.03it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
[A                                              

 80%|████████  | 5496/6870 [16:14<03:48,  6.03it/s]
[A
[A

{'eval_loss': 1.1059762239456177, 'eval_accuracy': 0.33478893740902477, 'eval_f1': 0.16721192293711376, 'eval_precision': 0.11159631246967493, 'eval_recall': 0.3333333333333333, 'eval_runtime': 14.8269, 'eval_samples_per_second': 92.669, 'eval_steps_per_second': 23.201, 'epoch': 4.0}


                                                     
 80%|████████  | 5501/6870 [16:17<30:59,  1.36s/it]

{'loss': 1.1009, 'grad_norm': 1.6319918632507324, 'learning_rate': 0.00019985443959243086, 'epoch': 4.0}


                                                   
 87%|████████▋ | 6001/6870 [17:40<02:25,  5.97it/s]

{'loss': 1.1032, 'grad_norm': 0.5795202851295471, 'learning_rate': 0.00012707423580786027, 'epoch': 4.37}


                                                   
 95%|█████████▍| 6501/6870 [19:03<01:01,  5.96it/s]

{'loss': 1.1012, 'grad_norm': 0.6619698405265808, 'learning_rate': 5.429403202328967e-05, 'epoch': 4.73}


100%|██████████| 6870/6870 [20:04<00:00,  6.01it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
[A                                              

100%|██████████| 6870/6870 [20:19<00:00,  6.01it/s]
[A
[A

{'eval_loss': 1.099313735961914, 'eval_accuracy': 0.33478893740902477, 'eval_f1': 0.16721192293711376, 'eval_precision': 0.11159631246967493, 'eval_recall': 0.3333333333333333, 'eval_runtime': 14.7994, 'eval_samples_per_second': 92.842, 'eval_steps_per_second': 23.244, 'epoch': 5.0}


                                                   
100%|██████████| 6870/6870 [20:21<00:00,  5.63it/s]

{'train_runtime': 1221.2733, 'train_samples_per_second': 22.501, 'train_steps_per_second': 5.625, 'train_loss': 1.120514235850505, 'epoch': 5.0}





TrainOutput(global_step=6870, training_loss=1.120514235850505, metrics={'train_runtime': 1221.2733, 'train_samples_per_second': 22.501, 'train_steps_per_second': 5.625, 'total_flos': 7230356719165440.0, 'train_loss': 1.120514235850505, 'epoch': 5.0})

In [None]:
# Get the best model
model = AutoModelForSequenceClassification.from_pretrained('roberta-base-sentence-relation/checkpoint-4816', num_labels=num_labels, id2label=id2label, label2id=label2id)

# Evaluate the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

100%|██████████| 43/43 [00:11<00:00,  3.61it/s]


{'eval_loss': 1.0233200788497925,
 'eval_accuracy': 0.5029112081513828,
 'eval_f1': 0.49280351950455037,
 'eval_precision': 0.5022752603866406,
 'eval_recall': 0.4993200260468844,
 'eval_runtime': 12.2102,
 'eval_samples_per_second': 112.529,
 'eval_steps_per_second': 3.522}