In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import networkx as nx
import time

torch.manual_seed(42)
np.random.seed(42)

In [3]:
# --- ATTENTION MODULES ---

class SoftmaxAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, N, self.num_heads, -1).transpose(1, 2), qkv)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)
        self.eps = 1e-6

    def feature_map(self, x):
        return torch.nn.functional.relu(x) + self.eps

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, N, self.num_heads, -1).transpose(1, 2), qkv)
        q = self.feature_map(q)
        k = self.feature_map(k)
        kv = k.transpose(-2, -1) @ v
        z = 1 / (q @ k.sum(dim=-2, keepdim=True).transpose(-2, -1) + self.eps)
        out = (q @ kv) * z
        out = out.transpose(1, 2).reshape(B, N, C)
        return self.to_out(out)

class GRFExactAttention(nn.Module):
    def __init__(self, dim, num_heads, num_patches, n_walks, p_halt, device):
        super().__init__()
        self.num_heads = num_heads
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)
        self.eps = 1e-6
        self.register_buffer('mask', self._generate_grf_mask(num_patches, n_walks, p_halt, device))

    def _generate_grf_mask(self, N, n_walks, p_halt, device):
        side = int(np.sqrt(N))
        G = nx.grid_2d_graph(side, side)
        mapping = {node: i for i, node in enumerate(sorted(list(G.nodes())))}
        G = nx.relabel_nodes(G, mapping)
        mask = torch.zeros(N, N)
        for start_node in range(N):
            for _ in range(n_walks):
                curr = start_node
                while True:
                    mask[start_node, curr] += 1.0
                    if np.random.rand() < p_halt: break
                    neighbors = sorted(list(G.neighbors(curr)))
                    if not neighbors: break
                    curr = np.random.choice(neighbors)
            mask[start_node] /= max(n_walks, 1)
        return mask.to(device)

    def feature_map(self, x):
        return torch.nn.functional.relu(x) + self.eps

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, N, self.num_heads, -1).transpose(1, 2), qkv)
        q = self.feature_map(q)
        k = self.feature_map(k)

        q_graph = (q.transpose(-2, -1) @ self.mask).transpose(-2, -1)
        k_graph = (k.transpose(-2, -1) @ self.mask).transpose(-2, -1)
        q = q + 0.1 * q_graph
        k = k + 0.1 * k_graph

        linear_kernel = q @ k.transpose(-2, -1)
        masked_kernel = linear_kernel * self.mask.unsqueeze(0).unsqueeze(0)
        z = 1 / (masked_kernel.sum(dim=-1, keepdim=True) + self.eps)
        out = (masked_kernel @ v) * z
        out = out.transpose(1, 2).reshape(B, N, C)
        return self.to_out(out)

class MAlphaAttention(nn.Module):
    def __init__(self, dim, num_heads, num_patches, device, order=5, decay=0.5):
        super().__init__()
        self.num_heads = num_heads
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)
        self.eps = 1e-6
        self.register_buffer('mask', self._generate_exact_mask(num_patches, order, decay, device))

    def _generate_exact_mask(self, N, order, decay, device):
        side = int(np.sqrt(N))
        G = nx.grid_2d_graph(side, side)
        mapping = {node: i for i, node in enumerate(sorted(list(G.nodes())))}
        G = nx.relabel_nodes(G, mapping)
        A = nx.to_numpy_array(G)
        D_inv = np.diag(1.0 / np.maximum(A.sum(axis=1), 1))
        W = D_inv @ A
        M = np.eye(N)
        W_k = np.eye(N)
        coeff = 1.0
        for _ in range(order):
            W_k = W_k @ W
            coeff *= decay
            M += coeff * W_k
        M = M / M.sum(axis=1, keepdims=True)
        return torch.tensor(M, dtype=torch.float32).to(device)

    def feature_map(self, x):
        return torch.nn.functional.relu(x) + self.eps

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, N, self.num_heads, -1).transpose(1, 2), qkv)
        q = self.feature_map(q)
        k = self.feature_map(k)
        q_graph = (q.transpose(-2, -1) @ self.mask).transpose(-2, -1)
        k_graph = (k.transpose(-2, -1) @ self.mask).transpose(-2, -1)
        q = q + 0.1 * q_graph
        k = k + 0.1 * k_graph
        linear_kernel = q @ k.transpose(-2, -1)
        masked_kernel = linear_kernel * self.mask.unsqueeze(0).unsqueeze(0)
        z = 1 / (masked_kernel.sum(dim=-1, keepdim=True) + self.eps)
        out = (masked_kernel @ v) * z
        out = out.transpose(1, 2).reshape(B, N, C)
        return self.to_out(out)

