In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision import transforms
import torchvision.models as models
import matplotlib.pyplot as plt
from tqdm import tqdm
import timm
import numpy as np
from torch.utils.data import random_split
import pickle

In [2]:
transform = transforms.Compose([
    transforms.Resize(139),    # minimum image size for inception resnet
    transforms.ToTensor(),
])

In [3]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transform,
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transform,
)

In [4]:
train_len = int(0.8 * len(training_data))
val_len = len(training_data) - train_len
torch.manual_seed(42)
train_data, val_data = random_split(training_data, [train_len, val_len])

In [5]:
class Model(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(Model, self).__init__()

        # Load a pretrained resnet model from torchvision.models in Pytorch
        self.model = timm.create_model('inception_resnet_v2', pretrained=False)

        # Change the input layer to take Grayscale image, instead of RGB images. 
        # Hence in_channels is set as 1 or 3 respectively
        # original definition of the first layer 
        # self.conv1 = Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        self.model.conv2d_1a.conv = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, bias=False)
    
        # Change the output layer to output 10 classes instead of 1000 classes
        num_ftrs = self.model.classif.in_features
        self.model.classif = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.model(x)

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [None]:
model = Model()
model.load_state_dict(torch.load('results/classifier/model19.pth'))
model.model.classif = nn.Identity()
model.to(device)
model.eval()

In [8]:
batch_size = 64
train_dataloader = DataLoader(train_data, batch_size=batch_size)
val_dataloader = DataLoader(val_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

In [9]:
def embed(dataloader):
    res = []

    for X, y in tqdm(dataloader):
        X = X.to(device)
        pred = model(X)
        pred = pred.detach().cpu()

        for i in range(len(y)):
            res.append((pred[i], y[i]))
    
    return res

In [10]:
train_res = embed(train_dataloader)
val_res = embed(val_dataloader)
test_res = embed(test_dataloader)

100%|██████████| 750/750 [02:24<00:00,  5.18it/s]
100%|██████████| 188/188 [00:35<00:00,  5.25it/s]
100%|██████████| 157/157 [00:29<00:00,  5.24it/s]


In [11]:
with open('train.pickle', 'wb') as f:
    pickle.dump(train_res, f)

In [12]:
with open('val.pickle', 'wb') as f:
    pickle.dump(val_res, f)

In [13]:
with open('test.pickle', 'wb') as f:
    pickle.dump(test_res, f)