In [13]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm 

In [30]:
# --- 1. Model Configuration ---

# Hyperparameters for our toy model. These are kept small to ensure
# it runs quickly on a standard laptop CPU.
VOCAB_SIZE = 30522  # Standard for bert-base-uncased tokenizer
EMBED_DIM = 128     # Dimension of token embeddings
HIDDEN_DIM = 256    # Dimension of the expert's hidden layer
NUM_EXPERTS = 4     # The number of experts in our MoE layer
TOP_K = 2           # Number of experts to route each token to
NUM_CLASSES = 4    # For the 20 Newsgroups dataset
NUM_EPOCHS = 5
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

Using device: cpu


In [19]:
# --- 2. MoE Model Definition ---

class Expert(nn.Module):
    """A simple feed-forward network to be used as an expert."""
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim) # Project back to the original dimension
        )

    def forward(self, x):
        return self.net(x)

class MoELayer(nn.Module):
    """
    The core Mixture-of-Experts layer.
    This layer takes a batch of tokens and routes each token to the top-k experts.
    """
    def __init__(self, embed_dim, hidden_dim, num_experts, top_k):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Create the pool of experts
        self.experts = nn.ModuleList([Expert(embed_dim, hidden_dim) for _ in range(num_experts)])
        
        # The gating network (router) is a simple linear layer that outputs
        # a logit for each expert.
        self.router = nn.Linear(embed_dim, num_experts)

    def forward(self, x):
        # x shape: (batch_size, sequence_length, embed_dim)
        batch_size, seq_len, embed_dim = x.shape

        # Reshape the input to be a flat list of tokens for easier processing
        x_flat = x.view(-1, embed_dim) # Shape: (batch_size * seq_len, embed_dim)
        num_tokens = x_flat.shape[0]

        # 1. Gating / Routing
        # Get the logits from the router for each token
        router_logits = self.router(x_flat) # Shape: (num_tokens, num_experts)

        # Find the top-k experts and their corresponding routing weights
        routing_weights, chosen_expert_indices = torch.topk(router_logits, self.top_k, dim=-1)
        routing_weights = F.softmax(routing_weights, dim=-1) # Softmax over the top-k logits

        # 2. Expert Processing
        # Initialize the final output tensor
        final_output = torch.zeros_like(x_flat)
        
        # Create a flat index to map tokens to their chosen experts
        flat_expert_indices = chosen_expert_indices.view(-1)
        
        # Create a tensor that maps each token to its position in the batch
        token_batch_map = torch.arange(num_tokens, device=x.device).repeat_interleave(self.top_k)

        # Get the expert outputs for all tokens and all chosen experts
        # This is a more complex but efficient way to handle batching for top-k
        expert_outputs = torch.zeros_like(x_flat)
        for i in range(self.num_experts):
            # Find which tokens have this expert in their top-k list
            mask = (chosen_expert_indices == i).any(dim=-1)
            if mask.any():
                expert_outputs[mask] = self.experts[i](x_flat[mask])

        # Combine the expert outputs using the routing weights
        # We need to gather the correct expert outputs for each token
        # and multiply by the corresponding routing weight.
        weighted_outputs = torch.zeros_like(x_flat)
        for i in range(self.top_k):
            expert_idx = chosen_expert_indices[:, i]
            weight = routing_weights[:, i].unsqueeze(1)
            
            # Gather the outputs from the correct experts
            # This is a bit complex, but it avoids a slow loop
            current_expert_outputs = torch.zeros_like(x_flat)
            for j in range(self.num_experts):
                mask = (expert_idx == j)
                if mask.any():
                    current_expert_outputs[mask] = self.experts[j](x_flat[mask])
            
            weighted_outputs += weight * current_expert_outputs

        # Reshape the output back to the original input shape
        return weighted_outputs.view(batch_size, seq_len, embed_dim)


