In [1]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

model_name = "mjwong/gte-large-mnli-anli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

tokenizer_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/904 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,

In [2]:
dev_jsonl = "alphanli-train-dev/dev.jsonl"
dev_labels = "alphanli-train-dev/dev-labels.lst"

def load_data(jsonl_file, labels_file):
    data = []
    with open(jsonl_file, "r") as f_json, open(labels_file, "r") as f_labels:
        labels = [int(line.strip()) for line in f_labels.readlines()]
        for idx, line in enumerate(f_json):
            entry = json.loads(line.strip())
            data.append({
                "obs1": entry["obs1"],
                "obs2": entry["obs2"],
                "hyp1": entry["hyp1"],
                "hyp2": entry["hyp2"],
                "label": labels[idx]  
            })
    return data

dev_data = load_data(dev_jsonl, dev_labels)

def format_input(entry, hypothesis):
    return f"{entry['obs1']} {entry['obs2']}", hypothesis

def run_inference(model, tokenizer, dataset):
    model.eval()
    predictions = []
    ground_truths = []

    for entry in tqdm(dataset):
        inputs = []
        labels = []

        for i, hyp in enumerate([entry['hyp1'], entry['hyp2']]):
            premise, hypothesis = format_input(entry, hyp)
            inputs.append((premise, hypothesis))
            labels.append(i + 1)  

        encodings = tokenizer(
            [inp[0] for inp in inputs],  # obs
            [inp[1] for inp in inputs],  # hyp
            padding=True, truncation=True, return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            outputs = model(**encodings)
            logits = outputs.logits

        entailment_scores = logits[:, 2]  # "entailment" idx
        pred = torch.argmax(entailment_scores).item() + 1  # match label indexing 1,2

        predictions.append(pred)
        ground_truths.append(entry["label"])

    return predictions, ground_truths

preds, labels = run_inference(model, tokenizer, dev_data)

  0%|                                                  | 0/1532 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|███████████████████████████████████████| 1532/1532 [02:33<00:00,  9.99it/s]


In [3]:
accuracy = accuracy_score(labels, preds)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary", pos_label=1)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-score: {f1:.4f}")

Accuracy: 0.4099
Precision: 0.4179
Recall: 0.4008
F1-score: 0.4092
