In [2]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

train_jsonl = "alphanli-train-dev/train.jsonl"
train_labels = "alphanli-train-dev/train-labels.lst"  
dev_jsonl = "alphanli-train-dev/dev.jsonl"
dev_labels = "alphanli-train-dev/dev-labels.lst"

def load_data(jsonl_file, labels_file):
    data = []
    with open(jsonl_file, "r") as f_json, open(labels_file, "r") as f_labels:
        labels = [int(line.strip()) for line in f_labels.readlines()]
        for idx, line in enumerate(f_json):
            entry = json.loads(line.strip())
            data.append({
                "obs1": entry["obs1"],
                "obs2": entry["obs2"],
                "hyp1": entry["hyp1"],
                "hyp2": entry["hyp2"],
                "label": labels[idx] 
            })
    return data

train_data = load_data(train_jsonl, train_labels)
dev_data = load_data(dev_jsonl, dev_labels)

class ANLIDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        premise = f"{entry['obs1']} {entry['obs2']}"

        encoding1 = self.tokenizer(
            premise, entry['hyp1'],
            truncation=True, padding='max_length', max_length=self.max_length,
            return_tensors="pt"
        )
        encoding2 = self.tokenizer(
            premise, entry['hyp2'],
            truncation=True, padding='max_length', max_length=self.max_length,
            return_tensors="pt"
        )

        encoding1 = {k: v.squeeze(0) for k, v in encoding1.items()}
        encoding2 = {k: v.squeeze(0) for k, v in encoding2.items()}
        
        label = entry["label"] - 1
        return {"input1": encoding1, "input2": encoding2, "label": label}

def collate_fn(batch):
    batch_input1 = {}
    batch_input2 = {}
    keys = list(batch[0]["input1"].keys())
    for k in keys:
        batch_input1[k] = torch.stack([item["input1"][k] for item in batch])
        batch_input2[k] = torch.stack([item["input2"][k] for item in batch])
    labels = torch.tensor([item["label"] for item in batch])
    return {"input1": batch_input1, "input2": batch_input2, "labels": labels}

train_dataset = ANLIDataset(train_data, tokenizer=None)  
dev_dataset = ANLIDataset(dev_data, tokenizer=None)

model_name = "mjwong/gte-large-mnli-anli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

train_dataset.tokenizer = tokenizer
dev_dataset.tokenizer = tokenizer

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(dev_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)
num_epochs = 5
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
loss_fn = torch.nn.CrossEntropyLoss()

scaler = torch.cuda.amp.GradScaler()

model.train()
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        optimizer.zero_grad()

        batch_input1 = {k: v.to(device) for k, v in batch["input1"].items()}
        batch_input2 = {k: v.to(device) for k, v in batch["input2"].items()}
        labels = batch["labels"].to(device)
        
        with torch.cuda.amp.autocast():
            outputs1 = model(**batch_input1)
            outputs2 = model(**batch_input2)

            score1 = outputs1.logits[:, 2]
            score2 = outputs2.logits[:, 2]
            logits = torch.stack([score1, score2], dim=1)
            loss = loss_fn(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        total_loss += loss.item()
        
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
    output_dir = f"gte-aNLI_epoch_{epoch+1}"
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Saved checkpoint to {output_dir}")

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
Training Epoch 1: 100%|██████████| 5302/5302 [15:15<00:00,  5.79it/s]


Epoch 1/5, Loss: 0.2615
Saved checkpoint to gte-aNLI_epoch_1


Training Epoch 2: 100%|██████████| 5302/5302 [15:15<00:00,  5.79it/s]


Epoch 2/5, Loss: 0.0890
Saved checkpoint to gte-aNLI_epoch_2


Training Epoch 3: 100%|██████████| 5302/5302 [15:16<00:00,  5.79it/s]


Epoch 3/5, Loss: 0.0327
Saved checkpoint to gte-aNLI_epoch_3


Training Epoch 4: 100%|██████████| 5302/5302 [15:16<00:00,  5.79it/s]


Epoch 4/5, Loss: 0.0162
Saved checkpoint to gte-aNLI_epoch_4


Training Epoch 5: 100%|██████████| 5302/5302 [15:17<00:00,  5.78it/s]


Epoch 5/5, Loss: 0.0090
Saved checkpoint to gte-aNLI_epoch_5


# Eval

In [3]:
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(dev_loader, desc="Evaluating"):

        batch_input1 = {k: v.to(device) for k, v in batch["input1"].items()}
        batch_input2 = {k: v.to(device) for k, v in batch["input2"].items()}
        labels = batch["labels"].to(device)
        
        outputs1 = model(**batch_input1)
        outputs2 = model(**batch_input2)
        
        score1 = outputs1.logits[:, 2]
        score2 = outputs2.logits[:, 2]
        logits = torch.stack([score1, score2], dim=1)
        
        preds = torch.argmax(logits, dim=1) + 1
        
        all_preds.extend(preds.cpu().numpy().tolist())
        all_labels.extend((labels.cpu().numpy() + 1).tolist())

accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
    all_labels, all_preds, average="binary", pos_label=1
)

print(f"Dev Accuracy: {accuracy:.4f}")
print(f"Dev Precision: {precision:.4f}")
print(f"Dev Recall: {recall:.4f}")
print(f"Dev F1-score: {f1:.4f}")

Evaluating: 100%|██████████| 48/48 [00:15<00:00,  3.09it/s]

Dev Accuracy: 0.6704
Dev Precision: 0.6816
Dev Recall: 0.6633
Dev F1-score: 0.6723