class ToeplitzAttention(nn.Module):
    def __init__(self, dim, num_heads, num_patches, device, decay=0.8):
        super().__init__()
        self.num_heads = num_heads
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)
        self.eps = 1e-6
        self.register_buffer('mask', self._generate_toeplitz_mask(num_patches, decay, device))

    def _generate_toeplitz_mask(self, N, decay, device):
        side = int(np.sqrt(N))
        mask = np.zeros((N, N))
        for i in range(N):
            for j in range(N):
                xi, yi = i // side, i % side
                xj, yj = j // side, j % side
                dist = abs(xi - xj) + abs(yi - yj)
                mask[i, j] = decay ** dist
        mask = mask / mask.sum(axis=1, keepdims=True)
        return torch.tensor(mask, dtype=torch.float32).to(device)

    def feature_map(self, x):
        return torch.nn.functional.relu(x) + self.eps

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, N, self.num_heads, -1).transpose(1, 2), qkv)
        q = self.feature_map(q)
        k = self.feature_map(k)
        q_graph = (q.transpose(-2, -1) @ self.mask).transpose(-2, -1)
        k_graph = (k.transpose(-2, -1) @ self.mask).transpose(-2, -1)
        q = q + 0.1 * q_graph
        k = k + 0.1 * k_graph
        linear_kernel = q @ k.transpose(-2, -1)
        masked_kernel = linear_kernel * self.mask.unsqueeze(0).unsqueeze(0)
        z = 1 / (masked_kernel.sum(dim=-1, keepdim=True) + self.eps)
        out = (masked_kernel @ v) * z
        out = out.transpose(1, 2).reshape(B, N, C)
        return self.to_out(out)

# --- 4. MODEL ---
class ViT(nn.Module):
    def __init__(self, patch_size, image_size, dim, depth, num_heads, dropout, mlp_dim, device, channels=3, attention_type='softmax', n_walks=50, p_halt=0.1, num_classes=10):
        super().__init__()
        self.patch_size = patch_size
        self.channels = channels
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        self.patch_embed = nn.Linear(patch_dim, dim)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, dim))
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            if attention_type == 'softmax':
                attn = SoftmaxAttention(dim, num_heads)
            elif attention_type == 'linear':
                attn = LinearAttention(dim, num_heads)
            elif attention_type == 'grf':
                attn = GRFExactAttention(dim, num_heads, num_patches, n_walks, p_halt, device)
            elif attention_type == 'm_alpha':
                attn = MAlphaAttention(dim, num_heads, num_patches, device)
            elif attention_type == 'toeplitz':
                attn = ToeplitzAttention(dim, num_heads, num_patches, device)

            self.layers.append(nn.ModuleList([
                nn.LayerNorm(dim),
                attn,
                nn.LayerNorm(dim),
                nn.Sequential(
                    nn.Linear(dim, mlp_dim), nn.GELU(), nn.Dropout(dropout),
                    nn.Linear(mlp_dim, dim), nn.Dropout(dropout)
                )
            ]))
        self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))

    def forward(self, img):
        p = self.patch_size
        x = img.unfold(2, p, p).unfold(3, p, p).reshape(img.shape[0], -1, self.channels * p * p)
        x = self.patch_embed(x)
        B, N, _ = x.shape
        x += self.pos_embed[:, :N]
        for norm1, attn, norm2, mlp in self.layers:
            x = x + attn(norm1(x))
            x = x + mlp(norm2(x))
        return self.mlp_head(x.mean(dim=1))

