In [None]:
import datasets
import torch
import numpy as np
from sklearn.metrics import accuracy_score
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, DataCollatorWithPadding, Trainer
from transformers import AutoConfig

In [None]:
data_files = {
    'train': 'data/pnli_train.csv',
    'validation': 'data/pnli_dev.csv',
}

In [None]:
raw_datasets = datasets.load_dataset('csv', data_files=data_files, column_names=['sentence1', 'sentence2', 'labels'])

In [None]:
# checkpoint = 'sentence-transformers/all-MiniLM-L12-v2'
# checkpoint_folder = 'sentence-transformers-all-MiniLM-L12-v2'

checkpoint = 'facebook/bart-large-mnli'
checkpoint_folder = 'facebook-bart-large-mnli'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True)

In [None]:
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

In [None]:
training_args = TrainingArguments(
    'checkpoints/' + checkpoint_folder + 'checkpoint',
    evaluation_strategy='steps',
    num_train_epochs=10,
    eval_steps=500,
    warmup_steps=200,
    metric_for_best_model='accuracy',
    load_best_model_at_end=True
)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2, ignore_mismatched_sizes=True)

In [None]:
def compute_metrics(eval_preds):
    print(type(eval_preds))
    logits, labels = eval_preds
    logits = logits[0]
    predictions = np.argmax(logits, axis=1)
    accuracy = accuracy_score(labels, predictions)
    return {'accuracy': accuracy}

In [None]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate(tokenized_datasets["validation"])

In [None]:
trainer.save_model('models/' + checkpoint_folder)

In [None]:
raw_test_datasets = datasets.load_dataset('csv', data_files={'test': 'data/pnli_test_unlabeled.csv'}, column_names=['sentence1', 'sentence2', 'labels'])

In [None]:
tokenized_test_datasets = raw_test_datasets.map(tokenize_function, batched=True)

In [None]:
tokenized_test_datasets['test'] = tokenized_test_datasets['test'].remove_columns(['sentence1', 'sentence2', 'labels'])

In [None]:
tokenized_test_datasets.set_format(type='torch')

In [None]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [None]:
predictions, _, _ = trainer.predict(tokenized_test_datasets['test'])

In [None]:
predictions = np.argmax(predictions, axis=1)

In [None]:
len(predictions)