In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from MoE_base import MoE

In [2]:
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

In [3]:
# Dataset Class Labels:
# 1: World | 2: Sports | 3: Business | 4: Sci/Tech
# Inputs are news article titles
dataset = load_dataset("ag_news")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [4]:
input_dim = 128
def tokenize(example):
    return tokenizer(example["text"], 
                     padding="max_length", 
                     truncation=True,
                     max_length=input_dim)

dataset = dataset.map(tokenize, batched=True)
dataset.set_format(type="torch", columns=["input_ids", "label"])

Map:   0%|          | 0/120000 [00:00<?, ? examples/s]

Map:   0%|          | 0/7600 [00:00<?, ? examples/s]

In [5]:
batch_size = 32
train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset['test'], batch_size=batch_size)
num_classes = 4

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MoE(num_experts=num_classes, input_dim=input_dim, output_dim=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 
                             lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    'min',
    patience=2,
    threshold=1e-2
)


### Training Loop

In [None]:
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_loader:
        inputs, labels = batch["input_ids"].to(device), batch["label"].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs.float())  # Convert input_ids to float for dense layers
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    scheduler.step(loss)
    print(f"{scheduler.get_last_lr()=}")
        

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}")

scheduler.get_last_lr()=[0.1]
Epoch 1/10, Loss: 71.3735
scheduler.get_last_lr()=[0.1]
Epoch 2/10, Loss: 1.3937
scheduler.get_last_lr()=[0.1]
Epoch 3/10, Loss: 1.3934
scheduler.get_last_lr()=[0.1]
Epoch 4/10, Loss: 1.3933


### Evaluation

In [61]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch["input_ids"].to(device), batch["label"].to(device)
        outputs = model(inputs.float())
        predictions = torch.argmax(outputs, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Test Accuracy: {accuracy:.4f}")

Test Accuracy: 0.2500
