In [None]:
pip install transformers datasets torch sentencepiece

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AlbertModel, AlbertTokenizer, get_scheduler
from datasets import load_dataset

# Configuration
MODEL_NAME = "albert-base-v2"
BATCH_SIZE = 16
EPOCHS = 3
LR = 2e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Tokenizer
tokenizer = AlbertTokenizer.from_pretrained(MODEL_NAME)

# Load TREC Dataset
dataset = load_dataset("trec")
labels_list = dataset["train"].features["coarse_label"].names  # Class names
NUM_CLASSES = len(labels_list)  # 6 classes

# Custom Dataset Class
class TrecDataset(Dataset):
    def __init__(self, split):
        self.data = dataset[split]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data[idx]["text"]
        label = self.data[idx]["coarse_label"]
        inputs = tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }

# Dataloaders
train_loader = DataLoader(TrecDataset("train"), batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(TrecDataset("test"), batch_size=BATCH_SIZE)

# Custom ALBERT Model
class CustomAlbertClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.albert = AlbertModel.from_pretrained(MODEL_NAME)
        self.classifier = nn.Linear(self.albert.config.hidden_size, num_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.albert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # [CLS] token output
        return self.classifier(pooled_output)

# Initialize Model
model = CustomAlbertClassifier(NUM_CLASSES).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()
scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_loader) * EPOCHS)

# Training Loop
def train():
    model.train()
    for epoch in range(EPOCHS):
        total_loss, correct, total = 0, 0, 0
        for batch in train_loader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)

        scheduler.step()
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Accuracy={correct/total:.4f}")

# Evaluation Function
def evaluate():
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            outputs = model(input_ids, attention_mask)
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
    print(f"Test Accuracy: {correct / total:.4f}")

# Run Training and Evaluation
train()
evaluate()


Downloading spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/5.09k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/336k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.4k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5452 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/500 [00:00<?, ? examples/s]

Downloading model.safetensors:   0%|          | 0.00/47.4M [00:00<?, ?B/s]

Epoch 1: Loss=0.4550, Accuracy=0.8465
Epoch 2: Loss=0.4757, Accuracy=0.8208
Epoch 3: Loss=0.2379, Accuracy=0.9382
Test Accuracy: 0.9540
