In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from pathlib import Path
import json
from itertools import product
from tqdm import tqdm
import numpy as np
import warnings
warnings.filterwarnings("ignore")

# ----------- Static config -----------
DATA_DIR       = Path("data")
TRAIN_JSON     = DATA_DIR / "train-claims.json"
DEV_JSON       = DATA_DIR / "dev-claims.json"
EVID_JSON      = DATA_DIR / "evidence.json"
BERT_MODEL     = "bert-base-uncased"
MAX_LEN        = 256
NUM_CLASSES    = 4
BATCH_SIZE     = 16
EPOCHS         = 5
LR             = 2e-4
DEVICE = "cpu"
if torch.cuda.is_available(): 
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"

label2idx = {
    "SUPPORTS": 0,
    "REFUTES": 1,
    "NOT_ENOUGH_INFO": 2,
    "DISPUTED": 3,
}

# ----------- Load data -----------
with open(TRAIN_JSON, "r", encoding="utf-8") as f:
    train_claims = json.load(f)
with open(DEV_JSON, "r", encoding="utf-8") as f:
    dev_claims = json.load(f)
with open(EVID_JSON, "r", encoding="utf-8") as f:
    evidence_dict = json.load(f)

  Referenced from: <CFED5F8E-EC3F-36FD-AAA3-2C6C7F8D3DD9> /Users/felikskong/anaconda3/envs/nlp/lib/python3.11/site-packages/torchvision/image.so
  warn(


In [9]:
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)

class ClaimEvidenceDataset(Dataset):
    def __init__(self, claims, evidences, tokenizer, max_len):
        self.items = []
        for cid, obj in claims.items():
            claim_text = obj["claim_text"]
            ev_ids = obj.get("evidences", [])
            ev_texts = [evidences[e] for e in ev_ids if e in evidences]
            full_input = claim_text + " [SEP] " + " ".join(ev_texts)
            label = label2idx[obj["claim_label"]]
            self.items.append((full_input, label))
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        text, label = self.items[idx]
        enc = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        return (
            enc["input_ids"].squeeze(0),
            enc["attention_mask"].squeeze(0),
            torch.tensor(label, dtype=torch.long),
        )

def collate_batch(batch):
    ids, masks, labs = zip(*batch)
    return torch.stack(ids), torch.stack(masks), torch.stack(labs)

train_ds = ClaimEvidenceDataset(train_claims, evidence_dict, tokenizer, MAX_LEN)
dev_ds = ClaimEvidenceDataset(dev_claims, evidence_dict, tokenizer, MAX_LEN)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
dev_dl = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

# ----------- Model definition -----------
class BiLSTMWithBertEncoder(nn.Module):
    def __init__(self, bert_name, lstm_hid, num_classes, dropout_prob, lstm_layers):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        for p in self.bert.parameters():
            p.requires_grad = False
        bert_dim = self.bert.config.hidden_size
        self.dropout_bert = nn.Dropout(dropout_prob)
        self.lstm = nn.LSTM(
            input_size=bert_dim,
            hidden_size=lstm_hid,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout_prob if lstm_layers > 1 else 0.0
        )
        self.attn_fc = nn.Linear(2 * lstm_hid, 1)
        self.dropout_pool = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(2 * lstm_hid, num_classes)

    def forward(self, input_ids, attention_mask):
        seq_emb = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        seq_emb = self.dropout_bert(seq_emb)
        lstm_out, _ = self.lstm(seq_emb)
        scores = self.attn_fc(lstm_out).squeeze(-1)
        scores = scores.masked_fill(attention_mask == 0, -1e9)
        alphas = torch.softmax(scores, dim=1)
        pooled = torch.sum(lstm_out * alphas.unsqueeze(-1), dim=1)
        pooled = self.dropout_pool(pooled)
        return self.classifier(pooled)


In [None]:
# ----------- Grid search -----------
param_grid = {
    "LSTM_HID_DIM": [256, 512],
    "DROPOUT_PROB": [0.1, 0.2, 0.3],
    "NUM_LAYERS": [2, 3],
}

best_global_acc = 0.0
best_global_path = "task2_best_model_grid.pt"
best_config = None