class TinyMoEForClassification(nn.Module):
    """The main model that uses the MoE layer for classification."""
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_experts, top_k, num_classes):
        super().__init__()
        # The embedding layer turns token IDs into dense vectors
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Our custom MoE layer
        self.moe_layer = MoELayer(embed_dim, hidden_dim, num_experts, top_k)
        
        # A simple linear layer to map the output of the MoE layer to class predictions
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, input_ids, attention_mask=None):
        # input_ids shape: (batch_size, sequence_length)
        embedded = self.embedding(input_ids) # Shape: (batch_size, seq_len, embed_dim)
        
        moe_output = self.moe_layer(embedded) # Shape: (batch_size, seq_len, embed_dim)

        # We use the representation of the first token ([CLS] token) for classification.
        # This is a common practice in models like BERT.
        cls_token_output = moe_output[:, 0] # Shape: (batch_size, embed_dim)

        # Get the final logits for each class
        logits = self.classifier(cls_token_output) # Shape: (batch_size, num_classes)
        return logits


In [28]:
# --- 3. Data Preparation ---

print("Preparing data...")
# Load a small, fast tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Load the 20 Newsgroups dataset
# OLD
# dataset = load_dataset("SetFit/20_newsgroups")

# NEW
dataset = load_dataset("ag_news")
train_dataset = dataset["train"]
test_dataset = dataset["test"]
# OLD
# class_names = ['alt.atheism', 'comp.graphics', ...] # (20 names)

# NEW
class_names = ['World', 'Sports', 'Business', 'Sci/Tech']

# Define a collate function to process batches of data for the DataLoader
def collate_fn(batch):
    texts = [item['text'] for item in batch]
    labels = [item['label'] for item in batch]
    
    # The tokenizer handles padding, truncation, and tensor conversion
    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding="max_length", # Pad to a fixed length
        truncation=True,
        max_length=256 # Use a smaller max_length for faster training
    )
    
    inputs['labels'] = torch.tensor(labels)
    return inputs

# Create DataLoaders for training and testing
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

print("Data preparation complete.")



Preparing data...
Data preparation complete.


In [31]:
model = TinyMoEForClassification(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_experts=NUM_EXPERTS,
    top_k=TOP_K,
    num_classes=NUM_CLASSES
).to(DEVICE)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

print("\n--- Starting Training ---")
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    
    # Use tqdm for a nice progress bar during training
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        # Move batch to the device
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        # Forward pass
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} | Average Training Loss: {avg_train_loss:.4f}")

print("\n--- Training Complete ---")



--- Starting Training ---


Epoch 1/5: 100%|██████████| 3750/3750 [10:29<00:00,  5.95it/s]


Epoch 1 | Average Training Loss: 1.3868


Epoch 2/5: 100%|██████████| 3750/3750 [13:54<00:00,  4.49it/s]


Epoch 2 | Average Training Loss: 1.3865


Epoch 3/5: 100%|██████████| 3750/3750 [12:32<00:00,  4.98it/s]


Epoch 3 | Average Training Loss: 1.3865


Epoch 4/5: 100%|██████████| 3750/3750 [12:23<00:00,  5.04it/s]


Epoch 4 | Average Training Loss: 1.3865


Epoch 5/5: 100%|██████████| 3750/3750 [12:14<00:00,  5.11it/s]

Epoch 5 | Average Training Loss: 1.3864

--- Training Complete ---





In [32]:
# --- Basic Analysis Example ---
# Now you can use this trained model as the subject for your MoE-Diag toolkit.
# Here's a quick example of how you might inspect the router for a single batch.
print("\n--- Running Basic Analysis Example ---")
model.eval()
with torch.no_grad():
    # Get one batch from the test set
    sample_batch = next(iter(test_loader))
    input_ids = sample_batch['input_ids'].to(DEVICE)
    
    # Manually perform the first few steps of the forward pass to get router logits
    embedded = model.embedding(input_ids)
    x_flat = embedded.view(-1, EMBED_DIM)
    router_logits = model.moe_layer.router(x_flat)
    routing_weights = F.softmax(router_logits, dim=1)
    chosen_expert_indices = torch.argmax(routing_weights, dim=1)

    # Count how many tokens were assigned to each expert in this batch
    expert_counts = torch.bincount(chosen_expert_indices, minlength=NUM_EXPERTS)
    
    print("Expert utilization for one sample batch:")
    for i, count in enumerate(expert_counts):
        print(f"  Expert {i}: {count.item()} tokens")




--- Running Basic Analysis Example ---
Expert utilization for one sample batch:
  Expert 0: 422 tokens
  Expert 1: 6700 tokens
  Expert 2: 535 tokens
  Expert 3: 535 tokens


In [33]:
# --- 4. Ablation Analysis Function ---