In [4]:
def get_dataloaders(dataset_name, batch_size, resize_image=32, manipulate_images=False):
    transform_compose_list = [
            transforms.Resize((resize_image, resize_image)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    if manipulate_images:
        transform_compose_list = [
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(resize_image, padding=4),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)
        ] + transform_compose_list
        
    if 'mnist' in dataset_name.lower():
        # Resize to 32x32 to match patch logic easily
        transform_compose_list = transform_compose_list[:-1] + [transforms.Normalize((0.5,), (0.5,))]

    transform = transforms.Compose(transform_compose_list)
        
    if dataset_name.lower() == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
        channels = 3
    elif dataset_name.lower() == 'cifar100':
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
        num_classes = 100
        channels = 3
    elif dataset_name.lower() == 'mnist':
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
        channels = 1
    elif dataset_name.lower() == 'fashionmnist':
        trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
        channels = 1
    else:
        raise ValueError("Unknown dataset")

    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0)

    return trainloader, testloader, num_classes, channels


# --- 1. CONFIGURATION ---
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
EPOCHS = 15
IMAGE_SIZE = 32
PATCH_SIZE = 4
DIM = 64
DEPTH = 2
NUM_HEADS = 4
MLP_DIM = 128
DROPOUT = 0.1

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"Using Device: {DEVICE}")

def train_and_evaluate(model_type, dataset_name='cifar10', n_walks=50, p_halt=0.1, manipulate_images=False):
    print(f"\n--- Training {model_type.upper()} on {dataset_name.upper()} ---")

    trainloader, testloader, num_classes, channels = get_dataloaders(dataset_name, BATCH_SIZE, manipulate_images=manipulate_images)

    model = ViT(
        patch_size=PATCH_SIZE,
        image_size=IMAGE_SIZE,
        dim=DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        dropout=DROPOUT,
        mlp_dim=MLP_DIM,
        device=DEVICE,
        channels=channels,
        attention_type=model_type,
        n_walks=n_walks,
        p_halt=p_halt,
        num_classes=num_classes
    )

    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    start_time = time.time()

    # --- TRAINING LOOP WITH PER-EPOCH LOGGING ---
    final_acc = 0.0

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(trainloader)

        # Evaluate after every epoch
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        epoch_acc = 100 * correct / total
        final_acc = epoch_acc # Store last accuracy

        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_train_loss:.4f} | Test Acc: {epoch_acc:.2f}%")

    train_time = time.time() - start_time
    print(f"   -> Final Result: Acc = {final_acc:.2f}%")
    return final_acc

def replicate_table_1_complete(dataset_name, manipulate_images=False):
    print("\n" + "="*70)
    print(f"5.2 Visual transformer training on {dataset_name}")
    print("="*70)

    acc_softmax = train_and_evaluate('softmax', dataset_name=dataset_name, manipulate_images=manipulate_images)
    acc_toeplitz = train_and_evaluate('toeplitz', dataset_name=dataset_name, manipulate_images=manipulate_images)
    acc_m_alpha = train_and_evaluate('m_alpha', dataset_name=dataset_name, manipulate_images=manipulate_images)
    acc_grf = train_and_evaluate('grf', dataset_name=dataset_name, n_walks=50, p_halt=0.1, manipulate_images=manipulate_images)
    acc_linear = train_and_evaluate('linear', dataset_name=dataset_name, manipulate_images=manipulate_images)

    print(f"\nCOMPLETE RESULT - {dataset_name}")
    print(f"{'Method':<25} {'Accuracy':<10}")
    print("-" * 50)
    print(f"{'Unmasked Softmax':<25} {acc_softmax:<10.2f} ")
    print(f"{'Toeplitz-masked Linear':<25} {acc_toeplitz:<10.2f}")
    print(f"{'M_alpha(G)-masked':<25} {acc_m_alpha:<10.2f} ")
    print("-" * 50)
    print(f"{'GRF-masked Linear':<25} {acc_grf:<10.2f}")
    print(f"{'Unmasked Linear':<25} {acc_linear:<10.2f}")
    print("="*70)


Using Device: cuda


## CIFAR10 - 15 EPOCHS

In [None]:
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
EPOCHS = 15
IMAGE_SIZE = 32
PATCH_SIZE = 4
DIM = 64
DEPTH = 2
NUM_HEADS = 4
MLP_DIM = 128
DROPOUT = 0.1
DATASET_NAME = 'cifar10'

