In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder, FashionMNIST

from PIL import Image, ImageFile

# the following import is required for training to be robust to truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

%matplotlib inline

In [None]:
# set up some parans
DATA = './data'
RESIZE = (224, 224)
BATCH_SIZE = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# download pre-trained ResNet18 
resnet18 = models.resnet18(pretrained=True)

# no need to perform any augmentation other than resizing 
imageTransformations = transforms.Compose([ 
        transforms.Resize(RESIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
])

# test on few images downloaded from the internet
dataset = ImageFolder(DATA, transform=imageTransformations)
                      
# define loader
imageLoader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)

In [None]:
class Img2Vec():
    """
    
    """
    def __init__(self, model, nFeatures):
        
        self.model = model
        self.nFeatures = nFeatures
        self.avgPool = model._modules.get('avgpool')
        
    def getVec(self, image, device):
        
        model = self.model
        embedding = torch.zeros(1, self.nFeatures, 1, 1)
        
        def copyData(m, i, o): embedding.copy_(o.data)
    
        with torch.no_grad():
            model.eval()
            
            if image.shape[1] == 1: 
                image = torch.from_numpy(_repeat_grayscale(image))
            
            # move to GPU
            model, image = model.to(device), image.to(device)
    
            h = self.avgPool.register_forward_hook(copyData)
            self.model(image)
            h.remove()
        
        return embedding.numpy()[0, :, 0, 0]
    
    def _repeat_grayscale(img):
        return np.repeat(img[..., np.newaxis], 3, -1)

In [None]:
# extract the image vector 
img2vec = Img2Vec(resnet18, 512)

vectors = {}
for (img, _) in imageLoader:
    vec = img2vec.getVec(img, DEVICE)
    vectors[img] = vec