def run_ablation_analysis(model, test_loader, device, num_classes, num_experts, class_names):
    """
    Performs counterfactual analysis by ablating each expert one by one and
    measuring the impact on per-class accuracy.
    """
    print("\n--- Starting Ablation Analysis ---")
    model.eval()
    
    def evaluate_per_class_accuracy():
        """Helper function to evaluate the model and return per-class accuracy."""
        class_correct = [0] * num_classes
        class_totals = [0] * num_classes
        
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Evaluating", leave=False):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(input_ids, attention_mask)
                _, predicted = torch.max(outputs.data, 1)
                
                for i in range(len(labels)):
                    label = labels[i]
                    class_correct[label] += (predicted[i] == label).item()
                    class_totals[label] += 1
        
        # Calculate accuracy for each class, avoiding division by zero
        accuracies = [(class_correct[i] / class_totals[i]) * 100 if class_totals[i] > 0 else 0 
                      for i in range(num_classes)]
        return accuracies

    # 1. Get baseline accuracy with the full model
    print("Calculating baseline accuracy...")
    baseline_accuracies = evaluate_per_class_accuracy()
    
    # --- NEW: Print baseline accuracies for context ---
    print("\nBaseline Per-Class Accuracy (%):")
    baseline_df = pd.DataFrame([baseline_accuracies], 
                               columns=[name[:15] for name in class_names], 
                               index=['Baseline Acc.'])
    print(baseline_df.round(2))


    # 2. Iterate through each expert, ablate it, and re-evaluate
    ablation_results = []
    for expert_to_ablate in range(num_experts):
        print(f"\nAblating Expert {expert_to_ablate}...")
        
        # Store original weights to restore them later
        original_weights = [p.clone().detach() for p in model.moe_layer.experts[expert_to_ablate].parameters()]

        # Zero out the weights of the current expert
        with torch.no_grad():
            for param in model.moe_layer.experts[expert_to_ablate].parameters():
                param.data.fill_(0)
        
        # Evaluate the model with the ablated expert
        ablated_accuracies = evaluate_per_class_accuracy()
        ablation_results.append(ablated_accuracies)

        # IMPORTANT: Restore the original weights
        with torch.no_grad():
            for i, param in enumerate(model.moe_layer.experts[expert_to_ablate].parameters()):
                param.data.copy_(original_weights[i])
        print(f"Restored Expert {expert_to_ablate}.")

    # 3. Present the results in a clear table
    print("\n--- Ablation Analysis Results ---")
    
    # Calculate the accuracy drop (Baseline - Ablated)
    accuracy_drops = []
    for expert_accuracies in ablation_results:
        drop = [base - ablated for base, ablated in zip(baseline_accuracies, expert_accuracies)]
        accuracy_drops.append(drop)

    # Use pandas to create a DataFrame for nice formatting
    df = pd.DataFrame(accuracy_drops,
                      columns=[name[:15] for name in class_names], # Truncate long class names
                      index=[f"Ablate Expert {i}" for i in range(num_experts)])

    # Display the accuracy drop. A high positive number means ablating that expert
    # significantly hurt performance for that class, indicating specialization.
    print("\nAccuracy Drop (%) After Ablating Each Expert:")
    pd.set_option('display.width', 1000) # Widen pandas output
    print(df.round(2))

In [34]:
run_ablation_analysis(model, test_loader, DEVICE, NUM_CLASSES, NUM_EXPERTS, class_names)



--- Starting Ablation Analysis ---
Calculating baseline accuracy...


                                                             


Baseline Per-Class Accuracy (%):
               World  Sports  Business  Sci/Tech
Baseline Acc.    0.0     0.0       0.0     100.0

Ablating Expert 0...


                                                             

Restored Expert 0.

Ablating Expert 1...


                                                             

Restored Expert 1.

Ablating Expert 2...


                                                             

Restored Expert 2.

Ablating Expert 3...


                                                             

Restored Expert 3.

--- Ablation Analysis Results ---

Accuracy Drop (%) After Ablating Each Expert:
                 World  Sports  Business  Sci/Tech
Ablate Expert 0    0.0     0.0       0.0       0.0
Ablate Expert 1    0.0     0.0       0.0       0.0
Ablate Expert 2    0.0     0.0       0.0       0.0
Ablate Expert 3    0.0     0.0       0.0       0.0