def train_and_evaluate(model_type, dataset_name='cifar10', n_walks=50, p_halt=0.1, manipulate_images=False):
    print(f"\n--- Training {model_type.upper()} on {dataset_name.upper()} ---")

    trainloader, testloader, num_classes, channels = get_dataloaders(dataset_name, BATCH_SIZE,manipulate_images=manipulate_images)

    model = ViT(
        patch_size=PATCH_SIZE,
        image_size=IMAGE_SIZE,
        dim=DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        dropout=DROPOUT,
        mlp_dim=MLP_DIM,
        device=DEVICE,
        channels=channels,
        attention_type=model_type,
        n_walks=n_walks,
        p_halt=p_halt,
        num_classes=num_classes
    )

    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    start_time = time.time()

    # --- TRAINING LOOP WITH PER-EPOCH LOGGING ---
    final_acc = 0.0

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(trainloader)

        # Evaluate after every epoch
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        epoch_acc = 100 * correct / total
        final_acc = epoch_acc # Store last accuracy

        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_train_loss:.4f} | Test Acc: {epoch_acc:.2f}%")

    train_time = time.time() - start_time
    print(f"   -> Final Result: Acc = {final_acc:.2f}%")
    return final_acc

if __name__ == "__main__":
    replicate_table_1_complete(DATASET_NAME)


5.2 Visual transformer training on cifar10

--- Training SOFTMAX on CIFAR10 ---


100%|██████████| 170M/170M [00:04<00:00, 42.2MB/s]


Epoch 1/15 | Loss: 1.8666 | Test Acc: 41.48%
Epoch 2/15 | Loss: 1.5630 | Test Acc: 45.51%
Epoch 3/15 | Loss: 1.4720 | Test Acc: 47.96%
Epoch 4/15 | Loss: 1.4135 | Test Acc: 49.56%
Epoch 5/15 | Loss: 1.3679 | Test Acc: 50.59%
Epoch 6/15 | Loss: 1.3241 | Test Acc: 51.24%
Epoch 7/15 | Loss: 1.2897 | Test Acc: 52.09%
Epoch 8/15 | Loss: 1.2538 | Test Acc: 53.14%
Epoch 9/15 | Loss: 1.2244 | Test Acc: 52.59%
Epoch 10/15 | Loss: 1.1910 | Test Acc: 53.89%
Epoch 11/15 | Loss: 1.1664 | Test Acc: 53.69%
Epoch 12/15 | Loss: 1.1357 | Test Acc: 55.49%
Epoch 13/15 | Loss: 1.1118 | Test Acc: 55.40%
Epoch 14/15 | Loss: 1.0927 | Test Acc: 55.33%
Epoch 15/15 | Loss: 1.0623 | Test Acc: 56.11%
   -> Final Result: Acc = 56.11%

--- Training TOEPLITZ on CIFAR10 ---
Epoch 1/15 | Loss: 1.8765 | Test Acc: 41.40%
Epoch 2/15 | Loss: 1.5699 | Test Acc: 43.49%
Epoch 3/15 | Loss: 1.4728 | Test Acc: 47.59%
Epoch 4/15 | Loss: 1.4046 | Test Acc: 49.48%
Epoch 5/15 | Loss: 1.3598 | Test Acc: 51.16%
Epoch 6/15 | Loss: 1.31

## CIFAR100 - 30 EPOCHS

In [None]:
# --- 1. CONFIGURATION ---
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
EPOCHS = 30
IMAGE_SIZE = 32
PATCH_SIZE = 4
DIM = 64
DEPTH = 2
NUM_HEADS = 4
MLP_DIM = 128
DROPOUT = 0.1

# --- DATASET SELECTION ---
# Switch to 'CIFAR100' for a harder task
DATASET_NAME = 'CIFAR100'  # Options: 'CIFAR10', 'CIFAR100'

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"Using Device: {DEVICE}")
print(f"Dataset: {DATASET_NAME}")

