In [5]:
import torch
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import Dataset
from skimage import io
import numpy as np
import matplotlib.pyplot as plt
import math

DATA_PATH = './data/ctl/'

In [6]:
def crop_img(image, bbox):
    image = Image.fromarray(image)
    width, height = image.size
    left = bbox[0] * width
    top = bbox[1] * height
    right = bbox[2] * width
    bottom = bbox[3] * height

    regions = [
        (0, 0, width, top),
        (0, bottom, width, height),
        (0, 0, left, height),
        (right, 0, width, height),
    ]

    largest_region = max(regions, key=lambda r: (r[2]-r[0]) * (r[3]-r[1]))

    cropped_image = image.crop(largest_region)

    return cropped_image

In [7]:
class CTLData(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

        self.category_dict = {}
        for i, category in enumerate(self.data['category'].unique()):
            self.category_dict[category] = i

    def __len__(self):
        return len(self.data)

    def get_category(self, idx):
        return self.category_dict[self.data.iloc[idx]['category']]

    def convert_to_url(self, signature):
        prefix = 'http://i.pinimg.com/400x/%s/%s/%s/%s.jpg'
        return prefix % (signature[0:2], signature[2:4], signature[4:6], signature)

    def get_image(self, signature, local=False):
        if local:
            return io.imread(DATA_PATH + "/imgs/" + signature + ".png")
        else:
            return io.imread(convert_to_url(signature))
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        scene_img = get_image(row['scene_id'])
        product_img = get_image(row['product_id'])
        cropped_scene_img = crop_img(scene_img, row['bbox'])
        

        if self.transform:
            scene_img = self.transform(scene_img)
            product_img = self.transform(product_img)
            cropped_scene_img = self.transform(cropped_scene_img)

        return scene_img, product_img, cropped_scene_img, self.get_category(idx), row['label']

In [8]:
class CTLModel(torch.nn.Module):
    def __init__(self, feature_extractor, attention_mechanism, similarity):
        super(CTLModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.attention_mechanism = attention_mechanism
        self.similarity = similarity
    
    def forward(self, scene_img, product_img):
        # extract features
        scene_features = self.feature_extractor(scene_img)
        product_features = self.feature_extractor(product_img)

        # apply self-attention
        scene_features, scene_attn_weights = self.attention_mechanism(scene_features)
        product_features, product_attn_weights = self.attention_mechanism(product_features)

        # compute similarity
        compatability_scores = self.similarity(scene_features, product_features)

        return compatability_scores, scene_attn_weights, product_attn_weights



In [9]:
# load pretrained resnet50 for feature extraction
resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
resnet = torch.nn.Sequential(*(list(resnet.children())[:-1]))
resnet.eval()

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.query = torch.nn.Linear(input_dim, input_dim)
        self.key = torch.nn.Linear(input_dim, input_dim)
        self.value = torch.nn.Linear(input_dim, input_dim)

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        attn_weights = torch.nn.functional.softmax(q @ k.transpose(-2, -1) / math.sqrt(k.shape[-1]), dim=-1)
        output = attn_weights @ v

        return output, attn_weights

def cos_sim(scene_features, product_features):
    return torch.nn.functional.cosine_similarity(scene_features, product_features, dim=-1)

model = CTLModel(feature_extractor=resnet, attention_mechanism=SelfAttention(2048), similarity=cos_sim)