In [1]:
import sys
sys.path.insert(0, './../Models')

from mlp_mixer import MLPMixer
from imagenet1k_dataloader import get_imagenet_loaders

from tqdm import tqdm
import numpy as np

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# B/16 architecture
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)
net = MLPMixer(in_channels = 3,
               dim = 768,
               num_classes = 1000,
               patch_size = 16,
               image_size = 224,
               depth = 12,
               token_dim = 384,
               channel_dim = 3072).to(device)
non_linearity = nn.Softmax(dim = 1)

cuda


In [3]:
def convert_keys(state_dict):
    keys = state_dict.keys()
    new_keys = []
    new_state_dict = {}

    for key in keys:
        new_key = key.replace("/", ".")
        new_key = new_key.replace("MixerBlock_", "MixerBlock.")
        new_key = new_key.replace("channel_mixing.Dense_0", "channel_mixing.1.net.0")
        new_key = new_key.replace("channel_mixing.Dense_1", "channel_mixing.1.net.3")
        new_key = new_key.replace("token_mixing.Dense_0", "token_mixing.2.net.0")
        new_key = new_key.replace("token_mixing.Dense_1", "token_mixing.2.net.3")
        new_key = new_key.replace("LayerNorm_0", "token_mixing.0")
        new_key = new_key.replace("LayerNorm_1", "channel_mixing.0")
        new_key = new_key.replace("scale", "weight")
        new_key = new_key.replace("kernel", "weight")
        new_key = new_key.replace("stem", "stem.0")
        new_key = new_key.replace("head", "head.0")
        new_key = new_key.replace("pre_head.0_layer_norm", "pre_head_layer_norm")
        new_keys.append(new_key)
    
    for (key, new_key) in zip(keys, new_keys):
        new_state_dict[new_key] = torch.tensor(state_dict[key], dtype = torch.float32).T
    return new_state_dict

In [4]:
google_weights = np.load("./../Weights/imagenet1k-Mixer-B_16.npz", allow_pickle = True)
new_weights = convert_keys(google_weights)
net.load_state_dict(new_weights, strict = False)

  new_state_dict[new_key] = torch.tensor(state_dict[key], dtype = torch.float32).T


<All keys matched successfully>

In [5]:
imagenet1k_data_dir = "./../Data/imagenet1k/"
test_size = 0.000001
batch_size = 32

train_loader, test_loader = get_imagenet_loaders(imagenet1k_data_dir, 
                                                 test_size = test_size, 
                                                 shuffle = True, 
                                                 batch_size = 32, 
                                                 device = device)

In [6]:
accuracy = 0
tqdm_loader = tqdm(train_loader, desc = "Inference", position = 0, leave = True)
for dat in tqdm_loader:
    image, label = dat[0], dat[1].cpu().detach()
    output = net(image).cpu().detach()
    output = non_linearity(output)
    predictions = torch.topk(output, k = 10, dim = 1)[1]
    
    for i in range(len(label)):
        if label[i] in predictions[i]:
            accuracy += 1.0 / (batch_size * len(train_loader))
    
    tqdm_loader.set_postfix(accuracy = 100 * accuracy)
    tqdm_loader.update(1)
tqdm_loader.close()
print(f"Accuracy: {accuracy * 100}%")

Inference: 100%|██████████| 3125/3125 [14:31<00:00,  3.58it/s, accuracy=36.1]

Accuracy: 36.13400000002206%