# --- 5. TRAINING UTILS ---
def train_and_evaluate(model_type, dataset_name='cifar10', n_walks=50, p_halt=0.1, manipulate_images=False):
    print(f"\n--- Training {model_type.upper()} on {dataset_name.upper()} ---")

    trainloader, testloader, num_classes, channels = get_dataloaders(dataset_name, BATCH_SIZE,manipulate_images=manipulate_images)

    model = ViT(
        patch_size=PATCH_SIZE,
        image_size=IMAGE_SIZE,
        dim=DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        dropout=DROPOUT,
        mlp_dim=MLP_DIM,
        device=DEVICE,
        channels=channels,
        attention_type=model_type,
        n_walks=n_walks,
        p_halt=p_halt,
        num_classes=num_classes
    )

    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    start_time = time.time()

    # --- TRAINING LOOP WITH PER-EPOCH LOGGING ---
    final_acc = 0.0

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(trainloader)

        # Evaluate after every epoch
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        epoch_acc = 100 * correct / total
        final_acc = epoch_acc # Store last accuracy

        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_train_loss:.4f} | Test Acc: {epoch_acc:.2f}%")

    train_time = time.time() - start_time
    print(f"   -> Final Result: Acc = {final_acc:.2f}%")
    return final_acc
# TODO Rename title in output

if __name__ == "__main__":
    replicate_table_1_complete(DATASET_NAME)

Using Device: cuda
Dataset: CIFAR100

5.2 Visual transformer training on CIFAR100

--- Training SOFTMAX on CIFAR100 ---


100%|██████████| 169M/169M [00:02<00:00, 76.3MB/s]


Epoch 1/30 | Loss: 4.2431 | Test Acc: 10.79%
Epoch 2/30 | Loss: 3.6941 | Test Acc: 15.60%
Epoch 3/30 | Loss: 3.5025 | Test Acc: 18.18%
Epoch 4/30 | Loss: 3.3678 | Test Acc: 19.69%
Epoch 5/30 | Loss: 3.2686 | Test Acc: 20.41%
Epoch 6/30 | Loss: 3.1789 | Test Acc: 22.40%
Epoch 7/30 | Loss: 3.1008 | Test Acc: 23.15%
Epoch 8/30 | Loss: 3.0389 | Test Acc: 23.65%
Epoch 9/30 | Loss: 2.9835 | Test Acc: 23.89%
Epoch 10/30 | Loss: 2.9298 | Test Acc: 25.24%
Epoch 11/30 | Loss: 2.8821 | Test Acc: 26.12%
Epoch 12/30 | Loss: 2.8379 | Test Acc: 26.12%
Epoch 13/30 | Loss: 2.8026 | Test Acc: 25.98%
Epoch 14/30 | Loss: 2.7620 | Test Acc: 26.47%
Epoch 15/30 | Loss: 2.7242 | Test Acc: 26.71%
Epoch 16/30 | Loss: 2.6934 | Test Acc: 27.37%
Epoch 17/30 | Loss: 2.6640 | Test Acc: 27.02%
Epoch 18/30 | Loss: 2.6305 | Test Acc: 27.32%
Epoch 19/30 | Loss: 2.6065 | Test Acc: 27.49%
Epoch 20/30 | Loss: 2.5741 | Test Acc: 27.65%
Epoch 21/30 | Loss: 2.5537 | Test Acc: 28.24%
Epoch 22/30 | Loss: 2.5239 | Test Acc: 28.1

## CIFAR10 - 30 EPOCHS

In [None]:
# --- 1. CONFIGURATION ---
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
EPOCHS = 30
IMAGE_SIZE = 32
PATCH_SIZE = 4
DIM = 64
DEPTH = 2
NUM_HEADS = 4
MLP_DIM = 128
DROPOUT = 0.1

# --- DATASET SELECTION ---
# Switch to 'CIFAR100' for a harder task
DATASET_NAME = 'CIFAR10'  # Options: 'CIFAR10', 'CIFAR100'

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"Using Device: {DEVICE}")
print(f"Dataset: {DATASET_NAME}")

