In [1]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import os
import csv
import time

from MoE_based_transformer import MoE_ViT_classifier


# If available, CUDA (in this case we use CPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy


# Model parameters (same as baseline for fairest possible comparison)
size = 32
bs = 512 # Batch size
lr = 0.0001
epochs = 5 # 200 would be nice but it takes too long
patch = 4
use_amp = False

# Applying some transformation to both sets 
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


# Prepare CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=8)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)

# 10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


  warn(


Files already downloaded and verified
Files already downloaded and verified


In [2]:
# MoE-based ViT call
net = MoE_ViT_classifier(image_size = 32, 
                        patch_size = 4,
                        num_classes = 10,
                        dim=412, 
                        depth = 6,  
                        heads = 8, 
                        moe_input_size =  10, # MoE parameters
                        moe_output_dim = 412, 
                        moe_num_experts = 8, 
                        moe_hidden_dim = 512, # same as baseline MLP dim 
                        moe_noisy_gating = False,
                        moe_k = 4, 
                        num_experts_per_tok = 2, 
                        channels = 3, 
                        dim_head=412, 
                        emb_dropout=0,)

# Cross Entropy Loss
criterion = nn.CrossEntropyLoss()

# Adam optimizer
optimizer = optim.Adam(net.parameters(), lr=lr)

# use cosine scheduling
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

# Train and test functions 
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return train_loss/(batch_idx+1)

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Accuraccy 
    acc = 100.*correct/total
    if acc > best_acc:
        best_acc = acc
    
    # Save in a txt accurary and loss
    content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}'
    print(content)

    # Create txt file and append content 
    with open('MoE-based_ViT.txt', 'w') as f:
        f.write(content)
        f.close()

    return test_loss, acc

list_loss = []
list_acc = []

for epoch in range(0, epochs):
    start = time.time()
    trainloss = train(epoch)
    val_loss, acc = test(epoch)
    
    scheduler.step(epoch-1) # step cosine scheduling
    
    list_loss.append(val_loss)
    list_acc.append(acc)


    print(list_loss)

num_patches 64
patch_dim 48


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)



Epoch: 0
Sat Aug  3 01:42:22 2024 Epoch 0, lr: 0.0001000, val loss: 153.97998, acc: 44.27000
[153.97997963428497]

Epoch: 1




Sat Aug  3 03:31:53 2024 Epoch 1, lr: 0.0000905, val loss: 134.96651, acc: 51.49000
[153.97997963428497, 134.96650767326355]

Epoch: 2
Sat Aug  3 05:22:03 2024 Epoch 2, lr: 0.0001000, val loss: 128.28579, acc: 54.36000
[153.97997963428497, 134.96650767326355, 128.2857929468155]

Epoch: 3
Sat Aug  3 07:11:27 2024 Epoch 3, lr: 0.0000905, val loss: 119.47903, acc: 57.75000
[153.97997963428497, 134.96650767326355, 128.2857929468155, 119.47903341054916]

Epoch: 4
Sat Aug  3 08:28:45 2024 Epoch 4, lr: 0.0000655, val loss: 108.42445, acc: 61.19000
[153.97997963428497, 134.96650767326355, 128.2857929468155, 119.47903341054916, 108.42444515228271]
