In [292]:
import os
import cv2
import math
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm import tqdm
import albumentations as A
import matplotlib.pyplot as plt
import itertools
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import timm
import random
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer, ViTModel, ViTFeatureExtractor
from torch.nn import TransformerEncoder, TransformerEncoderLayer

In [293]:
!nvidia-smi

Sun May  8 20:53:24 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A4000    On   | 00000000:1A:00.0 Off |                 Off* |
| 41%   40C    P8    15W / 140W |  16109MiB / 16117MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A4000    On   | 00000000:1B:00.0 Off |                 Off* |
| 41%   36C    P8    14W / 140W |   6230MiB / 16117MiB |      0%      Default |
|       

In [322]:
CKPT_PATH = '/common/home/gg676/536/notebooks/best_ilab2_scaled_new.pt'#best_ilab2.pt'#_epoch2_concatenated_ingr_620k.pt'
TEST_DATA = '/filer/tmp1/gg676/im2recipe/df_merged_TEST_appended_ingredients.pkl'#df_merged_TEST_appended_ingredients.pkl'#df_merged_VAL_appended_ingredients.pkl'#df_merged_VAL_ingredients.pkl'

In [323]:
df_val = pd.read_pickle(TEST_DATA)[:10]
df_val.reset_index(inplace=True, drop=True)

In [324]:
df_val.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10 entries, 0 to 9
Data columns (total 4 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   img_id       10 non-null     object
 1   text_id      10 non-null     object
 2   ingredients  10 non-null     object
 3   id           10 non-null     int64 
dtypes: int64(1), object(3)
memory usage: 448.0+ bytes


In [297]:
df_val.head()

Unnamed: 0,img_id,text_id,ingredients,id
0,ff611e83ca,3a5be7f0f5,100 grams Sweet potato 3 tbsp Honey 3 tbsp Mi...,91864
1,7097aa2903,54755e3dd8,750 ml vodka 3 -4 jalapenos 3 -4 red chilies ...,133073
2,78e994f1d6,3e47ed2ffa,3 ounces dark rum (or to suit taste) 4 tables...,97999
3,3b1f774356,a82c3d3651,1 (18 1/4 ounce) package fudge cake mix 3 tab...,264757
4,3eaa3482d4,fd9fc0c152,1 orange Safeway 1 lb For $1.28 thru 02/09 bu...,399129


In [298]:
class CFG:
    debug = False
    image_path = "/filer/tmp1/gg676/im2recipe/img_data/train_flattened"
    val_path = "/common/users/gg676/test_flattened"#test_flattened"
    #captions_path = "."
    batch_size = 48
    num_workers = 16
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    cross_attn_lr = 1e-4
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 3
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    model_name = 'vit_base_patch16_224'#'resnet50'
    image_embedding = 1000
    text_encoder_model = "distilbert-base-uncased"
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 200

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 1.0

    # image size
    size = 224

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1

In [299]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, ingredients, tokenizer, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names 
        """

        self.image_filenames = image_filenames
        #self.tokenizer_need = tokenizer_need
        self.ingredients = [tokenizer(ingr, 
                               padding='max_length', max_length = 128, truncation=True,
                                return_tensors="pt") for ingr in ingredients]
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {}
        image = cv2.imread(f"{CFG.val_path}/{self.image_filenames[idx]}.jpg")
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['ingredients'] = self.ingredients[idx]
        return item['image'], self.ingredients[idx]['input_ids'], self.ingredients[idx]['attention_mask']
        #

    def __len__(self):
        return len(self.ingredients)


def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )

In [300]:
class CLIPDatasetPreprocessed(torch.utils.data.Dataset):
    def __init__(self, image, input_ids, attn_mask, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names 
        """

        self.image = image
        #self.tokenizer_need = tokenizer_need
        #self.ingredients = [tokenizer(ingr, 
        #                       padding='max_length', max_length = 128, truncation=True,
        #                        return_tensors="pt") for ingr in ingredients]
        self.input_ids = input_ids
        self.attn_mask = attn_mask
        self.transforms = transforms

    def __getitem__(self, idx):
        #item = {}
        #image = cv2.imread(f"{CFG.val_path}/{self.image_filenames[idx]}.jpg")
        #image = self.transforms(image=image)['image']
        #item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        #item['ingredients'] = self.ingredients[idx]
        return self.image[idx], self.input_ids[idx], self.attn_mask[idx]
        #

    def __len__(self):
        return len(self.input_ids)


def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )

In [301]:
def make_train_valid_dfs(df, mode='train'):
    dataframe = df#pd.read_csv(f"{CFG.captions_path}/captions.csv")
    max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
    image_ids = np.arange(0, max_id)
    np.random.seed(42)
    if mode == 'train':
        train_dataframe = dataframe#[dataframe["id"].isin(train_ids)].reset_index(drop=True)
        return train_dataframe
    else:
        valid_dataframe = dataframe#[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
        return valid_dataframe
    #return train_dataframe, valid_dataframe


def build_loaders(dataframe, tokenizer, mode):
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        dataframe["img_id"].values,
        dataframe["ingredients"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

def build_preprocessed_loaders(img, input_ids, attn_mask, mode):
    transforms = get_transforms(mode=mode)
    dataset = CLIPDatasetPreprocessed(
        img,
        input_ids,
        attn_mask,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

In [302]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

In [303]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained('/filer/tmp1/gg676/distilbert-base-uncased')
        else:
            self.model = DistilBertModel(config=DistilBertConfig())
            
        for p in self.model.parameters():
            p.requires_grad = False

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state #last_hidden_state[:, self.target_token_idx, :]

In [304]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


In [305]:
"""
class CrossAttention(nn.Module):
    def __init__(self, model_dim=768, n_heads=2, n_layers=2, num_image_patches=197, num_classes=1, dropout=0.1):
        super().__init__()
        self.text_positional = PositionalEncoding(model_dim, dropout=dropout)
        self.sep_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        layers = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(layers, num_layers=n_layers)
        self.cls_projection = nn.Linear(model_dim, num_classes)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, image_features, text_features, src_key_padding_mask=None):
        #print(image_features.shape)
        batch_size = image_features.shape[0]
        text_features = self.text_positional(text_features)
        sep_token = self.sep_token.expand(batch_size, -1, -1)
        transformer_input = torch.cat((image_features, sep_token, text_features), dim=1)
        if src_key_padding_mask is not None:
            src_key_padding_mask = torch.cat((torch.zeros(image_features.shape[0], 
                                                          image_features.shape[1] + 1).to(CFG.device),
                                             src_key_padding_mask.to(CFG.device)), 1)
            
        transformer_outputs = self.encoder(transformer_input, src_key_padding_mask=src_key_padding_mask)
        projected_output = transformer_outputs[:, 0, :]
        return self.sigmoid(self.cls_projection(projected_output))
"""
"""
class CrossAttention(nn.Module):
    def __init__(self, model_dim=768, n_heads=8, n_layers=4, num_image_patches=197, num_classes=1, dropout=0.3):
        super().__init__()
        self.text_positional = PositionalEncoding(model_dim, dropout=dropout)
        self.sep_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        layers = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(layers, num_layers=n_layers)
        self.cls_projection = nn.Linear(model_dim, num_classes)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, image_features, text_features, src_key_padding_mask=None):
        #print(image_features.shape)
        batch_size = image_features.shape[0]
        text_features = self.text_positional(text_features)
        sep_token = self.sep_token.expand(batch_size, -1, -1)
        transformer_input = torch.cat((image_features, sep_token, text_features), dim=1)
        if src_key_padding_mask is not None:
            src_key_padding_mask = torch.cat((torch.zeros(image_features.shape[0], 
                                                          image_features.shape[1] + 1).to(CFG.device),
                                             src_key_padding_mask.to(CFG.device)), 1)
            
        transformer_outputs = self.encoder(transformer_input, src_key_padding_mask=src_key_padding_mask)
        projected_output = transformer_outputs[:, 0, :]
        return self.sigmoid(self.cls_projection(projected_output))
        
"""       

'\nclass CrossAttention(nn.Module):\n    def __init__(self, model_dim=768, n_heads=8, n_layers=4, num_image_patches=197, num_classes=1, dropout=0.3):\n        super().__init__()\n        self.text_positional = PositionalEncoding(model_dim, dropout=dropout)\n        self.sep_token = nn.Parameter(torch.zeros(1, 1, model_dim))\n        layers = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, batch_first=True)\n        self.encoder = nn.TransformerEncoder(layers, num_layers=n_layers)\n        self.cls_projection = nn.Linear(model_dim, num_classes)\n        self.sigmoid = nn.Sigmoid()\n    \n    def forward(self, image_features, text_features, src_key_padding_mask=None):\n        #print(image_features.shape)\n        batch_size = image_features.shape[0]\n        text_features = self.text_positional(text_features)\n        sep_token = self.sep_token.expand(batch_size, -1, -1)\n        transformer_input = torch.cat((image_features, sep_token, text_features), dim=1)\n        if sr

In [306]:
"""
class CrossAttention(nn.Module):
    def __init__(self, model_dim=768, n_heads=8, n_layers=4, num_image_patches=197, num_classes=1, dropout=0.1):
        super().__init__()
        #self.text_positional = PositionalEncoding(model_dim, dropout=dropout)
        self.pos_encoder = PositionalEncoding(model_dim, dropout)
        self.sep_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        layers = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(layers, num_layers=n_layers)
        self.cls_projection = nn.Linear(model_dim, num_classes)
        self.sigmoid = nn.Sigmoid()
        
    
    
    def forward(self, image_features, text_features, src_key_padding_mask=None):
        #print(image_features.shape)
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        
        
        batch_size = image_features.shape[0]
        #image_features = self.text_positional(image_features)
        #text_features = self.text_positional(text_features)
        
        sep_token = self.sep_token.expand(batch_size, -1, -1)
        src = torch.cat((image_features, sep_token, text_features), dim=1) * math.sqrt(768)
        src = self.pos_encoder(src)
        
        
        #sep_token = self.sep_token.expand(batch_size, -1, -1)
        #transformer_input = torch.cat((image_features, sep_token, text_features), dim=1)
        if src_key_padding_mask is not None:
            src_key_padding_mask = torch.cat((torch.zeros(image_features.shape[0], 
                                                          image_features.shape[1] + 1).to(CFG.device),
                                             src_key_padding_mask.to(CFG.device)), 1)
            
        transformer_outputs = self.encoder(src, src_key_padding_mask=src_key_padding_mask)
        projected_output = transformer_outputs[:, 0, :]
        return self.sigmoid(self.cls_projection(projected_output))
        
"""
        
class CrossAttention(nn.Module):
    def __init__(self, model_dim=768, n_heads=8, n_layers=4, num_image_patches=197, num_classes=1, dropout=0.1):
        super().__init__()
        self.text_positional = PositionalEncoding(model_dim, dropout=dropout)
        self.sep_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        layers = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(layers, num_layers=n_layers)
        self.cls_projection = nn.Linear(model_dim, num_classes)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, image_features, text_features, src_key_padding_mask=None):
        #print(image_features.shape)
        batch_size = image_features.shape[0]
        text_features = self.text_positional(text_features)
        sep_token = self.sep_token.expand(batch_size, -1, -1)
        transformer_input = torch.cat((image_features, sep_token, text_features), dim=1)
        if src_key_padding_mask is not None:
            src_key_padding_mask = torch.cat((torch.zeros(image_features.shape[0], 
                                                          image_features.shape[1] + 1).to(CFG.device),
                                             src_key_padding_mask.to(CFG.device)), 1)
            
        transformer_outputs = self.encoder(transformer_input, src_key_padding_mask=src_key_padding_mask)
        projected_output = transformer_outputs[:, 0, :]
        return self.sigmoid(self.cls_projection(projected_output))
        
        

In [307]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        #self.image_encoder = self.image_encoder.to(CFG.device)
        self.image_encoder.eval()
        self.text_encoder = TextEncoder()
        #self.text_encoder = DistilBertModel.from_pretrained('/filer/tmp1/gg676/distilbert-base-uncased')
        #self.text_encoder = self.text_encoder.to(CFG.device)
        self.text_encoder.eval()
        #classifier = classify(﻿128﻿,﻿100﻿,﻿17496﻿,﻿12﻿,﻿2﻿)
        
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature
        
        self.ntokens = 709  # size of vocabulary
        self.emsize = 768  # embedding dimension
        self.d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
        self.nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        self.nhead = 2  # number of heads in nn.MultiheadAttention
        self.dropout = 0.2  # dropout probability
        #self.transformer_model = TransformerModel(self.ntokens, self.emsize, self.nhead, self.d_hid, self.nlayers, self.dropout).to(CFG.device)
        self.cross_attn = CrossAttention()
        
        

    def forward(self, batch_img, input_ids, attn_mask):
        # Getting Image and Text Features
        with torch.no_grad():
            image_features = self.image_encoder(batch_img, output_hidden_states=True)
            #print(image_features)
            #print(input_ids.shape, attn_mask.shape)
            input_ids, attn_mask = input_ids.squeeze(1), attn_mask.squeeze(1)
            #print("image: ", image_features.hidden_states[-1].shape)
            #print("inp ids: ", input_ids.shape, "--> ", attn_mask.shape)
            text_features = self.text_encoder(
                input_ids=input_ids, attention_mask=attn_mask
        )
        #print(self.text_encoder(input_ids=input_ids, attention_mask=attn_mask, output_hidden_states=True))
        #print("text: ", text_features.shape)
        #print("image shape: ", image_features.shape)
        # Getting Image and Text Embeddings (with same dimension)
        #image_embeddings = self.image_projection(image_features)
        #text_embeddings = self.text_projection(text_features)
        #img_feature, text_feature, attn_mask_list, labels = create_binary_label_data(image_features.hidden_states[-1], text_features, attn_mask)
        #loss, outputs = classifier.forward(concatenated_img_text_pairs, labels)
        #print("labels: ", labels)
        #print(concatenated_img_text_pairs[0])
        #print(img_feature)
        #print(img_feature[0].shape)
        
        
        #src_mask = generate_square_subsequent_mask(CFG.batch_size).to(CFG.device)
        
        output = self.cross_attn(img_feature, text_feature, attn_mask)
        #print(output.shape)
        #print(output)
        
        
        
        
        
        
        """
        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        """
        #print("inside: ", loss.shape)
        return output, labels.unsqueeze(1)#.to(CFG.device)


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

In [325]:
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
valid_loader = build_loaders(df_val, tokenizer, mode="valid")

In [326]:
def prep_image_textvalid_df(valid_loader):
    
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(CKPT_PATH, map_location=CFG.device))
    model.eval()
    model.image_encoder.eval()
    model.text_encoder.eval()
    img_all_list = []
    input_id_list = []
    attn_mask_list = []
    text_all_list = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            img, input_ids, attn_mask = batch
            input_ids, attn_mask = input_ids.squeeze(1), attn_mask.squeeze(1)
            image_features = model.image_encoder(img.to(CFG.device), output_hidden_states=True).hidden_states[-1]
            text_features = model.text_encoder(input_ids=input_ids.to(CFG.device), attention_mask=attn_mask.to(CFG.device))
            #print(attn_mask)
            img_all_list.extend(image_features)
            text_all_list.extend(text_features)
            #input_id_list.extend(input_ids)
            attn_mask_list.extend(attn_mask)
    #print(len(attn_mask_list))
    return img_all_list, text_all_list, attn_mask_list

