In [1]:
import torchvision
from torchvision import transforms
import torch
from torch import nn
from torch.nn import functional as F
from efficientnet_pytorch import EfficientNet




from scipy.spatial.distance import cosine as cos_dist
from PIL import Image

In [2]:
class SimilarityEstimator: #Using pretrained net to get image embeddings
    def __init__(self, model='resnet'):
        
        if model == 'resnet':
            self.key = 'avgpool'
            self.model = torchvision.models.resnet34(pretrained=True)
            #num_ftrs = model.fc.in_features
            #self.model.fc = nn.Linear(num_ftrs, num_classes)
            self.model.avgpool.register_forward_hook(self.__get_activation(self.key))
        else:
            self.key = '_avg_pooling'
            self.model = EfficientNet.from_pretrained('efficientnet-b2')
            self.model._avg_pooling.register_forward_hook(self.__get_activation(self.key))
            
        self.model.eval()
    
    
    def __get_activation(self, name):
        def hook(model, input, output):
            self.activation[name] = output.detach()
        return hook
    
    
    def simmilarity(self, pair):
        self.activation = {}
        img1 = pair[0]
        img2 = pair[1]
        with torch.no_grad():
            _ = self.model(img1.unsqueeze(0))
            out1 = self.activation[self.key].squeeze().numpy()
            _ = self.model(img2.unsqueeze(0))
            out2 = self.activation[self.key].squeeze().numpy()

        return 1 - cos_dist(out1, out2) # as distance in scipy is 1 - dot(a, b)/(norm(a) * norm(b))
        
    
    @staticmethod    
    def get_data_pair(paths):
        data_transforms = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(
        [0.485, 0.456, 0.406], 
        [0.229, 0.224, 0.225]) 
    ])
        img1 = Image.open(paths[0]).convert('RGB')
        img2 = Image.open(paths[1]).convert('RGB')   
        ret = (data_transforms(img1), data_transforms(img2))
        return ret

In [3]:
ls *.png

[0m[01;35mair_purifier.png[0m  [01;35miPhone2.png[0m  [01;35miPhone.png[0m


In [4]:
sim_estimator = SimilarityEstimator('resnet')


paths = ['iPhone2.png', 'iPhone.png']
pair = SimilarityEstimator.get_data_pair(paths)
sim1 = sim_estimator.simmilarity(pair)



paths = ['iPhone2.png', 'air_purifier.png']
pair = SimilarityEstimator.get_data_pair(paths)
sim2 = sim_estimator.simmilarity(pair)

print(sim1, sim2)

0.8892881274223328 0.45980799198150635


In [5]:
sim_estimator = SimilarityEstimator('efficientnet')


paths = ['iPhone2.png', 'iPhone.png']
pair = SimilarityEstimator.get_data_pair(paths)
sim1 = sim_estimator.simmilarity(pair)



paths = ['iPhone2.png', 'air_purifier.png']
pair = SimilarityEstimator.get_data_pair(paths)
sim2 = sim_estimator.simmilarity(pair)

print(sim1, sim2)

Loaded pretrained weights for efficientnet-b2
0.7053554058074951 0.04367216303944588


In [7]:
## TO DO
## Similarity based on product info (website description)