In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
  
tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-snli")

# [0 - contradiction 1 - entailment 2 - neutral]
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-snli")

In [4]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=5, required=False)
parser.add_argument('--batch_size', type=int, default=32, required=False)
parser.add_argument('--eval_batch_size', type=int, default=64, required=False)
parser.add_argument('--lr', type=float, default=5e-5, required=False)
parser.add_argument('--model_name',
                    type=str,
                    default='bert-base-uncased',
                    required=False)
parser.add_argument('--max_length', type=int, default=64, required=False)
parser.add_argument('--data_root',
                    type=str,
                    default='../data/e-SNLI/dataset',
                    required=False)
args = parser.parse_args(args=[])

In [5]:
from automatic_eval.dataloader import get_dataloaders

train, val, test = get_dataloaders(tokenizer, args)

In [7]:
import torch

device = torch.device('cuda:3')
model = model.to(device)

In [16]:
from sklearn.metrics import classification_report
from tqdm.notebook import tqdm
all_pred = []
all_gt = []

model.eval()
for batch in tqdm(val):
    batch = {k: v.to(device) for k, v in batch.items()}
    label = 2 - batch['label']
    batch.pop('label')
    logits = model(**batch).logits
    pred = logits.max(-1)[1]

    all_pred.append(pred)
    all_gt.append(label)

all_pred = torch.cat(all_pred, dim=0).cpu().numpy()
all_gt = torch.cat(all_gt, dim=0).cpu().numpy()

print(classification_report(all_gt, all_pred, target_names=['contradiction', 'entailment', 'neutral']))


  0%|          | 0/154 [00:00<?, ?it/s]

               precision    recall  f1-score   support

contradiction       0.93      0.94      0.93      3278
   entailment       0.92      0.91      0.92      3329
      neutral       0.88      0.88      0.88      3235

     accuracy                           0.91      9842
    macro avg       0.91      0.91      0.91      9842
 weighted avg       0.91      0.91      0.91      9842



In [17]:
all_pred = []
all_gt = []

model.eval()
for batch in tqdm(test):
    batch = {k: v.to(device) for k, v in batch.items()}
    label = 2 - batch['label']
    batch.pop('label')
    logits = model(**batch).logits
    pred = logits.max(-1)[1]

    all_pred.append(pred)
    all_gt.append(label)

all_pred = torch.cat(all_pred, dim=0).cpu().numpy()
all_gt = torch.cat(all_gt, dim=0).cpu().numpy()

print(classification_report(all_gt, all_pred, target_names=['contradiction', 'entailment', 'neutral']))

  0%|          | 0/154 [00:00<?, ?it/s]

               precision    recall  f1-score   support

contradiction       0.93      0.94      0.93      3237
   entailment       0.92      0.90      0.91      3368
      neutral       0.87      0.87      0.87      3219

     accuracy                           0.90      9824
    macro avg       0.90      0.90      0.90      9824
 weighted avg       0.90      0.90      0.90      9824