# --- 5. TRAINING UTILS ---
def train_and_evaluate(model_type, dataset_name='cifar10', n_walks=50, p_halt=0.1, manipulate_images=False):
    print(f"\n--- Training {model_type.upper()} on {dataset_name.upper()} ---")

    trainloader, testloader, num_classes, channels = get_dataloaders(dataset_name, BATCH_SIZE,manipulate_images=manipulate_images)

    model = ViT(
        patch_size=PATCH_SIZE,
        image_size=IMAGE_SIZE,
        dim=DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        dropout=DROPOUT,
        mlp_dim=MLP_DIM,
        device=DEVICE,
        channels=channels,
        attention_type=model_type,
        n_walks=n_walks,
        p_halt=p_halt,
        num_classes=num_classes
    )

    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    start_time = time.time()

    # --- TRAINING LOOP WITH PER-EPOCH LOGGING ---
    final_acc = 0.0

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(trainloader)

        # Evaluate after every epoch
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        epoch_acc = 100 * correct / total
        final_acc = epoch_acc # Store last accuracy

        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_train_loss:.4f} | Test Acc: {epoch_acc:.2f}%")

    train_time = time.time() - start_time
    print(f"   -> Final Result: Acc = {final_acc:.2f}%")
    return final_acc
# TODO Rename title in output

if __name__ == "__main__":
    replicate_table_1_complete(DATASET_NAME)

Using Device: cuda
Dataset: CIFAR10

5.2 Visual transformer training on CIFAR10

--- Training SOFTMAX on CIFAR10 ---
Epoch 1/30 | Loss: 1.9151 | Test Acc: 41.59%
Epoch 2/30 | Loss: 1.5766 | Test Acc: 46.01%
Epoch 3/30 | Loss: 1.4707 | Test Acc: 48.28%
Epoch 4/30 | Loss: 1.4069 | Test Acc: 48.28%
Epoch 5/30 | Loss: 1.3553 | Test Acc: 51.22%
Epoch 6/30 | Loss: 1.3190 | Test Acc: 51.52%
Epoch 7/30 | Loss: 1.2815 | Test Acc: 52.23%
Epoch 8/30 | Loss: 1.2509 | Test Acc: 52.95%
Epoch 9/30 | Loss: 1.2222 | Test Acc: 52.23%
Epoch 10/30 | Loss: 1.1965 | Test Acc: 54.76%
Epoch 11/30 | Loss: 1.1643 | Test Acc: 54.48%
Epoch 12/30 | Loss: 1.1359 | Test Acc: 54.84%
Epoch 13/30 | Loss: 1.1124 | Test Acc: 55.10%
Epoch 14/30 | Loss: 1.0890 | Test Acc: 56.20%
Epoch 15/30 | Loss: 1.0642 | Test Acc: 55.21%
Epoch 16/30 | Loss: 1.0398 | Test Acc: 55.20%
Epoch 17/30 | Loss: 1.0212 | Test Acc: 56.55%
Epoch 18/30 | Loss: 0.9947 | Test Acc: 57.13%
Epoch 19/30 | Loss: 0.9782 | Test Acc: 56.57%
Epoch 20/30 | Loss

## FashionMNIST - 10 EPOCHS

In [None]:
# --- 1. CONFIGURATION ---
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
EPOCHS = 10
IMAGE_SIZE = 32
PATCH_SIZE = 4
DIM = 64
DEPTH = 2
NUM_HEADS = 4
MLP_DIM = 128
DROPOUT = 0.1

# --- DATASET SELECTION ---
# Switch to 'CIFAR100' for a harder task
DATASET_NAME = 'FashionMNIST'  # Options: 'CIFAR10', 'CIFAR100', 'FashionMNIST', 'MNIST'

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"Using Device: {DEVICE}")
print(f"Dataset: {DATASET_NAME}")

# --- 5. TRAINING UTILS ---
def train_and_evaluate(model_type, dataset_name='cifar10', n_walks=50, p_halt=0.1, manipulate_images=False):
    print(f"\n--- Training {model_type.upper()} on {dataset_name.upper()} ---")

    trainloader, testloader, num_classes, channels = get_dataloaders(dataset_name, BATCH_SIZE,manipulate_images=manipulate_images)

    model = ViT(
        patch_size=PATCH_SIZE,
        image_size=IMAGE_SIZE,
        dim=DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        dropout=DROPOUT,
        mlp_dim=MLP_DIM,
        device=DEVICE,
        channels=channels,
        attention_type=model_type,
        n_walks=n_walks,
        p_halt=p_halt,
        num_classes=num_classes
    )


    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    start_time = time.time()

    # --- TRAINING LOOP WITH PER-EPOCH LOGGING ---
    final_acc = 0.0

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(trainloader)

        # Evaluate after every epoch
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        epoch_acc = 100 * correct / total
        final_acc = epoch_acc # Store last accuracy

        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_train_loss:.4f} | Test Acc: {epoch_acc:.2f}%")

    train_time = time.time() - start_time
    print(f"   -> Final Result: Acc = {final_acc:.2f}%")
    return final_acc
