In [1]:
import torch
import math
import random
import numpy as np

from datasets import Recipe1MDataset
from time import time
from torch import nn
from models import TextEncoder, ImageEncoder
from helper import calculate_metrics
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from tqdm import tqdm
from pprint import pprint

In [5]:
def get_transformer_input(image_features, text_embedding, input_attention_mask):
    num_negative_to_positive_sample_ratio = 2

    image_features = image_features.clone()
    text_embedding = text_embedding.clone()
    input_attention_mask = input_attention_mask.clone()

    input_batch_size = image_features.shape[0]
    output_batch_size = (num_negative_to_positive_sample_ratio + 1) * input_batch_size
    ground_truths = torch.zeros(output_batch_size)
    ground_truths[:input_batch_size] = 1

    final_image_features = torch.zeros(output_batch_size, *image_features.shape[1:])
    final_text_embeddings = torch.zeros(output_batch_size, *text_embedding.shape[1:])
    output_attention_mask = torch.zeros(output_batch_size, *input_attention_mask.shape[1:])

    final_image_features[:input_batch_size] = image_features.clone()
    final_text_embeddings[:input_batch_size] = text_embedding.clone()

    for run_num in range(num_negative_to_positive_sample_ratio):
        a = torch.randperm(input_batch_size)
        b = torch.zeros(input_batch_size).to(dtype=torch.int64)
        for ind in range(input_batch_size):
            c = random.randint(0, input_batch_size - 1)
            while c == a[ind]:
                c = random.randint(0, input_batch_size - 1)
            b[ind] = c

#         print(a)
#         print(b)
        
        final_image_features[(1 + run_num) * input_batch_size : (2 + run_num) * input_batch_size] = image_features[a].clone()
        final_text_embeddings[(1 + run_num) * input_batch_size : (2 + run_num) * input_batch_size] = text_embedding[b].clone()
        output_attention_mask[(1 + run_num) * input_batch_size : (2 + run_num) * input_batch_size] = \
            input_attention_mask[b].clone()

    return final_image_features.clone(), final_text_embeddings.clone(), output_attention_mask.clone(), ground_truths.clone()


def save_model(model, fpath):
    torch.save(model, fpath)


def freeze_params(model):
    for param in model.parameters():
        param.requires_grad = False


def compute_ranks(sims):
    ranks = []
    preds = []
    # loop through the N similarities for images
    for ii in range(sims.shape[0]):
        # get a column of similarities for image ii
        sim = sims[ii, :]
        # sort indices in descending order
        sorting = np.argsort(sim)[::-1].tolist()
        # find where the index of the pair sample ended up in the sorting
        pos = sorting.index(ii)
        ranks.append(pos + 1.0)
        preds.append(sorting[0])
    # pdb.set_trace()
    return np.asarray(ranks), preds


