In [22]:
import torch
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import Dataset
from skimage import io
import numpy as np
import matplotlib.pyplot as plt
import math

DATA_PATH = './data/ctl/'

  from .collection import imread_collection_wrapper


In [None]:
def crop_img(image, bbox):
    image = Image.fromarray(image)
    width, height = image.size
    left = bbox[0] * width
    top = bbox[1] * height
    right = bbox[2] * width
    bottom = bbox[3] * height

    regions = [
        (0, 0, width, top),
        (0, bottom, width, height),
        (0, 0, left, height),
        (right, 0, width, height),
    ]

    largest_region = max(regions, key=lambda r: (r[2]-r[0]) * (r[3]-r[1]))

    cropped_image = image.crop(largest_region)

    return cropped_image

In [None]:
class CTLData(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

        self.category_dict = {}
        for i, category in enumerate(self.data['category'].unique()):
            self.category_dict[category] = i

    def __len__(self):
        return len(self.data)

    def get_category(self, idx):
        return self.category_dict[self.data.iloc[idx]['category']]

    def convert_to_url(self, signature):
        prefix = 'http://i.pinimg.com/400x/%s/%s/%s/%s.jpg'
        return prefix % (signature[0:2], signature[2:4], signature[4:6], signature)

    def get_image(self, signature, local=False):
        if local:
            return io.imread(DATA_PATH + "/imgs/" + signature + ".png")
        else:
            return io.imread(convert_to_url(signature))
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        scene_img = get_image(row['scene_id'])
        product_img = get_image(row['product_id'])
        cropped_scene_img = crop_img(scene_img, row['bbox'])
        

        if self.transform:
            scene_img = self.transform(scene_img)
            product_img = self.transform(product_img)
            cropped_scene_img = self.transform(cropped_scene_img)

        return scene_img, product_img, cropped_scene_img, self.get_category(idx), row['label']