In [1]:
# Use a pretrained model and get the embeddings for each of the image, store and index them, query for the top 5
# similar images for a given query image

import pandas as pd
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torch.utils.data as data
from tqdm import tqdm
from PIL import Image
import glob
from ImageClass import CustomImageClass
import numpy as np
import CNNModel
import Config

In [4]:
# Load the pretrained model
model = CNNModel(Config.embedding_size)
model.to('mps')



CNNModel(
  (resnet_module): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv

In [5]:
articles = pd.read_csv(r'../data/raw/h-and-m-personalized-fashion-recommendations/articles.csv')

In [6]:
img_list = glob.glob('../data/raw/h-and-m-personalized-fashion-recommendations/images/*/*')

In [7]:
len(img_list), len(articles)

(105100, 105542)

In [8]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((Config.img_size, Config.img_size)),
    torchvision.transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
])


image_dataset = CustomImageClass(data_path=img_list, 
                                 transform=transform)

In [9]:
custom_data_loader = torch.utils.data.DataLoader(dataset=image_dataset,
                                                batch_size=Config.batch_size,
                                                shuffle=True,
                                                num_workers=2)

In [10]:
embeddings = []

with torch.no_grad():
    for data in tqdm(custom_data_loader):
        preds = model(data.to("mps"))
        preds = preds.detach().cpu().numpy()
        embeddings.append(preds)

100%|█████████████████████████████████████| 1643/1643 [2:29:55<00:00,  5.47s/it]


In [13]:
embeddings = np.concatenate(embeddings)
img_embeddings = pd.DataFrame(embeddings)
img_embeddings['image_id'] = img_list

# save the embeddings
img_embeddings.to_csv(f"h&m_emb_img_{Config.embedding_size}.csv", index = False)

In [15]:
embeddings.shape

(105100, 128)

In [17]:
np.save('../models/H&M-Embeddings.npy', embeddings) # save