# TODO Rename title in output

if __name__ == "__main__":
    replicate_table_1_complete(DATASET_NAME)

Using Device: cuda
Dataset: FashionMNIST

5.2 Visual transformer training on FashionMNIST

--- Training SOFTMAX on FASHIONMNIST ---


100%|██████████| 26.4M/26.4M [00:01<00:00, 14.0MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 300kB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.53MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 12.2MB/s]


Epoch 1/10 | Loss: 0.8049 | Test Acc: 80.18%
Epoch 2/10 | Loss: 0.4555 | Test Acc: 84.50%
Epoch 3/10 | Loss: 0.4091 | Test Acc: 84.87%
Epoch 4/10 | Loss: 0.3803 | Test Acc: 85.40%
Epoch 5/10 | Loss: 0.3616 | Test Acc: 86.37%
Epoch 6/10 | Loss: 0.3420 | Test Acc: 86.12%
Epoch 7/10 | Loss: 0.3305 | Test Acc: 86.64%
Epoch 8/10 | Loss: 0.3205 | Test Acc: 87.07%
Epoch 9/10 | Loss: 0.3080 | Test Acc: 87.49%
Epoch 10/10 | Loss: 0.2999 | Test Acc: 87.49%
   -> Final Result: Acc = 87.49%

--- Training TOEPLITZ on FASHIONMNIST ---
Epoch 1/10 | Loss: 0.8267 | Test Acc: 80.97%
Epoch 2/10 | Loss: 0.4768 | Test Acc: 83.00%
Epoch 3/10 | Loss: 0.4220 | Test Acc: 84.06%
Epoch 4/10 | Loss: 0.3903 | Test Acc: 84.90%
Epoch 5/10 | Loss: 0.3683 | Test Acc: 86.25%
Epoch 6/10 | Loss: 0.3511 | Test Acc: 86.53%
Epoch 7/10 | Loss: 0.3347 | Test Acc: 87.00%
Epoch 8/10 | Loss: 0.3236 | Test Acc: 85.69%
Epoch 9/10 | Loss: 0.3145 | Test Acc: 86.91%
Epoch 10/10 | Loss: 0.3041 | Test Acc: 87.92%
   -> Final Result: Ac

## CIFAR10 - 15 EPOCHS - different parameters (batch_size, image_size, patch_size, depth)

In [5]:
# --- 1. CONFIGURATION ---
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
EPOCHS = 15
IMAGE_SIZE = 42
PATCH_SIZE = 4
DIM = 64
DEPTH = 4
NUM_HEADS = 4
MLP_DIM = 128
DROPOUT = 0.1
N_WALKS = 100
P_HALT = 0.1

# --- DATASET SELECTION ---
# Switch to 'CIFAR100' for a harder task
DATASET_NAME = "CIFAR10"  # Options: 'CIFAR10', 'CIFAR100', 'FashionMNIST', 'MNIST'

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"Using Device: {DEVICE}")
print(f"Dataset: {DATASET_NAME}")


# --- 5. TRAINING UTILS ---
def train_and_evaluate(
    model_type, dataset_name="cifar10", n_walks=50, p_halt=0.1, manipulate_images=False
):
    print(f"\n--- Training {model_type.upper()} on {dataset_name.upper()} ---")

    trainloader, testloader, num_classes, channels = get_dataloaders(
        dataset_name,
        BATCH_SIZE,
        manipulate_images=manipulate_images,
        resize_image=IMAGE_SIZE,
    )

    model = ViT(
        patch_size=PATCH_SIZE,
        image_size=IMAGE_SIZE,
        dim=DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        dropout=DROPOUT,
        mlp_dim=MLP_DIM,
        device=DEVICE,
        channels=channels,
        attention_type=model_type,
        n_walks=n_walks,
        p_halt=p_halt,
        num_classes=num_classes,
    )


    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    start_time = time.time()

    # --- TRAINING LOOP WITH PER-EPOCH LOGGING ---
    final_acc = 0.0

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(trainloader)

        # Evaluate after every epoch
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        epoch_acc = 100 * correct / total
        final_acc = epoch_acc  # Store last accuracy

        print(
            f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_train_loss:.4f} | Test Acc: {epoch_acc:.2f}%"
        )

    train_time = time.time() - start_time
    print(f"   -> Final Result: Acc = {final_acc:.2f}%")
    return final_acc


