In [None]:
import requests
import matplotlib.pyplot as plt
import numpy as np
import typing as tp
import time

import torch
from torch import (
    nn, Tensor
)
import torch.utils.data
import torchvision

In [None]:
transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), # [0, 255] -> [0.0, 1.0]
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # [0.0, 1.0] -> [-1.0, 1.0]
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True,
    transform=transforms
)
train_data = trainset.data
y_train = trainset.targets

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True,
    # transform=transforms
)
X_VAL = torch.stack([transforms(x) for x in testset.data])
y_val = torch.tensor(testset.targets)
NUM_CLASSES:int = len(np.unique(y_train)) # 10

FAN_IN:int = train_data.shape[1] * train_data.shape[2] * train_data.shape[3] # (3*32*32) = 3072

In [None]:
testset.data[0].shape

In [None]:
def plot_image(X:np.ndarray, y:np.ndarray, labels=trainset.classes):
    plt.figure(figsize=(1, 1))
    plt.imshow(X)
    plt.title(labels[y])
    plt.axis('off')
    plt.show()

In [None]:
plot_image(train_data[0], y_train[0])

In [None]:
class DenseModel(nn.Module):
    def __init__(
        self,
        fan_in:int,
        hidden_units:int,
        fan_out:int
    ):
        super().__init__()
        self.dense1 = nn.Linear(fan_in, hidden_units); self.relu1 = nn.ReLU()
        self.dense2 = nn.Linear(hidden_units, hidden_units); self.relu2 = nn.ReLU()
        self.dense3 = nn.Linear(hidden_units, fan_out); self.softmax = nn.Softmax(dim=-1)

    def forward(self, x:Tensor) -> Tensor: # (B, C=3, H=32, W=32)
        x = x.flatten(1)                # (B, C*H*W)
        x = self.relu1(self.dense1(x))  # (B, hidden_units)
        x = self.relu2(self.dense2(x))  # (B, hidden_units)
        return self.dense3(x)           # (B, fan_out)
    
print("DenseModel:", DenseModel(3*32*32, 100, 10))

In [None]:
class MixtureOfExperts(nn.Module):
    def __init__(self, num_experts:int, fan_in:int, fan_out:int, topk:int):
        super().__init__()
        self.num_experts = num_experts
        self.topk = topk

        self.gate_values_func = nn.Linear(fan_in, num_experts, bias=False)
        self.load_balancing_noise = nn.Linear(fan_in, num_experts, bias=False)

        self.expert_modules = nn.ModuleList([
            nn.Linear(fan_in, fan_out) for _ in range(num_experts)
        ])

    def gate_network(self, x:Tensor, topk:int):
        """takes input x and returns the gate values for each expert"""
        hx = (
            self.gate_values_func(x) +
            torch.normal(mean=0, std=1, size=(x.size(0), self.num_experts), device=x.device) * self.load_balancing_noise(x)
        ) # (B, num_experts)

        topk_values, topk_indices = torch.topk(hx, topk, dim=1)

        hx = torch.full_like(hx, -torch.inf)
        hx.scatter_(dim=1, index=topk_indices, src=topk_values)
        return torch.softmax(hx, dim=1)
        
    def expert_gate_dot_product(
        self,
        x:Tensor,           # (B, fan_in)
        gate_values:Tensor  # (B, num_experts)
    ):
        """when the gated value function returns 0, we need not compute that expert function"""
        expert_outputs:list[Tensor] = [] # (B, 1, fan_out)
        # (1, fan_in) # (num_experts,)
        for xi,       expert_gate_vals in zip(x.unsqueeze(1), gate_values):
            per_batch_expert_outputs = [] # (num_sel_experts, 1, fan_out)
            #   (,)
            for gate_val, expert_module in zip(expert_gate_vals, self.expert_modules):
                if gate_val == 0:
                    continue
                per_batch_expert_outputs.append(expert_module(xi)*gate_val) # (1, fan_out)
            out:Tensor = torch.stack(per_batch_expert_outputs).sum(dim=0)
            expert_outputs.append(out) # (1, fan_out)
        return torch.cat(expert_outputs, dim=0).squeeze(1) # (B, 1, fan_out) => (B, fan_out)

    def forward(self, x:Tensor): # (B, fan_in)
        gate_values = self.gate_network(x, self.topk) # (B, num_experts)
        x = self.expert_gate_dot_product(x, gate_values)
        return x, gate_values


class MoEDenseModel(nn.Module):
    def __init__(
        self,
        fan_in:int,
        hidden_units:int,
        fan_out:int,
        num_experts:int,
        topk:int
    ):
        super().__init__()
        self.num_experts = num_experts

        self.dense1 = nn.Linear(fan_in, hidden_units); self.relu1 = nn.ReLU()
        self.moe_dense2 = MixtureOfExperts(num_experts, hidden_units, hidden_units, topk); self.relu2 = nn.ReLU()
        self.dense3 = nn.Linear(hidden_units, fan_out); self.softmax = nn.Softmax(dim=-1)

    def forward(self, x:Tensor):
        x = x.flatten(1)                    # (B, fan_in=C*H*W)
        x = self.relu1(self.dense1(x))      # (B, hidden_units)
        x, gate_values = self.moe_dense2(x) # (B, hidden_units)
        x = self.relu2(x)
        x = self.softmax(self.dense3(x))    # (B, fan_out)
        return x, gate_values               # (B, fan_out), (B, num_experts)

In [None]:
def loss_fn(
    y_pred:Tensor, # (B, fan_out)
    y_true:Tensor, # (B,)
    gate_values:Tensor, # (B, num_experts),
    importance_weight:float,
    load_balancing_weight:float
):
    Lce = nn.functional.cross_entropy(y_pred, y_true)

    # Balancing expert utilization/importance
    importance = gate_values.sum(dim=0) # (num_experts,)
    Limportance = importance_weight * importance.std().div(importance.mean()).square()

    # Load balancing Loss: encourage experts to receive equal number of samples. (switch transormer type)
    num_experts = gate_values.size(1)
    importance_mean = gate_values.mean(dim=0)  # avg prob mass per expert
    load = (gate_values > 0).float().mean(dim=0)  # fraction of tokens sent to each expert
    Lload = load_balancing_weight * (importance_mean * load).mean() * (num_experts ** 2)

    return Lce, Limportance, Lload

In [None]:
def get_accuracy(y_true:Tensor, y_probs:Tensor):
    y_pred = torch.argmax(y_probs, dim=-1)
    return (y_true==y_pred).float().sum()/len(y_true)

In [None]:
class config:
    lr:float = 1e-3
    weight_decay:float = 0.0
    batch_size:int = 32
    num_epochs:int = 50

    HIDDEN_UNITS:int = 1024
    
    NUM_EXPERTS:int = 10
    TOPK:int = 3
    IMPORTANCE_WEIGHT:float = 0.1
    LOAD_BALANCING_WEIGHT:float = 0.05

In [None]:
naive_model = DenseModel(
    fan_in=FAN_IN,
    hidden_units=config.HIDDEN_UNITS,
    fan_out=NUM_CLASSES
); print("Number of parameters in naive_model:", sum(p.numel() for p in naive_model.parameters())/1e6, "Million")
moe_model = MoEDenseModel(
    fan_in=FAN_IN,
    hidden_units=config.HIDDEN_UNITS,
    fan_out=NUM_CLASSES,
    num_experts=config.NUM_EXPERTS,
    topk=config.TOPK
); print("Number of parameters in moe_model:", sum(p.numel() for p in moe_model.parameters())/1e6, "Million")

train_loader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, shuffle=False)

In [None]:
def train_model(model:DenseModel|MoEDenseModel, loss_fn:tp.Callable):
    try:
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        
        loss_history = {}
        accuracy_history = {}
        total_steps = 0
        for epoch in range(1, config.num_epochs+1):
            t0 = time.time()
            for step, (X, y_true) in enumerate(train_loader):
                total_steps += 1
                y_pred = model(X) # (B, num_classes)
                losses:tp.Sequence[Tensor] = loss_fn(y_pred, y_true, importance_weight=config.IMPORTANCE_WEIGHT, load_balancing_weight=config.LOAD_BALANCING_WEIGHT)
                loss:Tensor = sum(losses, start=torch.tensor(0.0, device=X.device))
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                loss_history[total_steps] = losses[0].cpu().detach().item()
            
            t1 = time.time()
            acc = get_accuracy(y_val, model(X_VAL))
            accuracy_history[total_steps] = acc
            print(f"|| Epoch: {epoch} || Loss: {loss_history[total_steps]:.4f} || Accuracy: {accuracy_history[total_steps]:.4f} || dt: {(t1-t0):.4f}s ||")
    except KeyboardInterrupt:
        print("Training interrupted by user.")
    return loss_history, accuracy_history

In [None]:
naivemodel_losses, naivemodel_accuracies = train_model(
    naive_model, loss_fn=lambda pred, true, **kwargs: (nn.functional.cross_entropy(pred, true), 0, 0)
)

In [None]:
plt.plot(list(naivemodel_losses.keys()), list(naivemodel_losses.values()))
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Naive Model Loss')
plt.grid(True)
plt.show()
plt.plot(list(naivemodel_accuracies.keys()), list(naivemodel_accuracies.values()))
plt.xlabel('Steps')
plt.ylabel('Accuracy')
plt.title('Naive Model Accuracy')
plt.grid(True)
plt.show()

In [None]:
moemodel_losses, moemodel_accuracies = train_model(
    moe_model, loss_fn=loss_fn
)

In [None]:
plt.plot(list(moemodel_losses.keys()), list(moemodel_losses.values()))
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('MoE Model Loss')
plt.grid(True)
plt.show()
plt.plot(list(moemodel_accuracies.keys()), list(moemodel_accuracies.values()))
plt.xlabel('Steps')
plt.ylabel('Accuracy')
plt.title('MoE Model Accuracy')
plt.grid(True)
plt.show()