for hid_dim, dropout, layers in product(param_grid["LSTM_HID_DIM"], param_grid["DROPOUT_PROB"], param_grid["NUM_LAYERS"]):
    print(f"\n🔍 Training with LSTM_HID_DIM={hid_dim}, DROPOUT_PROB={dropout}, NUM_LAYERS={layers}")
    model = BiLSTMWithBertEncoder(BERT_MODEL, hid_dim, NUM_CLASSES, dropout, layers).to(DEVICE)
    optimizer = torch.optim.Adam(model.classifier.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    best_acc = 0.0
    for epoch in range(1, EPOCHS + 1):
        model.train()
        for input_ids, attn_mask, labels in tqdm(train_dl, desc=f"Train Epoch {epoch}", dynamic_ncols=True):
            input_ids = input_ids.to(DEVICE)
            attn_mask = attn_mask.to(DEVICE)
            labels = labels.to(DEVICE)
            logits = model(input_ids, attn_mask)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for input_ids, attn_mask, labels in dev_dl:
                input_ids = input_ids.to(DEVICE)
                attn_mask = attn_mask.to(DEVICE)
                labels = labels.to(DEVICE)
                preds = model(input_ids, attn_mask).argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        acc = correct / total
        print(f"  Epoch {epoch} → Val Accuracy: {acc:.4%}")

        if acc > best_acc:
            best_acc = acc

    print(f"✅ Best val acc for this config: {best_acc:.4%}")
    if best_acc > best_global_acc:
        best_global_acc = best_acc
        best_config = (hid_dim, dropout, layers)
        torch.save(model.state_dict(), best_global_path)
        print(f"🔥 New global best model saved (acc {best_acc:.4%})")

print(f"\n🏆 Grid Search Done. Best Config: LSTM_HID_DIM={best_config[0]}, DROPOUT={best_config[1]}, LAYERS={best_config[2]}")
print(f"✅ Highest Validation Accuracy: {best_global_acc:.4%}")


🔍 Training with LSTM_HID_DIM=256, DROPOUT_PROB=0.1, NUM_LAYERS=2


Train Epoch 1: 100%|██████████| 77/77 [00:29<00:00,  2.65it/s]


  Epoch 1 → Val Accuracy: 44.1558%


Train Epoch 2: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 2 → Val Accuracy: 44.1558%


Train Epoch 3: 100%|██████████| 77/77 [00:28<00:00,  2.69it/s]


  Epoch 3 → Val Accuracy: 44.8052%


Train Epoch 4: 100%|██████████| 77/77 [00:28<00:00,  2.70it/s]


  Epoch 4 → Val Accuracy: 46.1039%


Train Epoch 5: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 5 → Val Accuracy: 47.4026%
✅ Best val acc for this config: 47.4026%
🔥 New global best model saved (acc 47.4026%)

🔍 Training with LSTM_HID_DIM=256, DROPOUT_PROB=0.1, NUM_LAYERS=3


Train Epoch 1: 100%|██████████| 77/77 [00:31<00:00,  2.46it/s]


  Epoch 1 → Val Accuracy: 44.8052%


Train Epoch 2: 100%|██████████| 77/77 [00:31<00:00,  2.46it/s]


  Epoch 2 → Val Accuracy: 44.1558%


Train Epoch 3: 100%|██████████| 77/77 [00:31<00:00,  2.46it/s]


  Epoch 3 → Val Accuracy: 46.7532%


Train Epoch 4: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 4 → Val Accuracy: 44.8052%


Train Epoch 5: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 5 → Val Accuracy: 44.8052%
✅ Best val acc for this config: 46.7532%

🔍 Training with LSTM_HID_DIM=256, DROPOUT_PROB=0.2, NUM_LAYERS=2


Train Epoch 1: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 1 → Val Accuracy: 44.1558%


Train Epoch 2: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 2 → Val Accuracy: 44.1558%


Train Epoch 3: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 3 → Val Accuracy: 44.1558%


Train Epoch 4: 100%|██████████| 77/77 [00:28<00:00,  2.70it/s]


  Epoch 4 → Val Accuracy: 44.8052%


Train Epoch 5: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 5 → Val Accuracy: 44.8052%
✅ Best val acc for this config: 44.8052%

🔍 Training with LSTM_HID_DIM=256, DROPOUT_PROB=0.2, NUM_LAYERS=3


Train Epoch 1: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 1 → Val Accuracy: 44.1558%


Train Epoch 2: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 2 → Val Accuracy: 44.1558%


Train Epoch 3: 100%|██████████| 77/77 [00:31<00:00,  2.46it/s]


  Epoch 3 → Val Accuracy: 44.1558%


Train Epoch 4: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 4 → Val Accuracy: 44.1558%


Train Epoch 5: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 5 → Val Accuracy: 43.5065%
✅ Best val acc for this config: 44.1558%

🔍 Training with LSTM_HID_DIM=256, DROPOUT_PROB=0.3, NUM_LAYERS=2


Train Epoch 1: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 1 → Val Accuracy: 44.1558%


Train Epoch 2: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 2 → Val Accuracy: 44.1558%


Train Epoch 3: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 3 → Val Accuracy: 44.1558%


Train Epoch 4: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 4 → Val Accuracy: 44.1558%


Train Epoch 5: 100%|██████████| 77/77 [00:28<00:00,  2.71it/s]


  Epoch 5 → Val Accuracy: 44.1558%
✅ Best val acc for this config: 44.1558%

🔍 Training with LSTM_HID_DIM=256, DROPOUT_PROB=0.3, NUM_LAYERS=3


Train Epoch 1: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 1 → Val Accuracy: 44.1558%


Train Epoch 2: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 2 → Val Accuracy: 44.1558%


Train Epoch 3: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 3 → Val Accuracy: 44.1558%


Train Epoch 4: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 4 → Val Accuracy: 44.1558%


Train Epoch 5: 100%|██████████| 77/77 [00:31<00:00,  2.47it/s]


  Epoch 5 → Val Accuracy: 44.1558%
✅ Best val acc for this config: 44.1558%

🔍 Training with LSTM_HID_DIM=512, DROPOUT_PROB=0.1, NUM_LAYERS=2


Train Epoch 1: 100%|██████████| 77/77 [00:34<00:00,  2.20it/s]


  Epoch 1 → Val Accuracy: 44.8052%


Train Epoch 2: 100%|██████████| 77/77 [00:34<00:00,  2.20it/s]


  Epoch 2 → Val Accuracy: 44.8052%


Train Epoch 3: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 3 → Val Accuracy: 47.4026%


Train Epoch 4: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 4 → Val Accuracy: 50.6494%


Train Epoch 5: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 5 → Val Accuracy: 51.2987%
✅ Best val acc for this config: 51.2987%
🔥 New global best model saved (acc 51.2987%)

🔍 Training with LSTM_HID_DIM=512, DROPOUT_PROB=0.1, NUM_LAYERS=3


Train Epoch 1: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 1 → Val Accuracy: 46.1039%


Train Epoch 2: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 2 → Val Accuracy: 48.0519%


Train Epoch 3: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 3 → Val Accuracy: 49.3506%


Train Epoch 4: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 4 → Val Accuracy: 51.9481%


Train Epoch 5: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 5 → Val Accuracy: 51.2987%
✅ Best val acc for this config: 51.9481%
🔥 New global best model saved (acc 51.9481%)

🔍 Training with LSTM_HID_DIM=512, DROPOUT_PROB=0.2, NUM_LAYERS=2


Train Epoch 1: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 1 → Val Accuracy: 44.1558%


Train Epoch 2: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 2 → Val Accuracy: 44.8052%


Train Epoch 3: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 3 → Val Accuracy: 45.4545%


Train Epoch 4: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 4 → Val Accuracy: 44.8052%


Train Epoch 5: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 5 → Val Accuracy: 46.1039%
✅ Best val acc for this config: 46.1039%

🔍 Training with LSTM_HID_DIM=512, DROPOUT_PROB=0.2, NUM_LAYERS=3


Train Epoch 1: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 1 → Val Accuracy: 44.1558%


Train Epoch 2: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 2 → Val Accuracy: 45.4545%


Train Epoch 3: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 3 → Val Accuracy: 46.7532%


Train Epoch 4: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 4 → Val Accuracy: 46.7532%


Train Epoch 5: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 5 → Val Accuracy: 48.0519%
✅ Best val acc for this config: 48.0519%

🔍 Training with LSTM_HID_DIM=512, DROPOUT_PROB=0.3, NUM_LAYERS=2


Train Epoch 1: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 1 → Val Accuracy: 44.1558%


Train Epoch 2: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 2 → Val Accuracy: 44.1558%


Train Epoch 3: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 3 → Val Accuracy: 44.1558%


Train Epoch 4: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 4 → Val Accuracy: 44.1558%


Train Epoch 5: 100%|██████████| 77/77 [00:34<00:00,  2.21it/s]


  Epoch 5 → Val Accuracy: 44.8052%
✅ Best val acc for this config: 44.8052%

🔍 Training with LSTM_HID_DIM=512, DROPOUT_PROB=0.3, NUM_LAYERS=3


Train Epoch 1: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 1 → Val Accuracy: 44.1558%


Train Epoch 2: 100%|██████████| 77/77 [00:41<00:00,  1.88it/s]


  Epoch 2 → Val Accuracy: 44.1558%


Train Epoch 3: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 3 → Val Accuracy: 44.1558%


Train Epoch 4: 100%|██████████| 77/77 [00:40<00:00,  1.88it/s]


  Epoch 4 → Val Accuracy: 44.1558%


Train Epoch 5: 100%|██████████| 77/77 [00:41<00:00,  1.88it/s]


  Epoch 5 → Val Accuracy: 44.1558%
✅ Best val acc for this config: 44.1558%

🏆 Grid Search Done. Best Config: LSTM_HID_DIM=512, DROPOUT=0.1, LAYERS=3
✅ Highest Validation Accuracy: 51.9481%


In [2]:
from collections import Counter
from nltk.tokenize import word_tokenize

def build_vocab(claims, evidences, min_freq=2):
    counter = Counter()
    for obj in claims.values():
        counter.update(word_tokenize(obj["claim_text"].lower()))
    for text in evidences.values():
        counter.update(word_tokenize(text.lower()))
    vocab = {word: i+2 for i, (word, freq) in enumerate(counter.items()) if freq >= min_freq}
    vocab["<PAD>"] = 0
    vocab["<UNK>"] = 1
    return vocab

vocab = build_vocab(train_claims, evidence_dict)

In [3]:
class TokenDataset(Dataset):
    def __init__(self, claims, evidences, vocab, max_len):
        self.items = []
        for obj in claims.values():
            claim_text = obj["claim_text"]
            ev_ids = obj.get("evidences", [])
            ev_texts = [evidences[e] for e in ev_ids if e in evidences]
            text = claim_text + " " + " ".join(ev_texts)
            tokens = word_tokenize(text.lower())
            token_ids = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
            if len(token_ids) > max_len:
                token_ids = token_ids[:max_len]
            else:
                token_ids += [vocab["<PAD>"]] * (max_len - len(token_ids))
            label = label2idx[obj["claim_label"]]
            self.items.append((torch.tensor(token_ids), torch.tensor(label)))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        return self.items[idx]

def collate_tokens(batch):
    tokens, labels = zip(*batch)
    return torch.stack(tokens), torch.tensor(labels)

In [4]:
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hid_dim, num_classes, dropout):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hid_dim, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hid_dim, num_classes)

    def forward(self, input_ids, _):
        x = self.embed(input_ids)
        _, (h, _) = self.lstm(x)
        out = self.dropout(h[-1])
        return self.fc(out)

class BiLSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hid_dim, num_classes, dropout):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hid_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(2 * hid_dim, num_classes)

    def forward(self, input_ids, _):
        x = self.embed(input_ids)
        _, (h, _) = self.lstm(x)
        h = torch.cat((h[-2], h[-1]), dim=1)
        out = self.dropout(h)
        return self.fc(out)
    
class BiLSTMWithBert(nn.Module):
    def __init__(self, bert_name, lstm_hid, num_classes, dropout_prob):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        for p in self.bert.parameters():
            p.requires_grad = False
        bert_dim = self.bert.config.hidden_size

        self.lstm = nn.LSTM(
            input_size=bert_dim,
            hidden_size=lstm_hid,
            batch_first=True,
            bidirectional=True,
            dropout=dropout_prob
        )
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(2 * lstm_hid, num_classes)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        lstm_out, _ = self.lstm(bert_output)
        pooled = torch.mean(lstm_out, dim=1)
        return self.classifier(self.dropout(pooled))
    
class BiLSTMWithBertEncoder(nn.Module):
    def __init__(self, bert_name, lstm_hid, num_classes, dropout_prob, lstm_layers):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        for p in self.bert.parameters():
            p.requires_grad = False
        bert_dim = self.bert.config.hidden_size

        self.dropout_bert = nn.Dropout(dropout_prob)
        self.lstm = nn.LSTM(
            input_size=bert_dim,
            hidden_size=lstm_hid,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout_prob if lstm_layers > 1 else 0.0
        )
        self.attn_fc = nn.Linear(2 * lstm_hid, 1)
        self.dropout_pool = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(2 * lstm_hid, num_classes)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            seq_emb = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        seq_emb = self.dropout_bert(seq_emb)
        lstm_out, _ = self.lstm(seq_emb)
        scores = self.attn_fc(lstm_out).squeeze(-1)
        scores = scores.masked_fill(attention_mask == 0, -1e9)
        alphas = torch.softmax(scores, dim=1)
        pooled = torch.sum(lstm_out * alphas.unsqueeze(-1), dim=1)
        pooled = self.dropout_pool(pooled)
        return self.classifier(pooled)

