# Google image embeddings

In [1]:
!pip install timm
import timm

In [2]:
import torch
from torch.utils.data import Dataset 
from torch.utils.data import DataLoader

import torchvision.models as models
from torchvision import transforms as tf

import cv2
import numpy as np
import matplotlib.pyplot as plt

import os
from pathlib import Path
from zipfile import ZipFile
from tqdm.notebook import tqdm

DEVICE = "cuda" if torch.cuda.is_available else "cpu"

In [7]:
class EmbeddingsDataset(Dataset):
    def __init__(self, size=[256, 256], aug_transforms=None, max_elem_per_class=None):
        self.images = []
        self.augmentation = aug_transforms
        self.max_elem_per_class = max_elem_per_class
        
        self.basic_transforms = tf.Compose([
            tf.Resize(size),
            tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        
        path = "../input/130k-images-512x512-universal-image-embeddings"
        dirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
        
        for folder in dirs:
            tree=os.walk(path+"/"+folder, topdown=True)
            folder_size = len(next(tree)[2])
            for i in tqdm(range(folder_size if self.max_elem_per_class is None else self.max_elem_per_class)):
                img_path = path+"/"+folder+"/"+"image"+"{:04d}".format(i)+".jpeg"
                print(img_path)
                image = cv2.imread(img_path)
                image = torch.tensor(image).permute((2, 0, 1))[None,:,:,:] / 255
                image = self.basic_transforms(image)
                self.images += [image]
        
    
    def __getitem__(self, idx):
        return images[idx]
    
    def __len__(self):
        pass

In [3]:
import glob, os
os.chdir("../input/130k-images-512x512-universal-image-embeddings/artwork")
for file in glob.glob("*.jpg"):
    print(file)

In [8]:
dataset = EmbeddingsDataset()

In [12]:
"image" + "{:04d}".format(100) + ".jpg"

tree=os.walk("../input/130k-images-512x512-universal-image-embeddings/apparel", topdown=True)
print(len(next(tree)[2])) 

path = "../input/130k-images-512x512-universal-image-embeddings"
dirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
dirs

artwork_path = "../input/130k-images-512x512-universal-image-embeddings"+"/"+dirs[0]+"/"
imgs = [d for d in os.listdir(artwork_path) if os.path.isdir(os.path.join(artwork_path, d))]
len(imgs)

In [23]:
def weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, torch.nn.Linear) :
        torch.nn.init.xavier_uniform_(m.weight)
        
    if isinstance(m, torch.nn.ConvTranspose2d) or isinstance(m, torch.nn.Conv2d):
        torch.nn.init.xavier_normal_(m.weight.data)

In [36]:
class Embedder(torch.nn.Module):
    def __init__(self, model_name, target_size=[384, 384], emb_size=64):
        super().__init__()
        
        self.target_size = target_size
        self.emb_size = emb_size
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0) 
        self.backbone.requires_grad = False
        
        self.head = torch.nn.Sequential(
            torch.nn.Linear(1536, 1024),
            torch.nn.Linear(1024, self.emb_size)
        )
        self.head.apply(weights_init)

        
    def forward(self, x):
        x = torch.tensor(x).permute((2, 0, 1))[None,:,:,:] 
        
        x = tf.functional.resize(x, size=self.target_size)
        
        if not self.training:
            x = x / 255.
            x = tf.functional.normalize(x, 
                                        mean=[0.485, 0.456, 0.406], 
                                        std=[0.229, 0.224, 0.225])
            
        x = self.backbone(x)
        x = self.head(x)
        
        return x
        
        
    def params_count(self):
        print("Params in backbone: ", sum(p.numel() for p in self.backbone.parameters()))
        print("Params in head: ", sum(p.numel() for p in self.head.parameters()))
        

In [37]:
model = Embedder("swin_large_patch4_window12_384_in22k")
model.params_count()

In [28]:
img = cv2.imread("../input/130k-images-512x512-universal-image-embeddings/landmark/image0000.jpg")
plt.imshow(img)

In [29]:
rez = model(img)

In [31]:
rez

In [64]:
model.eval()
saved_model = torch.jit.script(model)
saved_model.save('saved_model.pt')

with ZipFile('submission.zip','w') as zip:           
    zip.write('./saved_model.pt', arcname='saved_model.pt') 