In [327]:
img_all_list, text_all_list, attn_mask_list = prep_image_textvalid_df(valid_loader)

100%|██████████| 1/1 [00:06<00:00,  6.02s/it]


In [328]:
len(text_all_list)

10

In [329]:
def create_all_comb_img_text(img_all_list, text_all_list, attn_mask_list):
    img_final_list = []
    text_final_list = []
    attn_mask_final_list = []
    count = 0
    for i in img_all_list:
        #print(i)
        for text, attn in zip(text_all_list, attn_mask_list):
            img_final_list.append(i)
            text_final_list.append(text)
            attn_mask_final_list.append(attn)
        if count % 100 == 0:
            print(count)
        count += 1
    #print(img_final_list[0])
    return img_final_list, text_final_list, attn_mask_final_list

In [330]:
img_final_list,text_final_list, attn_mask_list = create_all_comb_img_text(img_all_list, text_all_list, attn_mask_list)

0


In [331]:
len(img_final_list)

100

In [332]:
#brute prediction so sad
def get_pred_all_comb(img_final_list, input_id_final_list, attn_mask_list):
    #valid_loader = build_preprocessed_loaders(img_final_list, text_final_list, attn_mask_list,  mode="valid")
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(CKPT_PATH, map_location=CFG.device))
    model.eval()
    model.image_encoder.eval()
    model.text_encoder.eval()
    model.cross_attn.eval()
    output_matrix = []
    with torch.no_grad():
        #for batch in tqdm(valid_loader):
            #img_features, text_features, attn_mask = batch
        for idx, img_features in enumerate(img_final_list):
            output_for_each_img = []
            for text_features, attn_mask in zip(input_id_final_list, attn_mask_list):
                #input_ids, attn_mask = input_ids.squeeze(1), attn_mask.squeeze(1)
                #image_features = model.image_encoder(img.to(CFG.device), output_hidden_states=True)
                #print("image features shape: ", image_features.hidden_states[-1].shape)
                #text_features = model.text_encoder(input_ids=input_ids.to(CFG.device), attention_mask=attn_mask.to(CFG.device))
                output_for_each_img.extend(model.cross_attn(img_features.unsqueeze(0), text_features.unsqueeze(0), attn_mask.unsqueeze(0)).detach().cpu().numpy()[0])
            print("idx: ", idx)
            output_matrix.append(output_for_each_img)
                
    return output_matrix     

