In [1]:
#!pip install timm

In [2]:
#!pip install albumentations

In [3]:
import os
import cv2
import math
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm import tqdm
import albumentations as A
import matplotlib.pyplot as plt
import itertools
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import timm
import random
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer, ViTModel, ViTFeatureExtractor
from torch.nn import TransformerEncoder, TransformerEncoderLayer

In [4]:
df = pd.read_pickle("/filer/tmp1/gg676/im2recipe/df_merged.pkl") #df_merged
#df_val = pd.read
#df = df.sort_values(by='id')[:158915]
#df.columns = ['text_id', 'img_id', 'ingredients']
#df['ingredients'] = df['ingredients'].str.lstrip()
#df['caption_number'] = df['caption_number'].str.lstrip()
#df.loc[19999, 'caption_number'] = "4"
#df.loc[19999, 'caption'] = "A dog runs across the grass ."
#ids = [id_ for id_ in range(len(df) // 5) for i in range(5)]
#df['id'] = ids
#df.to_csv("captions.csv", index=False)
df.reset_index(inplace=True, drop=True)
df.head()

Unnamed: 0,img_id,text_id,ingredients,id
0,48a9db19c2,f642f3cffc,"8 stalks rhubarb, cut into 3 inch lengths",387667
1,48a9db19c2,f642f3cffc,8 cups water,387667
2,48a9db19c2,f642f3cffc,"13 cup sugar, to taste",387667
3,48a9db19c2,f642f3cffc,"1 sprig mint, as garnish",387667
4,48a9c381f6,a8a884cafb,"1 small head of garlic, peeled and sliced thinly",265527


In [5]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5650110 entries, 0 to 5650109
Data columns (total 4 columns):
 #   Column       Dtype 
---  ------       ----- 
 0   img_id       object
 1   text_id      object
 2   ingredients  object
 3   id           int64 
dtypes: int64(1), object(3)
memory usage: 172.4+ MB


In [6]:
class CFG:
    debug = False
    image_path = "/filer/tmp1/gg676/im2recipe/img_data/train_flattened"
    #captions_path = "."
    batch_size = 128
    num_workers = 48
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    cross_attn_lr = 1e-4
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 3
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model_name = 'vit_base_patch16_224'#'resnet50'
    image_embedding = 1000
    text_encoder_model = "distilbert-base-uncased"
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 200

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 1.0

    # image size
    size = 224

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1

In [7]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