In [None]:
def train_and_eval(model, train_dl, dev_dl):
    model = model.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    best_acc = 0.0

    for epoch in range(1, EPOCHS + 1):
        model.train()
        total_loss = 0.0
        loop = tqdm(train_dl, desc=f"Train Epoch {epoch}", leave=False, dynamic_ncols=True)
        for input_ids, *rest in loop:
            labels = rest[-1].to(DEVICE)
            input_ids = input_ids.to(DEVICE)
            attention_mask = rest[0].to(DEVICE) if len(rest) == 2 else None

            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        # ----- Evaluation after each epoch -----
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for input_ids, *rest in dev_dl:
                labels = rest[-1].to(DEVICE)
                input_ids = input_ids.to(DEVICE)
                attention_mask = rest[0].to(DEVICE) if len(rest) == 2 else None

                preds = model(input_ids, attention_mask).argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = correct / total
        best_acc = max(best_acc, acc)

    return best_acc

In [None]:
# Prepare token-level data
train_token_ds = TokenDataset(train_claims, evidence_dict, vocab, MAX_LEN)
dev_token_ds = TokenDataset(dev_claims, evidence_dict, vocab, MAX_LEN)
train_token_dl = DataLoader(train_token_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_tokens)
dev_token_dl = DataLoader(dev_token_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_tokens)

