In [None]:
import sys
sys.path.insert(0, './../Models')
from mlp_mixer import MLPMixer
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import CIFAR10
from torch.utils.data.dataloader import default_collate

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
def convert_keys(state_dict, map_classifier, num_classes = 10):
    keys = state_dict.keys()
    new_keys = []
    new_state_dict = {}

    for key in keys:
        new_key = key.replace("/", ".")
        new_key = new_key.replace("MixerBlock_", "MixerBlock.")
        new_key = new_key.replace("channel_mixing.Dense_0", "channel_mixing.1.net.0")
        new_key = new_key.replace("channel_mixing.Dense_1", "channel_mixing.1.net.3")
        new_key = new_key.replace("token_mixing.Dense_0", "token_mixing.2.net.0")
        new_key = new_key.replace("token_mixing.Dense_1", "token_mixing.2.net.3")
        new_key = new_key.replace("LayerNorm_0", "token_mixing.0")
        new_key = new_key.replace("LayerNorm_1", "channel_mixing.0")
        new_key = new_key.replace("scale", "weight")
        new_key = new_key.replace("kernel", "weight")
        new_key = new_key.replace("stem", "stem.0")
        new_key = new_key.replace("head", "head.0")
        new_key = new_key.replace("pre_head.0_layer_norm", "pre_head_layer_norm")
        new_keys.append(new_key)
    
    if map_classifier:
        for (key, new_key) in zip(keys, new_keys): 
            new_state_dict[new_key] = torch.tensor(state_dict[key], dtype = torch.float32).T
    else:
        for (key, new_key) in zip(keys, new_keys):
            weights = torch.tensor(state_dict[key], dtype = torch.float32).T
            if "head." in new_key:
                k = np.random.randint(0, weights.shape[0], num_classes)
                # For bias
                if len(weights.shape) == 1:
                    weights = weights[k]
                # For weights
                else:
                    weights = weights[k, :]
            new_state_dict[new_key] = weights
    return new_state_dict

In [None]:
# B/16 architecture
net = MLPMixer(in_channels = 3,
               dim = 768,
               num_classes = 10,
               patch_size = 16,
               image_size = 224,
               depth = 12,
               token_dim = 384,
               channel_dim = 3072).to(device)

google_weights = np.load("./../Weights/imagenet1k-Mixer-B_16.npz", allow_pickle = True)
new_weights = convert_keys(google_weights, map_classifier = False, num_classes = 10)
net.load_state_dict(new_weights, strict = False)

non_linearity = nn.Softmax(dim = 1)
optimizer = torch.optim.Adam(net.parameters(), lr = 0.1)
loss = nn.CrossEntropyLoss()

In [None]:
data_dir = "./../Data"
batch_size = 32
shuffle = True

transform = T.Compose([
            T.Resize(224),
            T.ToTensor()])

train_dataset = CIFAR10(root = data_dir, train = True, transform = transform, download = True)
test_dataset = CIFAR10(root = data_dir, train = False, transform = transform, download = True)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = shuffle, 
                          collate_fn = lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = shuffle,
                          collate_fn = lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

In [None]:
epochs = 10
train_losses, test_losses = [], []
train_accs, test_accs = [], []

for epoch in range(epochs):
    tqdm_train_loader = tqdm(train_loader, desc = f"Train Epoch {epoch + 1}")
    net.train()
    train_loss, train_acc = 0, 0
    for (x, y) in tqdm_train_loader:
        optimizer.zero_grad()
        y_hat = non_linearity(net(x))
        loss_ = loss(y_hat, y)
        loss_.backward()
        optimizer.step()
        t_loss = loss_.item()
        t_acc = (y_hat.argmax(1) == y).sum().item()
        train_loss += t_loss / len(train_loader)
        train_acc += t_acc / len(train_dataset)
        tqdm_train_loader.set_postfix(train_loss = t_loss, train_acc = t_acc)
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    tqdm_test_loader = tqdm(test_loader, desc = f"Test Epoch {epoch + 1}")
    net.eval()
    test_loss, test_acc = 0, 0
    for (x, y) in tqdm_test_loader:
        y_hat = non_linearity(net(x))
        loss_ = loss(y_hat, y)
        t_loss = loss_.item()
        t_acc = (y_hat.argmax(1) == y).sum().item()
        test_loss += t_loss / len(test_loader)
        test_acc += t_acc / len(test_dataset)
        tqdm_test_loader.set_postfix(test_loss = t_loss, test_acc = t_acc)
    test_losses.append(test_loss)
    test_accs.append(test_acc)
    print(f"Epoch {epoch + 1} | Train Loss {train_loss:.4f} | Train Acc {train_acc:.4f} | Test Loss {test_loss:.4f} | Test Acc {test_acc:.4f}")

In [None]:
plt.subplot(1, 2, 1)
plt.plot(np.arange(epochs) + 1, train_losses, label = "Train Loss")
plt.plot(np.arange(epochs) + 1, test_losses, label = "Test Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(np.arange(epochs) + 1, train_accs, label = "Train Accuracy")
plt.plot(np.arange(epochs) + 1, test_accs, label = "Test Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()

plt.tight_layout()
plt.show()