In [1]:
import torch
import torchvision
import transformers
import timm

from transformers import ViTFeatureExtractor, ViTModel
from transformers import BertModel, BertTokenizer

from PIL import Image

# Image models

In [2]:
transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.CenterCrop((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        
    ]
)

In [3]:
image = Image.open("/data3/dataset/0a/01/0a01028bdb7383ba409036cdad89e0cc")

In [4]:
title = "Носок розовый"

In [5]:
transforms(image).unsqueeze(0).shape

torch.Size([1, 3, 224, 224])

### DINO v2

In [6]:
dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')

Using cache found in /home/aruslantsev/.cache/torch/hub/facebookresearch_dinov2_main
xFormers not available
xFormers not available


In [7]:
dino(transforms(image).unsqueeze(0)).shape

torch.Size([1, 768])

### ViT PyTorch

In [8]:
vit = torchvision.models.vit_l_32(weights=torchvision.models.ViT_L_32_Weights.IMAGENET1K_V1)

In [9]:
vit(transforms(image).unsqueeze(0)).shape

torch.Size([1, 1000])

### ViT TIMM

In [10]:
model = timm.models.vision_transformer.vit_base_patch16_224_dino()

In [11]:
model(transforms(image).unsqueeze(0)).shape

torch.Size([1, 768])

### ViT Transformers

Returns tokens' vectors!

In [12]:
feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
model = ViTModel.from_pretrained('facebook/dino-vitb16')
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state

Some weights of ViTModel were not initialized from the model checkpoint at facebook/dino-vitb16 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
last_hidden_states.shape

torch.Size([1, 197, 768])

In [14]:
outputs.pooler_output.shape

torch.Size([1, 768])

# Text models

### Bert

In [15]:
bert = BertModel.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
inputs = tokenizer(title, return_tensors="pt")
outputs = bert(**inputs)

In [17]:
outputs.last_hidden_state.shape

torch.Size([1, 13, 768])

# Combine text and image

In [18]:
bert = BertModel.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')

bert_multi = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using cache found in /home/aruslantsev/.cache/torch/hub/facebookresearch_dinov2_main
Some weights of the model checkpoint at bert-base-uncased were not use

In [19]:
inputs = tokenizer(title, return_tensors="pt")
text_inp = bert(**inputs).last_hidden_state

In [20]:
image_inp = dino(transforms(image).unsqueeze(0)).unsqueeze(1)

Vector norms are different. Maybe need to scale?

In [21]:
image_inp.norm()

tensor(46.9338, grad_fn=<LinalgVectorNormBackward0>)

In [22]:
text_inp[0, 0, :].norm()

tensor(14.1290, grad_fn=<LinalgVectorNormBackward0>)

### Custom tokens

In [23]:
CLS = torch.zeros(1, 1, 768)
CLS[0, 0, 0] = 1

SEP = torch.zeros(1, 1, 768)
SEP[0, 0, 1] = 1

SEP_IMG = torch.zeros(1, 1, 768)
SEP_IMG[0, 0, 2] = 1

SEP_SKU_IMG = torch.zeros(1, 1, 768)
SEP_SKU_IMG[0, 0, 2] = 1

PAD = torch.zeros(1, 1, 768)
PAD[0, 0, -1] = 1

In [24]:
vec_inp = torch.hstack(
    [
        CLS, 
        text_inp / text_inp.norm(keepdim=True, dim=2), 
        SEP_IMG, 
        image_inp / image_inp.norm(keepdim=True, dim=2), 
        SEP
    ] + [
        PAD
    ] * (64 - 17)
)

In [25]:
vec_inp.shape

torch.Size([1, 64, 768])

In [26]:
bert_multi(inputs_embeds=vec_inp).pooler_output.shape

torch.Size([1, 768])