In [1]:
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.tree import plot_tree

In [2]:
raw_data      = pd.read_csv('../data/sentence-relations/train.csv', index_col='id')
raw_submissions = pd.read_csv('../data/sentence-relations/test.csv', index_col='id')

# Remove chinese
# raw_data = raw_data[raw_data['lang_abv'] != 'zh']
# raw_data = raw_data[raw_data['lang_abv'] != 'th']

# Only take english
raw_data = raw_data[raw_data['lang_abv'] == 'en']

training_data, test_data = train_test_split(raw_data, test_size=0.2, random_state=42)

training_data.head()

Unnamed: 0_level_0,premise,hypothesis,lang_abv,language,label
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
b6e01c1a07,"Also, the Holy Family are said to have shelter...",The Holy family spent a total of three days here.,en,English,1
b8fa1a0044,Participants generally viewed the new internal...,Those organizations affected by the Sarbanes-O...,en,English,0
c828f51ef6,With a little practice almost anyone can flip ...,Practicing lets you do anything you put your m...,en,English,1
b2c98d5a99,More reserved and remote but a better administ...,The uncle had no match in administration; cert...,en,English,2
9cd35fee05,The company later told us that it had disconti...,The company later told us that it had enhanced...,en,English,2


In [3]:
from transformers import RobertaTokenizer, RobertaModel
from transformers import DataCollatorWithPadding

roberta_tokenizer = RobertaTokenizer.from_pretrained('FacebookAI/roberta-base')
data_collator = DataCollatorWithPadding(tokenizer=roberta_tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from torch.utils.data import Dataset, DataLoader

class SentenceRelationTransformerDataset(Dataset):

    def __init__(self, data, tokenizer):
        self.data               = data
        self.sentence_relations = []
        self.labels             = []
        self.tokenizer          = tokenizer

        self.perform_preprocessing()

    def __len__(self):
        return len(self.data)

    def get_max_input_length(self):

        max_length = 0

        for sentence_pair in self.sentence_relations:
            max_length = max(max_length, len(sentence_pair[0]), len(sentence_pair[1]))

        return max_length

    def perform_preprocessing(self):
        self.sentence_relations = []
        self.labels             = []

        for i in range(len(self.data)):
            premise    = self.data.iloc[i]['premise']
            hypothesis = self.data.iloc[i]['hypothesis']

            input_string = f"Premise: {premise} | Hypothesis: {hypothesis}"

            input_tokenized = self.tokenizer(input_string, padding='max_length', max_length=512, truncation=True)

            label = self.data.iloc[i]['label']

            self.labels.append(label)
            self.sentence_relations.append(input_tokenized)

    def __getitem__(self, idx):
        temp_dict = self.sentence_relations[idx]
        temp_dict['label'] = self.labels[idx]

        return temp_dict


train_dataset = SentenceRelationTransformerDataset(training_data, roberta_tokenizer)
test_dataset  = SentenceRelationTransformerDataset(test_data, roberta_tokenizer)

In [10]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
num_labels = len(id2label)

model = AutoModelForSequenceClassification.from_pretrained('FacebookAI/roberta-base', num_labels=num_labels, id2label=id2label, label2id=label2id)

# Freeze the model except
# modfication_layers = ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']

# for name, param in model.named_parameters():
#     if not any(layer in name for layer in modfication_layers):
#         param.requires_grad = False

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

training_args = TrainingArguments(
    output_dir="roberta-base-sentence-relation",
    learning_rate=1e-3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=True
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [9]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

  3%|▎         | 766/27480 [01:20<38:32, 11.55it/s]

KeyboardInterrupt: 

In [None]:
# Get the best model
model = AutoModelForSequenceClassification.from_pretrained('roberta-base-sentence-relation/checkpoint-4816', num_labels=num_labels, id2label=id2label, label2id=label2id)

# Evaluate the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

100%|██████████| 43/43 [00:11<00:00,  3.61it/s]


{'eval_loss': 1.0233200788497925,
 'eval_accuracy': 0.5029112081513828,
 'eval_f1': 0.49280351950455037,
 'eval_precision': 0.5022752603866406,
 'eval_recall': 0.4993200260468844,
 'eval_runtime': 12.2102,
 'eval_samples_per_second': 112.529,
 'eval_steps_per_second': 3.522}