In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jb
from glob import glob
from typing import Union, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.datasets import ImageFolder
from sklearn.metrics.pairwise import linear_kernel

from PIL import Image, ImageFile
from tqdm import tqdm
import joblib as jb
# the following import is required for training to be robust to truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

%matplotlib inline

## Re-train ResNet18 on DeepFashion Data

In [None]:
DATADIR = './data/img'
RESIZE = (256, 256)
CROP = (224, 224)
BATCHSIZE = 64
VALID = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# download pre-trained ResNet18 
resnet18 = models.resnet18(pretrained=True)

# define a new output layer to match the number of classes in the dataset
outputs = len(glob(DATADIR + '/*')) # the number of classes in the data folder
inputs = resnet18.fc.in_features # extract the number of inputs from the final layer
output_layer = nn.Linear(inputs, outputs) 

# freeze all but the avgpool and classifer layers 
for name, param in resnet18.named_parameters():
    if 'fc' not in name and 'avgpool' not in name:
        param.requires_grad=False

# replace the classifier with the new output layer        
resnet18.fc = output_layer

In [None]:
# no need to perform any augmentation other than resizing 
imageTransformations = transforms.Compose([ 
        transforms.Resize(RESIZE),
        transforms.CenterCrop(CROP),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
])

# load training and validation images
fashionDataset = ImageFolder(DATADIR, transform=imageTransformations)

# reserve 10% of images for validation purposes
n_val = int(np.floor(VALID * len(fashionDataset)))
n_train = len(fashionDataset) - n_val
trainSet, validSet = random_split(fashionDataset, [n_train, n_val])

# define the data loaders
trainFashion = DataLoader(trainSet, batch_size=BATCHSIZE)
validFashion = DataLoader(validSet , batch_size=BATCHSIZE)

## Image to Vector

In [None]:
# remove the last fc connected layer and expose the avgpool layer
img2vec = nn.Sequential(*(list(resnet18.children())[:-1]))



# sanity check
print(img2vec)

In [None]:
def ouput_embeddings(
        model: Union[nn.Sequential, nn.Module],
        data: Union[torch.Tensor, np.ndarray],
        device: torch.device
    ) -> torch.Tensor:
    """
    
    """

    if not isinstance(data, torch.Tensor):
        data = torch.from_numpy(data, dtype=torch.float32)
        
    if len(data.size()) == 3:
        data = data.unsqueeze(0)
    
    outputs = _ouput_embeddings(model, data, device)
    
    return outputs


def _ouput_embeddings(
        model: Union[nn.Sequential, nn.Module],
        data: Union[torch.Tensor, np.ndarray],
        device: torch.device
    ) -> torch.Tensor:
    """
    
    """
        
    # pre inference
    with torch.no_grad():
        # turn off dropout/batch norm
        model.eval() 
        # ensure the model and the images are on the same device
        model, data = model.to(device), data.to(device)
        # pass the batch through the model and save the outputs
        outputs = model(data) 

    # clear GPU memory if working with GPU
    model, data = model.to("cpu"), data.to("cpu")
    return outputs

In [None]:
np.c_[np.array([1,2,3]), np.array([4,5,6])]

In [None]:
%%time

embeddings = {}
for batch, _ in tqdm(fashionLoader):
    
    outputs = output_embeddings(img2vec, batch, DEVICE)
    temp = {img: out for img, out in zip(batch, outputs)}
    embeddings.update(temp)
    break

In [None]:
len(embeddings.keys())

In [None]:
a = dict(zip(batch, embeddings))

In [None]:
a = {*batch, *embeddings}

In [None]:
a = list(a)
a[0]

In [None]:
batch_embeddings.size()

In [None]:
batch.size()

In [None]:
jb.dump(embeddings, 'deep_fashion_batched_embeddings.pkl', compress=3)

In [None]:
jb.dump(embeddings, 'deep_fashion_batched_embeddings_backup.pkl')

In [None]:
embeddings = jb.load('deep_fashion_batched_embeddings.pkl')

In [None]:
batch

In [None]:
a = {image: None for image in batch}

In [None]:
len(a.keys())

In [None]:
# identify the first plane
first_plane = a[0]

image = first_plane.permute(1, 2, 0)
    
# avoid clipping
image -= image.min()
image /= image.max()

plt.axis('off')
#plt.title(f'{idx_to_class[class_idx[0].item()]}')
plt.imshow(image)
plt.show()

In [None]:
len(a[0])

In [None]:
len(a)