def rank(rcps: list, imgs: list, attention_masks: list, model=None, retrieved_type='recipe', retrieved_range=100,
         verbose=False, device='cuda'):
    t1 = time()
    N = retrieved_range
    data_size = len(imgs)
    glob_rank = []
    glob_recall = {1: 0.0, 5: 0.0, 10: 0.0}
    softmax = nn.Softmax(dim=-1)
    # average over 10 sets
    for i in range(2):
        ids_sub = np.random.choice(data_size, N, replace=False)
        # imgs_sub = imgs[ids_sub, :]
        # rcps_sub = rcps[ids_sub, :]
        imgs_sub = [imgs[ind] for ind in ids_sub]
        rcps_sub = [rcps[ind] for ind in ids_sub]
        attention_masks_sub = [attention_masks[ind] for ind in ids_sub]
        probs = np.zeros((N, N))
        for x in tqdm(range(N)):
            for y in range(N):
                # if retrieved_type == 'recipe':
                #     probs[x] = model(imgs_sub[x].repeat(N, 1, 1), rcps_sub)[:, 1]
                # else:
                #     probs[x] = model(imgs_sub, rcps_sub[x].repeat(N, 1, 1))[:, 1]
                try:
                    if retrieved_type == 'recipe':
                        probs[x][y] = softmax(model(imgs_sub[x].unsqueeze(0).to(device), rcps_sub[y].unsqueeze(0).to(device),
                                                    ~attention_masks_sub[y].bool().unsqueeze(0).to(device)))[0, 1]
                    else:
                        probs[x][y] = softmax(model(imgs_sub[y].unsqueeze(0).to(device), rcps_sub[x].unsqueeze(0).to(device),
                                                    ~attention_masks_sub[y].bool().unsqueeze(0).to(device)))[0, 1]
                except RuntimeError as e:
                    print(imgs_sub[x].unsqueeze(0).shape, rcps_sub[y].unsqueeze(0).shape, attention_masks_sub[y].unsqueeze(0).shape)
                    print(attention_masks_sub)
                    print(ids_sub, x, y)
                    raise(RuntimeError(str(e)))
        # loop through the N similarities for images
        ranks, _ = compute_ranks(probs)

        recall = {1: 0.0, 5: 0.0, 10: 0.0}
        for ii in recall.keys():
            recall[ii] = (ranks <= ii).sum() / ranks.shape[0]
        med = int(np.median(ranks))
        for ii in recall.keys():
            glob_recall[ii] += recall[ii]
        glob_rank.append(med)

    for i in glob_recall.keys():
        glob_recall[i] = glob_recall[i] / 10

    medR = np.mean(glob_rank)
    medR_std = np.std(glob_rank)
    t2 = time()
    if verbose:
        print(f'=>retrieved_range={retrieved_range}, MedR={medR:.4f}({medR_std:.4f}), time={t2 - t1:.4f}s')
        print(f'Global recall: 1: {glob_recall[1]:.4f}, 5: {glob_recall[5]:.4f}, 10: {glob_recall[10]:.4f}')
    return medR, medR_std, glob_recall

In [4]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class CrossModalAttention(nn.Module):
    def __init__(self, model_dim=768, n_heads=2, n_layers=2, num_image_patches=197, num_classes=2, drop_rate=0.1):
        super().__init__()
        self.text_pos_embed = SinusoidalPositionalEncoding(model_dim, dropout=drop_rate)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        self.sep_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        self.image_pos_embed = nn.Parameter(torch.zeros(1, num_image_patches + 1, model_dim))
        self.image_pos_drop = nn.Dropout(p=drop_rate)
        layers = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=n_heads,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layers, num_layers=n_layers)
        self.cls_projection = nn.Linear(model_dim, num_classes)
        
    def forward(self, image_features, text_features, src_key_padding_mask=None):
        batch_size = image_features.shape[0]
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        image_features = torch.cat((cls_token, image_features), dim=1)
        image_features = image_features + self.image_pos_embed
        image_features = self.image_pos_drop(image_features)
        
        text_features = self.text_pos_embed(text_features)
        
        sep_token = self.sep_token.expand(batch_size, -1, -1)
        transformer_input = torch.cat((image_features, sep_token, text_features), dim=1)
        if src_key_padding_mask is not None:
            src_key_padding_mask = torch.cat((torch.zeros(batch_size, image_features.shape[1] + 1).to(transformer_input.device), src_key_padding_mask), dim=1)
        transformer_input = transformer_input.transpose(1, 0)
        src_key_padding_mask = src_key_padding_mask.transpose(1, 0)
        transformer_outputs = self.encoder(transformer_input, src_key_padding_mask=src_key_padding_mask)
        cls_outputs = transformer_outputs[0, :, :]
        return self.cls_projection(cls_outputs)
        # return transformer_outputs

In [6]:
# Change paths here.
saved_model_path = '/common/home/as3503/as3503/courses/cs536/final_project/final_project/saved_models/model.pt'
transformer_model_path = '/common/home/as3503/as3503/courses/cs536/final_project/final_project/saved_models/1b1huuko/model_train_encoders_False_epoch_0.pt'

saved_weights = torch.load(saved_model_path, map_location='cpu')
# transformer_weights = torch.load(transformer_model_path, map_location='cpu')

device = 'cuda:7'
text_encoder = TextEncoder(2, 2)
text_encoder.load_state_dict(saved_weights['txt_encoder'])
text_encoder = text_encoder.to(device)

image_encoder = ImageEncoder()
image_encoder.load_state_dict(saved_weights['img_encoder'])
image_encoder = image_encoder.to(device)

cm_transformer = CrossModalAttention().to(device)
# cm_transformer.load_state_dict(transformer_weights['cm_transformer'])
cm_transformer = cm_transformer.to(device)