In [333]:
output_matrix = get_pred_all_comb(img_final_list, text_final_list, attn_mask_list)      

idx:  0
idx:  1
idx:  2
idx:  3
idx:  4
idx:  5
idx:  6
idx:  7
idx:  8
idx:  9
idx:  10
idx:  11
idx:  12
idx:  13
idx:  14
idx:  15
idx:  16
idx:  17
idx:  18
idx:  19
idx:  20
idx:  21
idx:  22
idx:  23
idx:  24
idx:  25
idx:  26
idx:  27
idx:  28
idx:  29
idx:  30
idx:  31
idx:  32
idx:  33
idx:  34
idx:  35
idx:  36
idx:  37
idx:  38
idx:  39
idx:  40
idx:  41
idx:  42
idx:  43
idx:  44
idx:  45
idx:  46
idx:  47
idx:  48
idx:  49
idx:  50
idx:  51
idx:  52
idx:  53
idx:  54
idx:  55
idx:  56
idx:  57
idx:  58
idx:  59
idx:  60
idx:  61
idx:  62
idx:  63
idx:  64
idx:  65
idx:  66
idx:  67
idx:  68
idx:  69
idx:  70
idx:  71
idx:  72
idx:  73
idx:  74
idx:  75
idx:  76
idx:  77
idx:  78
idx:  79
idx:  80
idx:  81
idx:  82
idx:  83
idx:  84
idx:  85
idx:  86
idx:  87
idx:  88
idx:  89
idx:  90
idx:  91
idx:  92
idx:  93
idx:  94
idx:  95
idx:  96
idx:  97
idx:  98
idx:  99


