In [10]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import BCEWithLogitsLoss

from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import math

from tqdm import tqdm

import requests
from PIL import Image

In [11]:
DATA_PATH = './data/ctl/'

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [13]:
class CTLData(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

        self.category_dict = {}
        for i, category in enumerate(self.data['category'].unique()):
            self.category_dict[category] = i

    def __len__(self):
        return len(self.data)

    def get_category(self, idx):
        return self.category_dict[self.data.iloc[idx]['category']]

    def convert_to_url(self, signature):
        prefix = 'http://i.pinimg.com/400x/%s/%s/%s/%s.jpg'
        return prefix % (signature[0:2], signature[2:4], signature[4:6], signature)

    def get_image(self, signature):
        url = self.convert_to_url(signature)
        return Image.open(requests.get(url, stream=True).raw)

    def _crop_img(self, image, bbox):
        width, height = image.size
        left = bbox[0] * width
        top = bbox[1] * height
        right = bbox[2] * width
        bottom = bbox[3] * height

        regions = [
            (0, 0, width, top),
            (0, bottom, width, height),
            (0, 0, left, height),
            (right, 0, width, height),
        ]

        largest_region = max(regions, key=lambda r: (r[2]-r[0]) * (r[3]-r[1]))

        cropped_image = image.crop(largest_region)

        return cropped_image

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        scene_img = self.get_image(row['scene_id'])
        product_img = self.get_image(row['product_id'])
        bbox = [float(x) for x in row['bbox'][1:-1].split(',')]
        cropped_scene_img = self._crop_img(scene_img, bbox)

        # print size of cropped image
        # print(cropped_scene_img.size)

        # scene_img = torch.from_numpy(np.array(scene_img))
        # product_img = torch.from_numpy(np.array(product_img))
        # cropped_scene_img = torch.from_numpy(np.array(cropped_scene_img))

        if self.transform:
            scene_img = self.transform(scene_img)
            product_img = self.transform(product_img)
            cropped_scene_img = self.transform(cropped_scene_img)


        return product_img, cropped_scene_img, torch.tensor(self.get_category(idx)), torch.tensor(row['label'])


In [14]:
class CTLModel(torch.nn.Module):
    def __init__(self, feature_extractor, attention_mechanism, similarity, embed_dim=128):
        super(CTLModel, self).__init__()
        self.feature_extractor = feature_extractor

        # two layer feed forward network to tranform features into d-dim embedding with unit length
        # Linear-BN-Relu-Dropout-Linear-Norm
        self.feed_forward = torch.nn.Sequential(
            torch.nn.Linear(1000, embed_dim),
            # torch.nn.BatchNorm1d(embed_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(embed_dim, embed_dim),
            # torch.nn.BatchNorm1d(embed_dim),
            torch.nn.ReLU(),
        )
        self.feed_forward.to(device)

        self.attention_mechanism = attention_mechanism
        self.attention_mechanism.to(device)
        
        self.similarity = similarity

    def forward(self, scene_img, product_img, category, verbose=False):
        if verbose:
            fig, ax = plt.subplots(1, 2)
            ax[0].imshow(scene_img[0].permute(1, 2, 0).cpu().detach().numpy())
            ax[1].imshow(product_img[0].permute(1, 2, 0).cpu().detach().numpy())
            plt.show()

        # extract global features and embed
        scene_features = self.feature_extractor(scene_img).to(device)
        product_features = self.feature_extractor(product_img).to(device)

        scene_embedding = self.feed_forward(scene_features)
        product_embedding = self.feed_forward(product_features)

        # extract local features for n patches of scene
        patch_embeddings = []
        patch_size = 142
        for i in range(0, scene_img.shape[2] - patch_size, patch_size):
            for j in range(0, scene_img.shape[3] - patch_size, patch_size):
                patch = scene_img[:, :, i:i+patch_size, j:j+patch_size]

                if verbose:
                    # visualize patch
                    print("patch number", len(patch_embeddings))
                    plt.imshow(patch[0].permute(1, 2, 0).cpu().detach().numpy())
                    plt.show()

                patch_features = self.feature_extractor(patch).to(device)
                patch_embedding = self.feed_forward(patch_features)
                patch_embeddings.append(patch_embedding)


        # compute attention weights for the scene patches
        attention = self.attention_mechanism(patch_embeddings, product_embedding, category)

        # compute global similarity
        global_similarity = self.similarity(scene_embedding, product_embedding)

        # compute local similarity sum over similarity of each patch weighted by attention with product
        local_similarity = sum([self.similarity(patch_embedding, product_embedding) * attention[i] for i, patch_embedding in enumerate(patch_embeddings)])

        compatability = 0.5 * (global_similarity + local_similarity)

        return compatability



In [15]:
# load pretrained resnet50 for feature extraction
resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# resnet = torch.nn.Sequential(*(list(resnet.children())[:-1]))
resnet.eval()
if torch.cuda.is_available():
    resnet = resnet.cuda()


transform = transforms.Compose([
    transforms.Resize(320,
                    #   interpolation=Image.NEAREST,
                      antialias=True),
    transforms.CenterCrop(288),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class Attention(torch.nn.Module):
    def __init__(self, embed_dim=128, num_categories=10):
        super(Attention, self).__init__()
        self.embed_dim = embed_dim
        self.category_embeddings = torch.nn.Embedding(num_categories, embed_dim)

    def forward(self, patch_embeddings, product_embedding, category):
        category_embedding = self.category_embeddings(category)
        attentions = [-torch.norm(patch_embedding - category_embedding, dim=-1) for patch_embedding in patch_embeddings]
        attentions = torch.nn.functional.softmax(torch.stack(attentions), dim=0)
        return attentions


def cos_sim(scene_features, product_features):
    return torch.nn.functional.cosine_similarity(scene_features, product_features, dim=-1)

def l2_sim(scene_features, product_features):
    return -torch.norm(scene_features - product_features, dim=-1)


In [16]:
model = CTLModel(feature_extractor=resnet, attention_mechanism=Attention(), similarity=l2_sim)

In [17]:
df = pd.read_csv(DATA_PATH + 'data.csv')

split = 0.8
train_df = df[int(split * len(df)):]
test_df = df[:int(split * len(df))]

positive_df = train_df[train_df['label'] == 1]
negative_df = train_df[train_df['label'] == 0]

positive_dataset = CTLData(positive_df, transform)
negative_dataset = CTLData(negative_df, transform)

positive_loader = DataLoader(positive_dataset, batch_size=1, shuffle=True)
negative_loader = DataLoader(negative_dataset, batch_size=1, shuffle=True)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

num_epochs = 10
losses = []
margin = 1.0
epochs = []
i = 0
for epoch in range(num_epochs):
    # zip the positive and negative loaders to create pairs
    for (pos_product_img, pos_cropped_scene_img, pos_category, _), (neg_product_img, neg_cropped_scene_img, neg_category, _) in zip(tqdm(positive_loader), negative_loader):
        pos_product_img = pos_product_img.to(device)
        pos_cropped_scene_img = pos_cropped_scene_img.to(device)
        pos_category = pos_category.to(device)
        
        neg_product_img = neg_product_img.to(device)
        neg_cropped_scene_img = neg_cropped_scene_img.to(device)
        neg_category = neg_category.to(device)

        # forward pass for both positive and negative samples
        pos_output = model(pos_cropped_scene_img, pos_product_img, pos_category)
        neg_output = model(neg_cropped_scene_img, neg_product_img, neg_category)

        # compute hinge loss
        # hinge loss is max(0, margin - pos_output + neg_output)
        loss = torch.clamp(margin - pos_output + neg_output, min=0)

        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Loss: {loss.item():.4f}")
            losses.append(loss.item())

    print(f'Epoch [{epoch + 1}/{num_epochs}]')
    epochs.append(len(losses))


# plot losses
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
for x in epochs:
    plt.axvline(x=x, color='r', linestyle='--')
plt.show()