val_dataset = Recipe1MDataset(part='val')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

batch_size = 8
val_loader = DataLoader(val_dataset, batch_size=batch_size)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [52]:
cm_transformer = CrossModalAttention()

In [93]:
cm_transformer.load_state_dict(torch.load('saved_models/k89llnjz/model_train_encoders_False_num_its_1000.pt', map_location='cpu')['cm_transformer'])

<All keys matched successfully>

In [94]:
# text_encoder = text_encoder.to(device)
# image_encoder = image_encoder.to(device)
# cm_transformer = cm_transformer.to(device)

In [95]:
cm_transformer.eval()
with torch.no_grad():
    for text, image in val_loader:
        text_inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
        text_outputs = text_encoder(**text_inputs)
        image_outputs = image_encoder(image.to(device))
        # outputs = []
        # for i in range(8):
        #     for j in range(8):
        #         output = cm_transformer(image_outputs[i].unsqueeze(0), text_outputs[j].unsqueeze(0), ~text_inputs.attention_mask[j].bool().unsqueeze(0))
                # outputs.append(torch.softmax(output, dim=1).detach().cpu())
        transformer_image_inputs, transformer_text_inputs, output_attention_mask, ground_truth = \
            get_transformer_input(image_outputs, text_outputs, text_inputs.attention_mask)
        text_padding_mask = ~output_attention_mask.bool()
        indices = torch.randperm(transformer_image_inputs.size()[0])
        outputs = cm_transformer(transformer_image_inputs[indices].to(device), transformer_text_inputs[indices].to(device), text_padding_mask[indices].to(device))
        # print(transformer_image_inputs)
        # print(transformer_text_inputs)
        # output1 = cm_transformer(transformer_image_inputs[1, :, :].unsqueeze(0).to(device), transformer_text_inputs[6, :, :].unsqueeze(0).to(device), text_padding_mask[6, :].unsqueeze(0).to(device))
        # output2 = cm_transformer(transformer_image_inputs[10].unsqueeze(0).to(device), transformer_text_inputs[10].unsqueeze(0).to(device), text_padding_mask[10].unsqueeze(0).to(device))
        # outputs = cm_transformer(image_outputs[6].unsqueeze(0), text_outputs[1].unsqueeze(0), ~text_inputs.attention_mask[1].bool().unsqueeze(0))
        # print(outputs)
        print(torch.softmax(outputs, dim=1))
        print(ground_truth)
        # pprint(output1)
        # pprint(output2)
        break

tensor([[9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06]], device='cuda:7')
tensor([1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.])


In [101]:
a = torch.randn(8, 197, 768)
b = torch.randn(8, 200, 768)
# c = torch.randn(8, 200).bool()
c = torch.zeros(8, 200)
with torch.no_grad():
    temp_outputs = cm_transformer(a.to(device), b.to(device), c.to(device))
    print(torch.softmax(outputs, dim=1))

tensor([[9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06],
        [9.9999e-01, 9.6588e-06]], device='cuda:7')


In [8]:
rank(text_embeddings, image_features, attention_masks, model=cm_transformer, device=device, verbose=True)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [02:06<00:00,  1.27s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [02:09<00:00,  1.29s/it]

=>retrieved_range=100, MedR=56.5000(4.5000), time=255.9803s
Global recall: 1: 0.0020, 5: 0.0170, 10: 0.0250





(56.5, 4.5, {1: 0.002, 5: 0.016999999999999998, 10: 0.025})

In [None]:
print('Calculating Metrics')
image_encoder.eval()
text_encoder.eval()
cm_transformer.eval()

text_embeddings = list()
image_features = list()
attention_masks = list()
with torch.no_grad():
    for text, image in tqdm(val_loader):
        text_inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
        text_outputs = text_encoder(**text_inputs)
        image_outputs = image_encoder(image.to(device))

        for text_output, image_feature, attention_mask in zip(text_outputs, image_outputs, text_inputs.attention_mask):
            text_embeddings.append(text_output.cpu())
            image_features.append(image_feature.cpu())
            attention_masks.append(attention_mask.cpu())

Calculating Metrics


 22%|███████████▌                                        | 2695/12148 [16:56<59:31,  2.65it/s]