In [None]:
import sys
sys.path.insert(0, './../Models')

from imagenet1k_dataloader import get_imagenet_loaders

import torch
import torch.nn as nn
from torchvision import transforms as T
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

import timm
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

net = timm.create_model('mixer_b16_224.miil_in21k_ft_in1k', pretrained = True).to(device)

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

net.head = Identity()

In [None]:
imagenet1k_data_dir = "./../Data/imagenet1k/"
test_size = 0.1
batch_size = 64

train_loader, test_loader = get_imagenet_loaders(imagenet1k_data_dir, 
                                                 test_size = test_size, 
                                                 shuffle = True, 
                                                 batch_size = batch_size, 
                                                 device = device)

In [None]:
outputs = torch.Tensor()
labels = torch.Tensor()

tqdm_loader = tqdm(train_loader, desc = "Inference Train Data", position = 0, leave = True)
for dat in tqdm_loader:
    image, label = dat[0], dat[1].cpu().detach()
    output = net(image).cpu().detach()
    outputs = torch.cat((outputs, output), dim = 0)
    labels = torch.cat((labels, label), dim = 0)
    tqdm_loader.update(1)
tqdm_loader.close()

tqdm_loader = tqdm(test_loader, desc = "Inference Test Data", position = 0, leave = True)
for dat in tqdm_loader:
    image, label = dat[0], dat[1].cpu().detach()
    output = net(image).cpu().detach()
    outputs = torch.cat((outputs, output), dim = 0)
    labels = torch.cat((labels, label), dim = 0)
    tqdm_loader.update(1)

In [None]:
m = TSNE(n_components = 2, perplexity = 30, n_iter = 1000, learning_rate = 60)
features = m.fit_transform(outputs)

In [None]:
cmap = sns.hls_palette(as_cmap = True)
f, ax = plt.subplots(figsize = (12, 8))
points = ax.scatter(features[:, 0], features[:, 1], c = labels.numpy(), s = 30, cmap = cmap)
f.colorbar(points)
plt.show()