In [None]:
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.cuda.empty_cache()

**All parameters you need/want to customize should be in this following cell.**

In [None]:
''' Dataset args '''
ds = 'cifar10' # Choose from ['cifar10', 'cifar100'] # @param {isTemplate:true}
args_batch_size = 512
args_scale = 1
args_reprob = 0
args_ra_m = 12
args_ra_n = 2
args_jitter = 0
res = 32

''' Model args '''
# Choose from 'SplitMixerI', 'SplitMixerI_channel_only', 'SplitMixerII', 'SplitMixerIII', 'SplitMixerIV', 'SplitMixerV'
args_name = 'SplitMixerV' # @param {isTemplate:true}
args_hdim = 256
args_depth = 8
args_psize = 2
args_conv_ks = 5
num_classes = 10 if ds == 'cifar10' else 100
args_ratio = 2/3 # need to provide this if using SplitMixerI, V # @param {isTemplate:true}
args_parts = 4 # need to provide this if using SplitMixerII, III, IV # @param {isTemplate:true}

''' Training args'''
args_lr_max = 0.05
args_wd = 0.005
args_epochs = 100
args_clip_norm = True
args_workers = 2


# Model definitions
You can also find model definitions [under timm](https://github.com/aliborji/splitmixer/blob/main/pytorch-image-models/timm/models/splitmixer.py). To make this notebook more standalone, we repeat the model definitions here without importing.

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


class PartialChannelMixer(nn.Module):
    def __init__(self, dim, is_odd, ratio):
        super().__init__()
        self.dim = dim
        self.partial_c = int(dim *ratio)
        self.mixer = nn.Conv2d(self.partial_c, self.partial_c, kernel_size=1)
        self.is_odd = is_odd
    
    def forward(self, x):
        if self.is_odd == 0:
            idx = self.partial_c
            return torch.cat((self.mixer(x[:, :idx]), x[:, idx:]), dim=1)
        else:
            idx = self.dim - self.partial_c
            return torch.cat((x[:, :idx], self.mixer(x[:, idx:])), dim=1)


def SplitMixerI(dim, depth, kernel_size=5, patch_size=2, num_classes=10,
                ratio=2/3, img_size=32, **kwargs):
    ''' Partial overlap model, one segment per iteration '''
    n_patch = img_size // patch_size
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        # nn.LayerNorm([dim, n_patch, n_patch]),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, (1,kernel_size), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim),
                # nn.LayerNorm([dim, n_patch, n_patch]),
                nn.Conv2d(dim, dim, (kernel_size,1), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim),
                # nn.LayerNorm([dim, n_patch, n_patch]),
                )),
            PartialChannelMixer(dim, i % 2, ratio),
            nn.GELU(),
            nn.BatchNorm2d(dim),
            # nn.LayerNorm([dim, n_patch, n_patch]),
            ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, num_classes)
        )

def SplitMixerI_channel_only(dim, depth, kernel_size=5, patch_size=2,
                             num_classes=10, ratio=2/3, **kwargs):
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, (kernel_size,kernel_size), groups=dim,
                          padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim),
                )),
            PartialChannelMixer(dim, i % 2, ratio),
            nn.GELU(),
            nn.BatchNorm2d(dim),
            ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, num_classes)
        )

# ------------------------------------------------------------------------------

class ChannelBin(nn.Module):
    def __init__(self, dim, remainder, n_part=3):
        super().__init__()
        self.dim = dim
        self.remainder = remainder
        self.n_part = n_part
        self.bin_dim = int(dim / n_part)
        self.c = dim - self.bin_dim * (n_part - 1) if (
            remainder == n_part - 1) else self.bin_dim
        self.mixer = nn.Conv2d(self.c, self.c, kernel_size=1)
    
    def forward(self, x):
        start = self.remainder * self.bin_dim
        end = self.dim if (self.remainder == self.n_part - 1) else (
            (self.remainder + 1) * self.bin_dim)
        return torch.cat((x[:, :start], self.mixer(x[:, start : end]),
                          x[:, end:]), dim=1)


def SplitMixerII(dim, depth, kernel_size=5, patch_size=2, num_classes=10,
                 n_part=3, **kwargs):
    ''' Non overlapping; In each iteration only one segment is convolved '''
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, (1,kernel_size), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim),
                nn.Conv2d(dim, dim, (kernel_size,1), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim)
                )),
            ChannelBin(dim, i % n_part, n_part),
            nn.GELU(),
            nn.BatchNorm2d(dim),
            ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, num_classes)
        )

#-------------------------------------------------------------------------------

class ChannelPatch(nn.Module):
    ''' Non-overlap; Same layer '''
    def __init__(self, dim, n_part):
        super().__init__()
        assert dim % n_part == 0, (
            f'dim {dim} need to be divisible by n_part {n_part}')
        self.dim = dim
        self.n_part = n_part
        self.c = dim // n_part
        self.mixer = nn.Conv2d(self.c, self.c, kernel_size=1)
    
    def forward(self, x):
        c = self.c
        x = [self.mixer(x[:, c * i : c * (i + 1)]) for i in range(self.n_part)]
        return torch.cat(x, dim=1)


