In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import thop  # For FLOPs profiling
import os
import pickle



Dataset Setup (MNIST)

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_train = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

fashion_train = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
fashion_test = datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)

cifar_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
cifar_train = datasets.CIFAR10(root='./data', train=True, transform=cifar_transform, download=True)
cifar_test = datasets.CIFAR10(root='./data', train=False, transform=cifar_transform, download=True)

mnist_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)

fashion_loader = DataLoader(fashion_train, batch_size=64, shuffle=True)
fashion_test_loader = DataLoader(fashion_test, batch_size=64, shuffle=False)

cifar_loader = DataLoader(cifar_train, batch_size=64, shuffle=True)
cifar_test_loader = DataLoader(cifar_test, batch_size=64, shuffle=False)

Define Norm Alternatives

In [3]:
class DyTNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        return self.alpha * (x - mean) + self.beta

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = x.norm(dim=-1, keepdim=True) / (x.size(-1) ** 0.5)
        return self.weight * (x / (rms + self.eps))

NormSelector Module

In [4]:
class NormSelector(nn.Module):
    def __init__(self, dim, freeze_dyt=False, freeze_ln=False):
        super().__init__()
        self.norms = nn.ModuleList([
            DyTNorm(dim),
            nn.LayerNorm(dim),
            RMSNorm(dim)
        ])
        self.weights = nn.Parameter(torch.ones(3))
        self.freeze_dyt = freeze_dyt
        self.freeze_ln = freeze_ln
        self.last_grad_norm = torch.zeros(3)
        self.grad_history = []

    def forward(self, x):
        ws = F.softmax(self.weights, dim=0)
        norms_out = [norm(x) for norm in self.norms]

        if self.freeze_dyt:
            ws = ws.clone()
            ws[0] = 0.0
        if self.freeze_ln:
            ws = ws.clone()
            ws[1] = 0.0

        ws = ws / ws.sum()
        out = sum(w * norm for w, norm in zip(ws, norms_out))

        if self.training and x.requires_grad:
            with torch.no_grad():
                for i, norm in enumerate(norms_out):
                    g = norm.detach().norm()
                    self.last_grad_norm[i] = g
                self.grad_history.append(self.last_grad_norm.clone())

        return out

    def show_norm_weights(self):
        ws = F.softmax(self.weights, dim=0)
        return {'DyT': ws[0].item(), 'LN': ws[1].item(), 'RMS': ws[2].item()}

    def show_gradient_norms(self):
        return self.last_grad_norm.tolist()

    def plot_gradient_norms(self, tag=''):
        arr = torch.stack(self.grad_history).cpu().numpy()
        plt.figure()
        plt.plot(arr[:, 0], label='DyT GradNorm')
        plt.plot(arr[:, 1], label='LN GradNorm')
        plt.plot(arr[:, 2], label='RMS GradNorm')
        plt.legend()
        plt.title(f'Gradient Norms per Layer - {tag}')
        plt.xlabel('Step')
        plt.ylabel('Grad Norm')
        plt.savefig(f'plot_gradnorms_{tag}.png')
        plt.close()

Transformer Block with AutoNorm

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads=2, ff_dim=128, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
        self.norm1 = NormSelector(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, dim)
        )
        self.norm2 = NormSelector(dim)

    def forward(self, x):
        x2, _ = self.attn(x, x, x)
        x = self.norm1(x + x2)
        x2 = self.ff(x)
        x = self.norm2(x + x2)
        return x

Teacher Transformer

In [6]:
class TransformerBlockTeacher(nn.Module):
    def __init__(self, dim, heads=2, ff_dim=128, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, dim)
        )
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        x2, _ = self.attn(x, x, x)
        x = self.norm1(x + x2)
        x2 = self.ff(x)
        x = self.norm2(x + x2)
        return x

Transformer Classifier

In [7]:
class SimpleTransformer(nn.Module):
    def __init__(self, input_dim=28, model_dim=64, num_classes=10, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Linear(input_dim, model_dim)
        self.layers = nn.Sequential(*[
            TransformerBlock(model_dim, dropout=dropout) for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(model_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.layers(x)
        return self.classifier(x.mean(dim=1))

class SimpleTransformerTeacher(nn.Module):
    def __init__(self, input_dim=28, model_dim=64, num_classes=10, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Linear(input_dim, model_dim)
        self.layers = nn.Sequential(*[
            TransformerBlockTeacher(model_dim, dropout=dropout) for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(model_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.layers(x)
        return self.classifier(x.mean(dim=1))


Save + Load Model

In [8]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

def load_model(model_class, path, *args, **kwargs):
    model = model_class(*args, **kwargs)
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

Train & Evaluate Pipeline

In [11]:
def get_loader(dataset_name, batch_size=64):
    transform = transforms.ToTensor()
    if dataset_name == 'MNIST':
        dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
        testset = datasets.MNIST('data', train=False, transform=transform)
    elif dataset_name == 'FashionMNIST':
        dataset = datasets.FashionMNIST('data', train=True, download=True, transform=transform)
        testset = datasets.FashionMNIST('data', train=False, transform=transform)
    elif dataset_name == 'CIFAR10':
        transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
        dataset = datasets.CIFAR10('data', train=True, download=True, transform=transform)
        testset = datasets.CIFAR10('data', train=False, transform=transform)
    else:
        raise ValueError("Unsupported dataset")
    return DataLoader(dataset, batch_size=batch_size, shuffle=True), DataLoader(testset, batch_size=batch_size)

def train_model(model, dataloader, optimizer, epochs=5, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    model.train()
    loss_list = []
    for epoch in range(epochs):
        total_loss = 0
        for img, label in dataloader:
            img, label = img.to(device), label.to(device)
            img = img.squeeze(1)
            optimizer.zero_grad()
            out = model(img)
            loss = criterion(out, label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        loss_list.append(avg_loss)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
    return loss_list

def evaluate_model(model, dataloader, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
):
    model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for img, label in dataloader:
            img, label = img.to(device), label.to(device)
            img = img.squeeze(1)
            out = model(img)
            pred = out.argmax(dim=1)
            correct += (pred == label).sum().item()
            total += label.size(0)
    return correct / total

def run_experiment(model_class, save_path, dataset_train='MNIST', dataset_test='FashionMNIST'):
    print(f"Training on {dataset_train}, testing on {dataset_test}")
    train_loader, _ = get_loader(dataset_train)
    _, test_loader = get_loader(dataset_test)

    model = model_class()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    train_model(model, train_loader, optimizer, epochs=5)
    save_model(model, save_path)
    acc = evaluate_model(model, test_loader)
    print(f"Transfer Accuracy on {dataset_test}: {acc:.4f}")

In [12]:
run_experiment(SimpleTransformer, 'autonorm_mnist.pth', 'MNIST', 'FashionMNIST')

Training on MNIST, testing on FashionMNIST
Epoch 1, Loss: 0.8084
Epoch 2, Loss: 0.4795
Epoch 3, Loss: 0.4080
Epoch 4, Loss: 0.3646
Epoch 5, Loss: 0.3340
Transfer Accuracy on FashionMNIST: 0.0427


In [13]:
run_experiment(SimpleTransformerTeacher, 'teacher_mnist.pth', 'MNIST', 'CIFAR10')

Training on MNIST, testing on CIFAR10
Epoch 1, Loss: 0.8117
Epoch 2, Loss: 0.4567
Epoch 3, Loss: 0.3883
Epoch 4, Loss: 0.3543
Epoch 5, Loss: 0.3270


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x32 and 28x64)