In [45]:
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline

In [46]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli").to(device)
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device)


In [47]:
mnli_test_set = pd.read_csv("/home/kyle/repos/Parameter-Free-LM-Editing/datasets/boss_benchmark/NaturalLanguageInference/mnli/test.tsv", sep="\t")
mnli_test_set.head()

Unnamed: 0,Premise,Hypothesis,Label
0,"It focuses on desktop, client/server, and ente...",It lacks focus on desktop and enterprise compu...,2
1,huh do you have your own kiln or do you do you,Did somebody give you your kiln?,1
2,no chemicals and plus then you can use it as a...,You can use those chemicals as a fertilizer,2
3,"In Texas, the legislature was instrumental in ...",The benefit program in place already had littl...,2
4,"In 1654 Oliver Cromwell, Lord Protector of Eng...",Cromwell sent nobody to the Caribbean.,2


In [48]:
test_entry = mnli_test_set.iloc[9810]
classifier(test_entry["Premise"] + " - " + test_entry["Hypothesis"])

[{'label': 'entailment', 'score': 0.7289925813674927}]

In [54]:
classifier("It focuses on desktop, client/server, and enterprisewide computing. / It lacks focus on desktop and enterprise computing sector.")[0]["label"]

'contradiction'

### PyTorch Native

In [80]:
predicitons = []
labels = []
for i in range(len(mnli_test_set)):
    current_entry = mnli_test_set.iloc[i]
    input_text = current_entry["Premise"] + " / " + current_entry["Hypothesis"]
    label = current_entry["Label"]
    
    tokenized_input = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)
    logits = model(**tokenized_input).logits
    prediction = model.config.id2label[logits.argmax().item()].lower()
    
    token_label_map = {
        "entailment": 0,
        "neutral": 1,
        "contradiction": 2
    }   
    predicitons.append(token_label_map[prediction])
    labels.append(label)
     
    break


display(predicitons)
display(labels)

[2]

[2]