In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go

import tensorflow as tf
import warnings

warnings.filterwarnings('ignore')

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")  
    except RuntimeError as e:
        print(e)
else:
    strategy = tf.distribute.get_strategy()
    print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
import os
os.environ['WANDB_API_KEY'] = 'ab7d302b943c078bbd1f2dcb63dbc2fbf591ad58'

In [None]:
train = pd.read_csv("../input/contradictory-my-dear-watson/train.csv")
test = pd.read_csv("../input/contradictory-my-dear-watson/test.csv")

In [None]:
display(train.head())
display(test.head())

In [None]:
display(train.info())

In [None]:
display(test.info())

In [None]:
import plotly.express as px
labels, frequencies = np.unique(train.language.values, return_counts = True)

fig = px.pie(values= frequencies,
            names = labels,
            title = 'train: languages distribution')
fig.show()

In [None]:
import plotly.express as px
labels, frequencies = np.unique(test.language.values, return_counts = True)

fig = px.pie(values= frequencies,
            names = labels,
            title = 'test: languages distribution')
fig.show()

In [None]:
import plotly.express as px
labels, frequencies = np.unique(train.label.values, return_counts = True)

fig = px.pie(values= frequencies,
            names = labels,
            title = 'train: label distribution')
fig.show()

In [None]:
sns.countplot(x= train.label)


In [None]:
label_count = train['label'].value_counts().sort_index()
label_count

In [None]:
label_names = ['entailment', 'neutral', 'contradiction']
label_count.index = label_names
label_count

In [None]:
!pip install evaluate 


In [None]:
import evaluate 
import torch

from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer

In [None]:
model_name = 'symanto/xlm-roberta-base-snli-mnli-anli-xnli'
tokenizer = AutoTokenizer.from_pretrained(model_name) 

In [None]:
train = train.drop(labels=['language', 'lang_abv',], axis=1)
test = test.drop(labels=['language','lang_abv'], axis=1)


In [None]:
from datasets import Dataset, DatasetDict

In [None]:
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(train, test_size = 0.2, random_state = 42)

train_ds = Dataset.from_pandas(train_df)
val_ds = Dataset.from_pandas(val_df)
test_ds = Dataset.from_pandas(test)

ds = DatasetDict()
ds['train'] = train_ds
ds['validation'] = val_ds
ds['test'] = test_ds

In [None]:
def tokenizer_sentence(data):
    return tokenizer(data['premise'], data['hypothesis'], truncation = True)

In [None]:
tokenized_ds = ds.map(tokenizer_sentence, batched=True)

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


In [None]:
import torch.nn as nn
from transformers import XLMRobertaModel

class CustomXLMRobertaModel(nn.Module):
    def __init__(self, num_labels):
        super(CustomXLMRobertaModel, self).__init__()
        model_name = 'symanto/xlm-roberta-base-snli-mnli-anli-xnli'
        self.roberta = XLMRobertaModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.2)
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_labels)
        )
        self.loss = nn.CrossEntropyLoss()
        self.num_labels = num_labels

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        output = self.dropout(output.pooler_output)
        logits = self.classifier(output)

        if labels is not None:
            loss = self.loss(logits.view(-1, self.num_labels), labels.view(-1))
            return {"loss": loss, "logits": logits}
        else:
            return logits

In [None]:
model = CustomXLMRobertaModel(3)

In [None]:
from sklearn.metrics import accuracy_score, f1_score
from datasets import load_metric

training_args = TrainingArguments("/content",
                                  optim="adamw_torch",
                                  num_train_epochs=5,
                                  evaluation_strategy="epoch",
                                  logging_dir='./logs',
                                  logging_steps=10,
                                report_to="none")

f1_metric = load_metric("f1")
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return {
        'accuracy': accuracy_score(labels, predictions),
        'f1': f1_metric.compute(predictions=predictions, references=labels, average="micro")
    }

In [None]:
from transformers import Trainer

trainer = Trainer(
    model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,  
)

In [None]:
os.environ["WANDB_DISABLED"] = "false"   

In [None]:
trainer.train()

In [None]:
predictions = trainer.predict(tokenized_ds["test"])
predictions

In [None]:
logits = torch.from_numpy(predictions.predictions)
probs = torch.softmax(logits, -1).tolist() # convert to probability
probs[:5]

In [None]:
outputs = []

for index, prob in enumerate(probs):

    predicted_label = prob.index(max(prob))
    element_id = ds['test']['id'][index]
    prediction = (element_id, predicted_label)
    outputs.append(prediction)

In [None]:
submission = pd.read_csv("/kaggle/input/contradictory-my-dear-watson/sample_submission.csv")
submission

In [None]:
outputs = pd.DataFrame(outputs, columns = ['id','prediction'])
outputs

In [None]:
outputs.to_csv('submission.csv',index = False)