In [8]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, ingredients, tokenizer, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names 
        """

        self.image_filenames = image_filenames
        self.ingredients = [tokenizer(ingr, 
                               padding='max_length', max_length = 128, truncation=True,
                                return_tensors="pt") for ingr in ingredients]
        #self.ingredients = ingredients # list(ingredients)
        #self.ingredients_input_ids = self.ingredients['input_ids']#ingredients_input_ids
        #self.ingredients_attn_mask = self.ingredients['attention_mask']
        #print(ingredients_input_ids.shape, "--> ", ingredients_attn_mask.shape)
        #self.encoded_ingredients = dict(zip(tuple(ingredients_input_ids[:]), ingredients_attn_mask))
        #self.encoded_ingredients = tokenizer(
        #    list(ingredients), padding=True, truncation=True, max_length=CFG.max_length
        #)
        self.transforms = transforms

    def __getitem__(self, idx):
        #item = {
        #    key: torch.tensor(values)
        #    for key, values in zip(self.ingredients_input_ids[idx], self.ingredients_attn_mask[idx])
        #}
        item = {}
        #print(self.ingredients_input_ids[idx].shape)
        #print(self.ingredients_input_ids[idx])
        image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}.jpg")
        #print(image)
        #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        #print(image.shape)
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['ingredients'] = self.ingredients[idx]
        #print("ingre: ", self.ingredients[idx], type(self.ingredients[idx]))
        #return item#self.ingredients_input_ids[idx], torch.tensor(self.ingredients_attn_mask[idx])
        return item['image'], self.ingredients[idx]['input_ids'], self.ingredients[idx]['attention_mask']
        #

    def __len__(self):
        return len(self.ingredients)



def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )

In [9]:
class ImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(
        self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
    ):
        super().__init__()
        #self.model = timm.create_model(
        #    model_name, pretrained=True, num_classes=0)#, num_classes=0, global_pool="avg"
        #)
        self.model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        #for p in self.model.parameters():
        #    p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

In [10]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained('/filer/tmp1/gg676/distilbert-base-uncased')
        else:
            self.model = DistilBertModel(config=DistilBertConfig())
            
        for p in self.model.parameters():
            p.requires_grad = False

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state #last_hidden_state[:, self.target_token_idx, :]

In [11]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

In [12]:
def create_binary_label_data(img_features, text_features, attn_mask):
    concatenated_pairs = []
    img_feature_list = []
    text_feature_list = []
    attn_mask_list = []
    non_matching_index_list = []
    label_list = []
    #print("shape of batch: ", img_features.shape, "--> ", text_features.shape)
    for each_img_idx in range(len(img_features)):
        zero_label_samples = []
        for each_img_img_idx in range(len(img_features)):
            #print("count: ", each_img_img_idx)
            #print("img feature: ", img_features[each_img_idx][0].shape)
            #print("img feature: ", img_features[each_img_idx][0])
            if each_img_idx == each_img_img_idx:
                img_feature_list.append(img_features[each_img_idx])
                text_feature_list.append(text_features[each_img_idx])
                attn_mask_list.append(attn_mask[each_img_idx])
                #correct_pairs = (img_features[each_img_idx], text_features[each_img_idx], 1)
                #concatenated_pairs.append(correct_pairs)
                label_list.append(torch.tensor(1.0))
                #concatenated_pairs.append((torch.cat(img_features[each_img_idx], text_features[each_img_idx]), 1)
            else:
                #zero_label_samples.append((img_features[each_img_idx], text_features[each_img_img_idx], 0))
                non_matching_index_list.append(each_img_img_idx)
        #sampled_zero_labels_tuples = [item for item in zero_label_samples if item[1] == 0]
        #print("zero label: ", zero_label_samples[0])
        random_sampled_neg_label_idx = random.choice(non_matching_index_list)
        img_feature_list.append(img_features[each_img_idx])
        text_feature_list.append(text_features[random_sampled_neg_label_idx])
        attn_mask_list.append(attn_mask[random_sampled_neg_label_idx])
        label_list.append(torch.tensor(0.0))
        #concatenated_pairs.append(random.choice(zero_label_samples))
    #print("length: ", len(img_feature_list), "--> ", len(text_feature_list))
    #print("shapes: ", img_feature_list[0].shape, text_feature_list[0].shape)
    #print("\n 2nd: ", type(img_feature_list[1]), "\n")
    #print("label: ", label_list[0])
    
    text_feature_list = torch.stack(text_feature_list)
    img_feature_list = torch.stack(img_feature_list)
    label_list = torch.stack(label_list)
    attn_mask_list = torch.stack(attn_mask_list)
    #return list(zip(*concatenated_pairs))[0], list(zip(*concatenated_pairs))[1], list(zip(*concatenated_pairs))[2]
    return img_feature_list, text_feature_list, attn_mask_list, label_list
    #for j in concatenated_pairs:
    #    print("j -> ", j[1])
    #return [i[0] for i in concatenated_pairs], [i[1] for i in concatenated_pairs]

In [13]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


In [14]:
class CrossAttention(nn.Module):
    def __init__(self, model_dim=768, n_heads=2, n_layers=2, num_image_patches=197, num_classes=1, dropout=0.1):
        super().__init__()
        self.text_positional = PositionalEncoding(model_dim, dropout=dropout)
        self.sep_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        layers = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(layers, num_layers=n_layers)
        self.cls_projection = nn.Linear(model_dim, num_classes)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, image_features, text_features, src_key_padding_mask=None):
        #print(image_features.shape)
        batch_size = image_features.shape[0]
        text_features = self.text_positional(text_features)
        sep_token = self.sep_token.expand(batch_size, -1, -1)
        transformer_input = torch.cat((image_features, sep_token, text_features), dim=1)
        if src_key_padding_mask is not None:
            src_key_padding_mask = torch.cat((torch.zeros(image_features.shape[0], 
                                                          image_features.shape[1] + 1).to(CFG.device),
                                             src_key_padding_mask.to(CFG.device)), 1)
            
        transformer_outputs = self.encoder(transformer_input, src_key_padding_mask=src_key_padding_mask)
        projected_output = transformer_outputs[:, 0, :]
        return self.sigmoid(self.cls_projection(projected_output))
        
        

In [15]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        #self.image_encoder = self.image_encoder.to(CFG.device)
        self.image_encoder.eval()
        self.text_encoder = TextEncoder()
        #self.text_encoder = DistilBertModel.from_pretrained('/filer/tmp1/gg676/distilbert-base-uncased')
        #self.text_encoder = self.text_encoder.to(CFG.device)
        self.text_encoder.eval()
        #classifier = classify(﻿128﻿,﻿100﻿,﻿17496﻿,﻿12﻿,﻿2﻿)
        
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature
        
        self.ntokens = 709  # size of vocabulary
        self.emsize = 768  # embedding dimension
        self.d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
        self.nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        self.nhead = 2  # number of heads in nn.MultiheadAttention
        self.dropout = 0.2  # dropout probability
        #self.transformer_model = TransformerModel(self.ntokens, self.emsize, self.nhead, self.d_hid, self.nlayers, self.dropout).to(CFG.device)
        self.cross_attn = CrossAttention()
        
        

    def forward(self, batch_img, input_ids, attn_mask):
        # Getting Image and Text Features
        with torch.no_grad():
            image_features = self.image_encoder(batch_img, output_hidden_states=True)
            #print(image_features)
            input_ids, attn_mask = input_ids.squeeze(1), attn_mask.squeeze(1)
            #print("image: ", image_features.hidden_states[-1].shape)
            #print("inp ids: ", input_ids.shape, "--> ", attn_mask.shape)
            text_features = self.text_encoder(
                input_ids=input_ids, attention_mask=attn_mask
        )
        #print(self.text_encoder(input_ids=input_ids, attention_mask=attn_mask, output_hidden_states=True))
        #print("text: ", text_features.shape)
        #print("image shape: ", image_features.shape)
        # Getting Image and Text Embeddings (with same dimension)
        #image_embeddings = self.image_projection(image_features)
        #text_embeddings = self.text_projection(text_features)
        img_feature, text_feature, attn_mask_list, labels = create_binary_label_data(image_features.hidden_states[-1], text_features, attn_mask)
        #loss, outputs = classifier.forward(concatenated_img_text_pairs, labels)
        #print("labels: ", labels)
        #print(concatenated_img_text_pairs[0])
        #print(img_feature)
        #print(img_feature[0].shape)
        
        
        #src_mask = generate_square_subsequent_mask(CFG.batch_size).to(CFG.device)
        
        output = self.cross_attn(img_feature, text_feature, attn_mask_list)
        #print(output.shape)
        #print(output)
        
        
        
        
        
        
        """
        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        """
        #print("inside: ", loss.shape)
        return output, labels.unsqueeze(1).to(CFG.device)


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

In [16]:
batch_size = 4
dim = 256
embeddings = torch.randn(batch_size, dim)
out = embeddings @ embeddings.T
print(F.softmax(out, dim=-1))

tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])


In [17]:
def make_train_valid_dfs():
    dataframe = df#pd.read_csv(f"{CFG.captions_path}/captions.csv")
    max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
    image_ids = np.arange(0, max_id)
    np.random.seed(42)
    #valid_ids = np.random.choice(
    #    image_ids, size=int(0.2 * len(image_ids)), replace=False
    #)
    #train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe#[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    #valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    return train_dataframe#, valid_dataframe


def build_loaders(dataframe, tokenizer, mode):
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        dataframe["img_id"].values,
        dataframe["ingredients"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

In [18]:
def train_epoch(model, train_loader, optimizer, criterion, lr_scheduler, step):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_object:
        #batch = {k: v.to('cuda:0') for k, v in batch.items() if k != "ingredients"}
        img, input_ids, attn_mask = batch
        #print(batch)
        img = img.to(CFG.device)
        #batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "ingredients"}
        input_ids = input_ids.to(CFG.device)
        attn_mask = attn_mask.to(CFG.device)
        output, labels = model(img, input_ids, attn_mask)
        loss = criterion(output, labels)
        #print("outside: ", loss.shape)
        optimizer.zero_grad()
        loss.backward()#torch.ones(1).to(CFG.device))
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()

        count = labels.size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter


def valid_epoch(model, valid_loader):
    loss_meter = AvgMeter()

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "ingredients"}
        loss = model(batch)

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    return loss_meter



In [19]:
import os
import time
#os.environ['CUDA_VISIBLE_DEVICES'] = '0,3'
def main():
    
    train_df = make_train_valid_dfs()
    start = time.time()
    tokenizer = DistilBertTokenizer.from_pretrained('/filer/tmp1/gg676/distilbert-base-uncased')
    train_loader = build_loaders(train_df, tokenizer, mode="train")
    print("Data loading time taken: ", time.time()-start)
    model = CLIPModel()
    #params = [
    #    {"params": model.cross_attn.parameters(), "lr": CFG.cross_attn_lr},
        #{"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
    #    { "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
    #    ]
    #valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
    print("Time taken to load: ", time.time()-start)
    #if torch.cuda.device_count() > 1:
    #  print(torch.cuda.device_count(), "GPUs!")
    #  model = nn.DataParallel(model)
    #model.to(f'cuda:{model.device_ids[0]}')
    model = CLIPModel().to(CFG.device)
    
    criterion = nn.BCELoss()

    optimizer = torch.optim.AdamW(model.cross_attn.parameters(), lr=CFG.cross_attn_lr, weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
    )
    step = "epoch"

    best_loss = float('inf')
    for epoch in range(CFG.epochs):
        print(f"Epoch: {epoch + 1}")
        model.train()
        train_loss = train_epoch(model, train_loader, optimizer, criterion, lr_scheduler, step)
        #model.eval()
        #with torch.no_grad():
        #    valid_loss = valid_epoch(model, valid_loader)
        
        if train_loss.avg < best_loss:
            best_loss = train_loss.avg
            torch.save(model.state_dict(), "best_ilab2.pt")
            print("Saved Best Model!")
        
        lr_scheduler.step(train_loss.avg)

In [20]:
main()

KeyboardInterrupt: 

In [None]:
def get_image_embeddings(valid_df, model_path):
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
    
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(model_path, map_location=CFG.device))
    model.eval()
    
    valid_image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_features = model.image_encoder(batch["image"].to(CFG.device))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)
    return model, torch.cat(valid_image_embeddings)

In [None]:
_, valid_df = make_train_valid_dfs()
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")

In [None]:
def find_matches(model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(CFG.device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)
    
    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T
    
    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
    matches = [image_filenames[idx] for idx in indices[::5]]
    print(matches)
    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        image = cv2.imread(f"{CFG.image_path}/{match}.jpg")
        #print(image)
        #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")
    
    plt.show()

In [None]:
find_matches(model, 
             image_embeddings,
             query="Chicken thighs",
             image_filenames=valid_df['img_id'].values,
             n=9)

In [None]:
image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

In [None]:
image_encoder

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")


In [None]:
inputs = feature_extractor(img, return_tensors="pt")

In [None]:
with torch.no_grad():
    outputs = image_encoder(**inputs, output_hidden_states=True)

In [None]:
outputs.hidden_states[-1][0][1:].shape

In [None]:
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()


In [None]:
import urllib
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

config = resolve_data_config({}, model=model)
transform = create_transform(**config)

url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
urllib.request.urlretrieve(url, filename)
img = Image.open(filename).convert('RGB')
tensor = transform(img).unsqueeze(0) # transform and add batch dimension


In [None]:
with torch.no_grad():
    out = model(tensor)

In [None]:
out.shape

In [None]:
out.hidden_states[-1]

In [None]:
out