# TODO Rename title in output

def replicate_table_1_complete(dataset_name, n_walks, p_halt, manipulate_images=False):
    print("\n" + "="*70)
    print(f"5.2 Visual transformer training on {dataset_name}")
    print("="*70)

    acc_softmax = train_and_evaluate('softmax', dataset_name=dataset_name, manipulate_images=manipulate_images)
    acc_toeplitz = train_and_evaluate('toeplitz', dataset_name=dataset_name, manipulate_images=manipulate_images)
    acc_m_alpha = train_and_evaluate('m_alpha', dataset_name=dataset_name, manipulate_images=manipulate_images)
    acc_grf = train_and_evaluate('grf', dataset_name=dataset_name, n_walks=n_walks, p_halt=p_halt, manipulate_images=manipulate_images)
    acc_linear = train_and_evaluate('linear', dataset_name=dataset_name, manipulate_images=manipulate_images)

    print(f"\nCOMPLETE RESULT - {dataset_name}")
    print(f"{'Method':<25} {'Accuracy':<10}")
    print("-" * 50)
    print(f"{'Unmasked Softmax':<25} {acc_softmax:<10.2f} ")
    print(f"{'Toeplitz-masked Linear':<25} {acc_toeplitz:<10.2f}")
    print(f"{'M_alpha(G)-masked':<25} {acc_m_alpha:<10.2f} ")
    print("-" * 50)
    print(f"{'GRF-masked Linear':<25} {acc_grf:<10.2f}")
    print(f"{'Unmasked Linear':<25} {acc_linear:<10.2f}")
    print("="*70)

if __name__ == "__main__":
    replicate_table_1_complete(DATASET_NAME, N_WALKS, P_HALT)

Using Device: cuda
Dataset: CIFAR10

5.2 Visual transformer training on CIFAR10

--- Training SOFTMAX on CIFAR10 ---


100%|██████████| 170M/170M [00:05<00:00, 28.7MB/s] 


Epoch 1/15 | Loss: 1.8496 | Test Acc: 42.40%
Epoch 2/15 | Loss: 1.5206 | Test Acc: 46.95%
Epoch 3/15 | Loss: 1.4078 | Test Acc: 50.05%
Epoch 4/15 | Loss: 1.3361 | Test Acc: 49.56%
Epoch 5/15 | Loss: 1.2800 | Test Acc: 52.41%
Epoch 6/15 | Loss: 1.2350 | Test Acc: 53.77%
Epoch 7/15 | Loss: 1.1895 | Test Acc: 52.02%
Epoch 8/15 | Loss: 1.1520 | Test Acc: 54.61%
Epoch 9/15 | Loss: 1.1141 | Test Acc: 55.36%
Epoch 10/15 | Loss: 1.0784 | Test Acc: 55.44%
Epoch 11/15 | Loss: 1.0444 | Test Acc: 55.81%
Epoch 12/15 | Loss: 1.0096 | Test Acc: 56.54%
Epoch 13/15 | Loss: 0.9797 | Test Acc: 56.33%
Epoch 14/15 | Loss: 0.9469 | Test Acc: 56.55%
Epoch 15/15 | Loss: 0.9225 | Test Acc: 56.42%
   -> Final Result: Acc = 56.42%

--- Training TOEPLITZ on CIFAR10 ---
Epoch 1/15 | Loss: 1.8154 | Test Acc: 43.01%
Epoch 2/15 | Loss: 1.5184 | Test Acc: 47.63%
Epoch 3/15 | Loss: 1.4132 | Test Acc: 49.49%
Epoch 4/15 | Loss: 1.3396 | Test Acc: 51.25%
Epoch 5/15 | Loss: 1.2820 | Test Acc: 54.27%
Epoch 6/15 | Loss: 1.23