# 1. LSTM
lstm_model = LSTMClassifier(len(vocab), 300, 512, NUM_CLASSES, 0.1)
acc_lstm = train_and_eval(lstm_model, train_token_dl, dev_token_dl)
print(f"[LSTM] Dev Accuracy: {acc_lstm:.4%}")

# 2. BiLSTM
bilstm_model = BiLSTMClassifier(len(vocab), 300, 512, NUM_CLASSES, 0.1)
acc_bilstm = train_and_eval(bilstm_model, train_token_dl, dev_token_dl)
print(f"[BiLSTM] Dev Accuracy: {acc_bilstm:.4%}")

# 3. BiLSTM + BERT
bert_model = BiLSTMWithBert(BERT_MODEL, 512, NUM_CLASSES, 0.1)
acc_bert = train_and_eval(bert_model, train_dl, dev_dl)
print(f"[BiLSTM + BERT] Dev Accuracy: {acc_bert:.4%}")

# 4. BiLSTM + BERT + Attention
bert_attn_model = BiLSTMWithBertEncoder(BERT_MODEL, 512, NUM_CLASSES, 0.1, 3)
acc_bert_attn = train_and_eval(bert_attn_model, train_dl, dev_dl)
print(f"[BiLSTM + BERT + Attn] Dev Accuracy: {acc_bert_attn:.4%}")


[LSTM] Dev Accuracy: 47.4026%
[BiLSTM] Dev Accuracy: 48.0519%  
[BiLSTM + BERT] Dev Accuracy: 51.6883%
[BiLSTM + BERT + Attn] Dev Accuracy: 52.6364% 

