In [48]:
import pandas as pd
import numpy as np

import torch

from transformers import RobertaTokenizer, RobertaForSequenceClassification

In [49]:
df = pd.read_csv('../data/esnli_train_1.csv')
df.head()

Unnamed: 0,pairID,gold_label,Sentence1,Sentence2,Explanation_1,WorkerId,Sentence1_marked_1,Sentence2_marked_1,Sentence1_Highlighted_1,Sentence2_Highlighted_1
0,3416050480.jpg#4r1n,neutral,A person on a horse jumps over a broken down a...,A person is training his horse for a competition.,the person is not necessarily training his horse,AF0PI3RISB5Q7,A person on a horse jumps over a broken down a...,A person is *training* *his* *horse* for a co...,{},345
1,3416050480.jpg#4r1c,contradiction,A person on a horse jumps over a broken down a...,"A person is at a diner, ordering an omelette.",One cannot be on a jumping horse cannot be a d...,A36ZT2WFIA2HMF,A person *on* *a* *horse* *jumps* over a brok...,"A person *is* *at* *a* *diner,* *ordering* an...",4235,25436
2,3416050480.jpg#4r1e,entailment,A person on a horse jumps over a broken down a...,"A person is outdoors, on a horse.",a broken down airplane is outdoors,A2GK75ZQTX2RDZ,A person on a horse jumps over *a* *broken* *...,"A person is *outdoors,* on a horse.",89107,3
3,2267923837.jpg#2r1n,neutral,Children smiling and waving at camera,They are smiling at their parents,Just because they are smiling and waving at a ...,A18TOIDG32QICP,Children smiling and waving at camera,They are smiling *at* *their* *parents*,{},534
4,2267923837.jpg#2r1e,entailment,Children smiling and waving at camera,There are children present,The children must be present to see them smili...,AEX0YE6TUZRHT,*Children* *smiling* *and* *waving* at camera,There are children *present*,0132,3


In [50]:
df_cleaned = df.rename(columns={'Sentence1': 'premise', 'Sentence2': 'hypothesis', 'Explanation_1': 'explanation'}).drop(["WorkerId", "Sentence1_Highlighted_1", "Sentence2_Highlighted_1"], axis=1)
df_cleaned

Unnamed: 0,pairID,gold_label,premise,hypothesis,explanation,Sentence1_marked_1,Sentence2_marked_1
0,3416050480.jpg#4r1n,neutral,A person on a horse jumps over a broken down a...,A person is training his horse for a competition.,the person is not necessarily training his horse,A person on a horse jumps over a broken down a...,A person is *training* *his* *horse* for a co...
1,3416050480.jpg#4r1c,contradiction,A person on a horse jumps over a broken down a...,"A person is at a diner, ordering an omelette.",One cannot be on a jumping horse cannot be a d...,A person *on* *a* *horse* *jumps* over a brok...,"A person *is* *at* *a* *diner,* *ordering* an..."
2,3416050480.jpg#4r1e,entailment,A person on a horse jumps over a broken down a...,"A person is outdoors, on a horse.",a broken down airplane is outdoors,A person on a horse jumps over *a* *broken* *...,"A person is *outdoors,* on a horse."
3,2267923837.jpg#2r1n,neutral,Children smiling and waving at camera,They are smiling at their parents,Just because they are smiling and waving at a ...,Children smiling and waving at camera,They are smiling *at* *their* *parents*
4,2267923837.jpg#2r1e,entailment,Children smiling and waving at camera,There are children present,The children must be present to see them smili...,*Children* *smiling* *and* *waving* at camera,There are children *present*
...,...,...,...,...,...,...,...
259994,21138719.jpg#0r4e,entailment,Man in black pants and vest balances between t...,There is a man holding fire.,Fire can be held on torches.,Man in black pants and vest balances between ...,There is a man holding *fire.*
259995,21138719.jpg#0r1c,contradiction,Man in black pants and vest balances between t...,man burning down a church,If he is holding flaming torches then he would...,Man in black pants and vest balances between ...,man *burning* *down* a *church*
259996,21138719.jpg#0r5c,contradiction,Man in black pants and vest balances between t...,man runing after being set on fire.,A man will not be holding torches if he was se...,Man in black pants and vest balances between ...,man runing after being *set* on *fire.*
259997,21138719.jpg#0r2e,entailment,Man in black pants and vest balances between t...,A man holds two lit torches.,The torches are apart of a show.,Man in black pants and vest balances between ...,A man holds two lit *torches.*


In [51]:
label_to_id = {"entailment": 0, "neutral": 1, "contradiction": 2}
id_to_label = {v: k for k, v in label_to_id.items()}

Concatenate everything together

In [57]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base')

# Tokenize input
premise = df_cleaned['premise'][0]
hypothesis = df_cleaned['hypothesis'][0]
explanation = df_cleaned['explanation'][0]
actual_label = df_cleaned['gold_label'][0]
encoded_input = tokenizer.encode_plus(premise, hypothesis, explanation, padding=True, truncation=True, return_tensors='pt')

labels = torch.tensor(df_cleaned['gold_label'].replace(label_to_id).tolist())[0]
output = model(**encoded_input)

predicted_class = torch.argmax(output.logits, dim=1)

print(f"Premise: {premise}\nHypothesis: {hypothesis}\nExplanation: {explanation}\n")
print(f"True class: {actual_label}")
print(f"Predicted class: {id_to_label[predicted_class.item()]}")

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Premise: A person on a horse jumps over a broken down airplane.
Hypothesis: A person is training his horse for a competition.
Explanation: the person is not necessarily training his horse

True class: neutral
Predicted class: neutral


In [55]:
labels

tensor(1)

In [54]:
encoded_input = tokenizer.encode_plus(premise, hypothesis, padding=True, truncation=True, return_tensors='pt')

labels = torch.tensor(df_cleaned['gold_label'].replace(label_to_id).tolist())[0]
output = model(**encoded_input)

predicted_class = torch.argmax(output.logits, dim=1)
print(f"Predicted class: {id_to_label[predicted_class.item()]}")

Predicted class: neutral
