In [1]:
!pip install timm

Collecting timm
  Downloading timm-0.6.11-py3-none-any.whl (548 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m548.7/548.7 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.6.11
[0m

In [2]:
import torch
import torch.nn.functional as F
import timm
import numpy as np
from torch.nn import Module
from torchvision import transforms
from torchvision.transforms.functional import resize, normalize, pil_to_tensor
from torch.utils.data import DataLoader, ConcatDataset
from datasets import load_from_disk
from scipy.linalg import eigh

In [3]:
EMBEDDING_DIMENSION = 64
BEIT_MODEL = 'beit_large_patch16_224_in22k'

In [4]:
IMAGENET_TRAIN = '../input/guie-imagenetmini-1k/imagenetmini_1k_train'
IMAGENET_VAL = '../input/guie-imagenetmini-1k/imagenetmini_1k_val'
LANDMARKS_TRAIN = '../input/guie-landmarks/landmarks_train'
LANDMARKS_VAL = '../input/guie-landmarks/landmarks_val'
FURNITURE_TRAIN = '../input/guie-furniture/furniture_train'
FURNITURE_VAL = '../input/guie-furniture/furniture_val'
PRODUCTS_TRAIN = '../input/beit-products10k/products_train'
PRODUCTS_VAL = '../input/beit-products10k/products_val'

In [5]:
landmarks_train = load_from_disk(LANDMARKS_TRAIN)
landmarks_train.set_format(type='torch')
landmarks_val = load_from_disk(LANDMARKS_VAL)
landmarks_val.set_format(type='torch')
landmarks_all = ConcatDataset([landmarks_train, landmarks_val])
print(landmarks_train, landmarks_val)
print(len(landmarks_all))

Dataset({
    features: ['label', 'beit'],
    num_rows: 70593
}) Dataset({
    features: ['label', 'beit'],
    num_rows: 7768
})
78361


In [6]:
furniture_train = load_from_disk(FURNITURE_TRAIN)
furniture_train.set_format(type='torch')
furniture_val = load_from_disk(FURNITURE_VAL)
furniture_val.set_format(type='torch')
furniture_all = ConcatDataset([furniture_train, furniture_val])
print(furniture_train, furniture_val)
print(len(furniture_all))

Dataset({
    features: ['label', 'beit'],
    num_rows: 8415
}) Dataset({
    features: ['label', 'beit'],
    num_rows: 931
})
9346


In [7]:
products_train = load_from_disk(PRODUCTS_TRAIN)
products_train.set_format(type='torch')
products_val = load_from_disk(PRODUCTS_VAL)
products_val.set_format(type='torch')
products_all = ConcatDataset([products_train, products_val])
print(products_train, products_val)
print(len(products_all))

Dataset({
    features: ['label', 'beit'],
    num_rows: 127890
}) Dataset({
    features: ['label', 'beit'],
    num_rows: 14041
})
141931


In [8]:
imagenet_train = load_from_disk(IMAGENET_TRAIN)
imagenet_train.set_format(type='torch')
imagenet_val = load_from_disk(IMAGENET_VAL)
imagenet_val.set_format(type='torch')
imagenet_all = ConcatDataset([imagenet_train, imagenet_val])
print(imagenet_train, imagenet_val)
print(len(imagenet_all))

Dataset({
    features: ['label', 'beit'],
    num_rows: 34745
}) Dataset({
    features: ['label', 'beit'],
    num_rows: 3923
})
38668


In [9]:
dataset = ConcatDataset([landmarks_all, furniture_all, products_all, imagenet_all])
print(len(dataset))

268306


Current weights of different categories in `dataset` are

In [10]:
print(len(landmarks_all) / len(dataset))
print(len(furniture_all) / len(dataset))
print(len(products_all) / len(dataset))
print(len(imagenet_all) / len(dataset))

0.29205832146877075
0.03483336190767258
0.5289892883498691
0.1441190282736875


We'd like to assign weights to examples so that the category distribution would resemble more closely the dataset on which the model is evaluated. Let's try the following target weights  
landmarks -> 0.196  
furniture -> 0.106  
products -> 0.199  
imagenet -> rest = 0.499

In [11]:
category_weights = {
    'landmarks': len(dataset) / len(landmarks_all) * 0.196,
    'furniture': len(dataset) / len(furniture_all) * 0.106,
    'products': len(dataset) / len(products_all) * 0.199,
    'imagenet': len(dataset) / len(imagenet_all) * 0.499
}
category_weights

{'landmarks': 0.6710988374318858,
 'furniture': 3.043059704686497,
 'products': 0.3761890918826754,
 'imagenet': 3.462415796007034}

In [12]:
weights = torch.tensor(
    [category_weights['landmarks']] * len(landmarks_all)
    + [category_weights['furniture']] * len(furniture_all)
    + [category_weights['products']] * len(products_all)
    + [category_weights['imagenet']] * len(imagenet_all)
)
print(len(weights))

268306


Compute weighted PCA down to 64 dimensions.

In [13]:
X = torch.cat([
    landmarks_train['beit'], landmarks_val['beit'],
    furniture_train['beit'], furniture_val['beit'],
    products_train['beit'], products_val['beit'],
    imagenet_train['beit'], imagenet_val['beit']
])
print(X.shape)

torch.Size([268306, 1024])


In [14]:
def weighted_PCA(X, weights):
    n, p = X.shape
    mean = (weights @ X) / weights.sum()
    X_centered = X - mean
    cov = ((weights * X_centered.T) @ X_centered) / weights.sum()
    vals, vecs = eigh(cov, subset_by_index=[p - EMBEDDING_DIMENSION, p - 1])
    vals = torch.from_numpy(vals)
    vecs = torch.from_numpy(vecs)
    return mean, vals, vecs

In [15]:
mean, explained_variances, W = weighted_PCA(X, weights)

Actual model simply computes BEiT embeddings and projects down to 64 dimensions with PCA trained on the dataset.

In [16]:
class GUIE(Module):
    def __init__(self):
        super().__init__()
        self.beit = timm.create_model(BEIT_MODEL, pretrained=True, num_classes=0)
        self.pca_mean = mean
        self.pca_projector = W
        self.beit.requires_grad_(False)

    def forward(self, image_batch):
        '''
        Model consumes images as tensors (not PIL-images) and outputs embeddings as tensors.
        Expected shape: (batch, colors, width, height)
        '''
        image_tensors = image_batch / 255.
        image_tensors = resize(image_tensors, size=(224, 224))
        image_tensors = normalize(
            image_tensors,
            mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5]
        )
        embeddings = self.beit(image_tensors)
        embeddings = (embeddings - self.pca_mean) @ self.pca_projector
        return embeddings

In [17]:
model = GUIE()
model.eval()
saved_model = torch.jit.script(model)
saved_model.save('saved_model.pt')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth" to /root/.cache/torch/hub/checkpoints/beit_large_patch16_224_pt22k_ft22k.pth


In [18]:
from zipfile import ZipFile
with ZipFile('submission.zip', 'w') as zip:
    zip.write('./saved_model.pt', arcname='saved_model.pt')

In [19]:
from PIL import Image
from torchvision import transforms
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
input_image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
saved_model_path = 'saved_model.pt'
convert_to_tensor = transforms.Compose([transforms.PILToTensor()])
input_tensor = convert_to_tensor(input_image)
input_batch = input_tensor.unsqueeze(0)

In [20]:
from PIL import Image
import torch
from torchvision import transforms

# Model loading.
model = torch.jit.load(saved_model_path)
model.eval()
embedding_fn = model

# Load image and extract its embedding.
#input_image = Image.open(image_path).convert("RGB")
with torch.no_grad():
  embedding = torch.flatten(embedding_fn(input_batch)[0]).cpu().data.numpy()
embedding

array([ 0.2632971 , -1.6655103 , -2.4893472 ,  3.7273684 ,  1.3950839 ,
        1.3776908 , -2.0331917 ,  0.85838693,  1.6138374 ,  2.3114119 ,
        2.5594916 , -4.367251  , -1.1665906 ,  0.17241156,  1.1958246 ,
       -1.2847953 , -2.8455603 , -2.202995  ,  1.9537129 , -1.8801165 ,
       -2.186744  ,  1.4249673 ,  1.2141016 ,  0.55171984, -1.0276257 ,
        0.09351698,  1.977908  ,  1.498049  , -2.2125213 ,  1.1878104 ,
       -2.5432067 , -1.2673784 , -0.5995904 ,  1.3039222 ,  0.20791796,
       -0.24558282,  0.26261878, -3.3305638 , -1.109017  , -0.13759184,
       -0.65043414, -2.0158453 ,  0.8314079 ,  2.3849723 ,  0.49997306,
       -1.1524358 ,  0.65752596,  1.7776992 , -0.82183826,  1.9799197 ,
        0.80306333,  1.3667268 ,  1.4504836 , -2.1804628 ,  2.1219397 ,
       -0.02977777,  1.3119843 , -2.319443  ,  1.876002  , -3.4972377 ,
       -2.4339538 ,  2.4844184 , -4.4004545 ,  0.3981835 ], dtype=float32)