def SplitMixerIII(dim, depth, kernel_size=5, patch_size=2, num_classes=10,
                  n_part=3, **kwargs):
    ''' Non overlapping; In each iteration all segments are convolved;
        parameteres are shared across segments '''
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, (1,kernel_size), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim),
                nn.Conv2d(dim, dim, (kernel_size,1), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim)
                )),
            ChannelPatch(dim, n_part),
            nn.GELU(),
            nn.BatchNorm2d(dim),
            ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, num_classes)
        )

#-------------------------------------------------------------------------------

class ChannelPatchUnshared(nn.Module):
    ''' Non-overlap; Same layer, multiple convs '''
    def __init__(self, dim, n_part):
        super().__init__()
        self.dim = dim
        self.n_part = n_part
        c = dim // n_part
        last_c = dim - c * (n_part - 1)
        self.mixer = nn.ModuleList(
            [nn.Conv2d(c, c, kernel_size=1) for _ in range(n_part - 1)] + (
                [nn.Conv2d(last_c, last_c, kernel_size=1)]))
        self.c, self.last_c = c, last_c
   
    def forward(self, x):
        c, last_c = self.c, self.last_c
        x = [self.mixer[i](x[:, c * i : c * (i + 1)]) for i in (
            range(self.n_part - 1))] + [self.mixer[-1](x[:, -last_c:])]
        return torch.cat(x, dim=1)


def SplitMixerIV(dim, depth, kernel_size=5, patch_size=2, num_classes=10,
                 n_part=3, **kwargs):
    ''' Non overlapping; In each iteration all segments are convolved;
        no parameter sharing across segments '''
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, (1,kernel_size), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim),
                nn.Conv2d(dim, dim, (kernel_size,1), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim)
                )),
            ChannelPatchUnshared(dim, n_part),
            nn.GELU(),
            nn.BatchNorm2d(dim),
            ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, num_classes)
        )

# ------------------------------------------------------------------------------

class PartialChannelMixerOneBlock(nn.Module):
    def __init__(self, dim, ratio):
        super().__init__()
        self.dim = dim
        self.c = int(dim *ratio)
        self.mixer1 = nn.Conv2d(self.c, self.c, kernel_size=1)
        self.mixer2 = nn.Conv2d(self.c, self.c, kernel_size=1)
    
    def forward(self, x):
        c, dim = self.c, self.dim
        x = torch.cat((self.mixer1(x[:, :c]), x[:, c:]), dim=1)
        return torch.cat((x[:, :(dim - c)], self.mixer2(x[:, (dim - c):])), dim=1)


def SplitMixerV(dim, depth, kernel_size=5, patch_size=2, num_classes=10,
                ratio=2/3, **kwargs):
    ''' Partial overlap model, both segment per iteration '''
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, (1,kernel_size), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim),
                nn.Conv2d(dim, dim, (kernel_size,1), groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim),
                )),
            PartialChannelMixerOneBlock(dim, ratio),
            nn.GELU(),
            nn.BatchNorm2d(dim),
            ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, num_classes)
        )


And, as the baseline, we use ConvMixer, as defined below:

In [None]:
def ConvMixer(dim, depth, kernel_size=5, patch_size=2, num_classes=10):
    "original convMixer"
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
                Residual(nn.Sequential(
                    nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
                )),
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(dim)
        ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, num_classes)
    )


# CIFAR-{10,100} dataset and loader

In [None]:
mean = {
'cifar10': (0.4914, 0.4822, 0.4465),
'cifar100': (0.5071, 0.4867, 0.4408),
}

std = {
'cifar10': (0.2023, 0.1994, 0.2010),
'cifar100': (0.2675, 0.2565, 0.2761),
}

mean, std = mean[ds], std[ds]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(args_scale, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=args_ra_n, magnitude=args_ra_m),
    transforms.ColorJitter(args_jitter, args_jitter, args_jitter),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomErasing(p=args_reprob)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])


trainset = getattr(torchvision.datasets, ds.upper())(
    root='./data', train=True, download=True, transform=train_transform)
testset = getattr(torchvision.datasets, ds.upper())(
    root='./data', train=False, download=True, transform=test_transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=args_batch_size,
                                          shuffle=True, num_workers=args_workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=args_batch_size,
                                         shuffle=False, num_workers=args_workers)


# Model creation and training

In [None]:
''' Create and print the model based on args. '''

model_kwargs = {'dim': args_hdim, 'depth': args_depth, 'kernel_size': args_conv_ks,
              'patch_size': args_psize, 'num_classes': num_classes}
if args_name[:10] == 'SplitMixer':
    if args_name == 'SplitMixerI':
        model_kwargs['ratio'] = args_ratio
    else:
        model_kwargs['n_part'] = args_parts

model = globals()[args_name](**model_kwargs).to(device)
print(model)

In [None]:
if device.type == 'cuda':
    model = nn.DataParallel(model).cuda()

lr_schedule = lambda t: np.interp(
    [t], [0, args_epochs*2//5, args_epochs*4//5, args_epochs], 
    [0, args_lr_max, args_lr_max/20.0, 0])[0]

opt = optim.AdamW(model.parameters(), lr=args_lr_max, weight_decay=args_wd)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

for epoch in range(args_epochs):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    for i, (X, y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda()

        lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        opt.param_groups[0].update(lr=lr)

        opt.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(X)
            loss = criterion(output, y)

        scaler.scale(loss).backward()
        if args_clip_norm:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)

    print(f'[{args_name}] Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')