In [335]:
def get_image_text_features(valid_df):
    #tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    valid_loader = build_preprocessed_loaders(img_final_list, input_id_final_list, attn_mask_list,  mode="valid")
    
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(CKPT_PATH, map_location=CFG.device))
    model.eval()
    model.image_encoder.eval()
    model.text_encoder.eval()
    model.cross_attn.eval()
    output_list = []
    #img_id_list = []
    img_feature_list = []
    text_feature_list = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            img, input_ids, attn_mask = batch
            #img_id_list.extend(img_id)
            input_ids, attn_mask = input_ids.squeeze(1), attn_mask.squeeze(1)
            image_features = model.image_encoder(img.to(CFG.device), output_hidden_states=True)
            #print("image features shape: ", image_features.hidden_states[-1].shape)
            text_features = model.text_encoder(input_ids=input_ids.to(CFG.device), attention_mask=attn_mask.to(CFG.device))
            output = model.cross_attn(image_features.hidden_states[-1], text_features, attn_mask)
            output_list.extend(output.detach().cpu().numpy())
            img_feature_list.extend(image_features.hidden_states[-1].detach().cpu().numpy())
            text_feature_list.extend(text_features.detach().cpu().numpy())
            #image_embeddings = model.image_projection(image_features)
            #valid_image_featur.append(image_features)
    return output_list, img_feature_list, text_feature_list


In [None]:
"""
def get_image_text_features(valid_df):
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    valid_loader = build_loaders(valid_df, tokenizer, tokenizer_need=False, mode="valid")
    
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(CKPT_PATH, map_location=CFG.device))
    model.eval()
    model.image_encoder.eval()
    model.text_encoder.eval()
    model.cross_attn.eval()
    output_list = []
    img_feature_list = []
    text_feature_list = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            img, input_ids, attn_mask = batch
            input_ids, attn_mask = input_ids.squeeze(1), attn_mask.squeeze(1)
            image_features = model.image_encoder(img.to(CFG.device), output_hidden_states=True)
            #print("image features shape: ", image_features.hidden_states[-1].shape)
            text_features = model.text_encoder(input_ids=input_ids.to(CFG.device), attention_mask=attn_mask.to(CFG.device))
            output = model.cross_attn(image_features.hidden_states[-1], text_features, attn_mask)
            output_list.extend(output.detach().cpu().numpy())
            img_feature_list.extend(image_features.hidden_states[-1].detach().cpu().numpy())
            text_feature_list.extend(text_features.detach().cpu().numpy())
            #image_embeddings = model.image_projection(image_features)
            #valid_image_featur.append(image_features)
    return output_list, img_feature_list, text_feature_list
"""

In [None]:
#valid_df = make_train_valid_dfs(mode='valid')
output_list, img_feature_list, text_feature_list = get_image_text_features(df_val)

In [None]:
len(img_feature_list[0])

In [None]:
output_list_values = [value.item() for value in output_list]

In [None]:
output_list_values

In [None]:
output_list

In [336]:
def rank(output_matrix):
    results_dict = {}
    #projection_txt, projection_img = txt_data, img_data
    med_dict = {}
    idxs = range(100)
    
    glob_rank = []
    glob_recall = {1:0.0,5:0.0,10:0.0}
    
    for i in range(10):
        ids = random.sample(range(0,len(output_matrix)-1), 10)
        
        #txt_sample = projection_txt[ids,:]
        #img_sample = projection_img[ids,:]
        
        similarity = np.array(output_matrix)#np.dot(txt_sample.cpu().numpy(), img_sample.T.cpu().numpy())

        med_rank = []
        
        recall = {1:0.0,5:0.0,10:0.0}
        #print(idxs)
        for ii in range(10):
            #print(ii)
            # get a column of similarities
            sim = similarity[ii, :]
            # sort indices in descending order
            sorting = np.argsort(sim)[::-1].tolist()
            # find where the index of the pair sample ended up in the sorting
            pos = sorting.index(ii)  
            if (pos+1) == 1:
                recall[1]+=1
            if (pos+1) <=5:
                recall[5]+=1
            if (pos+1)<=10:
                recall[10]+=1
            # store the position
            med_rank.append(pos+1)
        for i in recall.keys():
            recall[i]=recall[i]/10
        med = np.median(med_rank)
        for i in recall.keys():
            glob_recall[i]+=recall[i]
        glob_rank.append(med)

    for i in glob_recall.keys():
        glob_recall[i] = glob_recall[i]/10
    
    med_dict["mean_median"] = np.average(glob_rank)
    med_dict["recall"] = glob_recall
    med_dict["median_all"] = glob_rank
    print("Result:",med_dict)
    return med_dict

In [337]:
rank(output_matrix)

Result: {'mean_median': 30.5, 'recall': {1: 0.0, 5: 0.0, 10: 0.0}, 'median_all': [30.5, 30.5, 30.5, 30.5, 30.5, 30.5, 30.5, 30.5, 30.5, 30.5]}


{'mean_median': 30.5,
 'recall': {1: 0.0, 5: 0.0, 10: 0.0},
 'median_all': [30.5, 30.5, 30.5, 30.5, 30.5, 30.5, 30.5, 30.5, 30.5, 30.5]}