In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
import numpy as np

In [13]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    
    def forward(self, x):
        return self.fn(x) + x

In [14]:
def ConvMixer(dim, depth, kernel_size=5, patch_size=2, n_classes=10):
    
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, kernel_size, groups=dim, padding='same'),
                nn.GELU(),
                nn.BatchNorm2d(dim)
            )),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.GELU(),
            nn.BatchNorm2d(dim)
        )for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, n_classes)

    )

In [15]:
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

epochs = 25
batch_size =512

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
# check cuda and device
cuda = torch.cuda.is_available()

if cuda:
    print(torch.cuda.get_device_name())
    device = 'cuda'

NVIDIA GeForce RTX 3060


In [17]:
lr_schedule = lambda t: np.interp([t], [0, epochs*2//5, epochs*4//5, epochs], 
                                  [0, 0.01, 0.01/20.0, 0])[0]

depth = 10
hdim = 256
psize = 2
conv_ks = 5
clip_norm = True

In [18]:
model = ConvMixer(hdim, depth, patch_size=psize, kernel_size=conv_ks, n_classes=10)
model = nn.DataParallel(model, device_ids=[0]).cuda()

opt = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

In [19]:
for epoch in range(epochs):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    for i, (X, y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda()

        lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        opt.param_groups[0].update(lr=lr)

        opt.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(X)
            loss = criterion(output, y)

        scaler.scale(loss).backward()
        if clip_norm:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)

    print(f'ConvMixer: Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')


ConvMixer: Epoch: 0 | Train Acc: 0.3489, Test Acc: 0.4903, Time: 38.6, lr: 0.001000
ConvMixer: Epoch: 1 | Train Acc: 0.5484, Test Acc: 0.5826, Time: 31.4, lr: 0.002000
ConvMixer: Epoch: 2 | Train Acc: 0.6435, Test Acc: 0.6497, Time: 31.5, lr: 0.003000
ConvMixer: Epoch: 3 | Train Acc: 0.7012, Test Acc: 0.7146, Time: 31.3, lr: 0.004000
ConvMixer: Epoch: 4 | Train Acc: 0.7353, Test Acc: 0.7432, Time: 31.3, lr: 0.005000
ConvMixer: Epoch: 5 | Train Acc: 0.7596, Test Acc: 0.7901, Time: 31.4, lr: 0.006000
ConvMixer: Epoch: 6 | Train Acc: 0.7776, Test Acc: 0.8177, Time: 31.4, lr: 0.007000
ConvMixer: Epoch: 7 | Train Acc: 0.7911, Test Acc: 0.7975, Time: 31.4, lr: 0.008000
ConvMixer: Epoch: 8 | Train Acc: 0.8059, Test Acc: 0.8186, Time: 31.6, lr: 0.009000
ConvMixer: Epoch: 9 | Train Acc: 0.8150, Test Acc: 0.8346, Time: 31.5, lr: 0.010000
ConvMixer: Epoch: 10 | Train Acc: 0.8249, Test Acc: 0.8293, Time: 31.6, lr: 0.009050
ConvMixer: Epoch: 11 | Train Acc: 0.8423, Test Acc: 0.8596, Time: 31.8, lr: