In [1]:
!pip install datasets transformers



In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from tqdm import tqdm


In [3]:
train_raw = load_dataset("tau/commonsense_qa", split="train[:-1000]")
valid_raw = load_dataset("tau/commonsense_qa", split="train[-1000:]")
test_raw  = load_dataset("tau/commonsense_qa", split="validation")

In [4]:
class HybridDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_length=80):
        self.dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label_map = {'A':0, 'B':1, 'C':2, 'D':3, 'E':4}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        question = item["question"]
        choices = item["choices"]["text"]
        input_texts = [f"{question} {choice}" for choice in choices]
        encoded = self.tokenizer(
            input_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return {
            "input_ids": encoded["input_ids"],           # [5, seq_len]
            "attention_mask": encoded["attention_mask"], # [5, seq_len]
            "label": torch.tensor(self.label_map[item["answerKey"]])
        }


In [5]:
class BiLSTMTransformerHybrid(nn.Module):
    def __init__(self, model_name="albert-base-v2", hidden_size=128, num_labels=5):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_name)
        self.bilstm = nn.LSTM(
            input_size=self.transformer.config.hidden_size,
            hidden_size=hidden_size,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size*2, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)  # Output a score for each choice
        )

    def forward(self, input_ids, attention_mask):
        batch_size, num_choices, seq_len = input_ids.size()
        input_ids = input_ids.view(-1, seq_len)           # [batch*num_choices, seq_len]
        attention_mask = attention_mask.view(-1, seq_len) # [batch*num_choices, seq_len]

        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state       # [batch*num_choices, seq_len, hidden]
        lstm_out, _ = self.bilstm(sequence_output)        # [batch*num_choices, seq_len, hidden*2]
        pooled = lstm_out.mean(dim=1)                     # [batch*num_choices, hidden*2]
        logits = self.classifier(pooled)                  # [batch*num_choices, 1]
        logits = logits.view(batch_size, num_choices)     # [batch, num_choices]
        return logits


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "albert-base-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

train_dataset = HybridDataset(train_raw, tokenizer)
valid_dataset = HybridDataset(valid_raw, tokenizer)
test_dataset  = HybridDataset(test_raw, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8)
test_loader  = DataLoader(test_dataset, batch_size=8)

model = BiLSTMTransformerHybrid(model_name=model_name).to(device)


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.lr)
loss_fn = nn.CrossEntropyLoss()

def train(model, dataloader):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        loss = loss_fn(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * input_ids.size(0)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += input_ids.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

def evaluate(model, dataloader):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
            loss = loss_fn(logits, labels)

            total_loss += loss.item() * input_ids.size(0)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += input_ids.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


In [11]:
for epoch in range(wandb.config.epochs):
    train_loss, train_acc = train(model, train_loader)
    val_loss, val_acc = evaluate(model, valid_loader)

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Acc={train_acc:.4f} | Val Loss={val_loss:.4f}, Acc={val_acc:.4f}")

    wandb.log({
        "epoch": epoch+1,
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "val_loss": val_loss,
        "val_accuracy": val_acc
    })

wandb.finish()

Epoch 1: Train Loss=1.6097, Acc=0.2042 | Val Loss=1.6093, Acc=0.2350
Epoch 2: Train Loss=1.6096, Acc=0.2103 | Val Loss=1.6050, Acc=0.2270
Epoch 3: Train Loss=1.5807, Acc=0.2592 | Val Loss=1.5094, Acc=0.3170
Epoch 4: Train Loss=1.4009, Acc=0.4064 | Val Loss=1.3983, Acc=0.4180
Epoch 5: Train Loss=1.0675, Acc=0.5774 | Val Loss=1.4494, Acc=0.4270
Epoch 6: Train Loss=0.6351, Acc=0.7614 | Val Loss=1.8551, Acc=0.4040
Epoch 7: Train Loss=0.2524, Acc=0.9112 | Val Loss=2.8140, Acc=0.4030
Epoch 8: Train Loss=0.1524, Acc=0.9491 | Val Loss=2.9544, Acc=0.4030
Epoch 9: Train Loss=0.0616, Acc=0.9809 | Val Loss=3.8757, Acc=0.3970
Epoch 10: Train Loss=0.0368, Acc=0.9894 | Val Loss=4.1081, Acc=0.4130
Epoch 11: Train Loss=0.0559, Acc=0.9807 | Val Loss=3.2454, Acc=0.4140
Epoch 12: Train Loss=0.0806, Acc=0.9754 | Val Loss=3.9594, Acc=0.4010
Epoch 13: Train Loss=0.0585, Acc=0.9825 | Val Loss=4.6006, Acc=0.3930
Epoch 14: Train Loss=0.0269, Acc=0.9919 | Val Loss=4.6027, Acc=0.3920
Epoch 15: Train Loss=0.0372, 

0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train_accuracy,▁▁▁▃▄▆▇██████████████████
train_loss,███▇▆▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▄██▇▇▇▇██▇▇▇▇▇▇▇▇▇▆▇▆▆▇
val_loss,▁▁▁▁▁▂▄▄▆▆▅▆▇▇█▇██▇█▅▇▇▆█

0,1
epoch,25.0
train_accuracy,0.97723
train_loss,0.0657
val_accuracy,0.396
val_loss,4.95431
