In [None]:
import torch
from rich.progress import track
from bpe_encoder import get_bpe_encoder
import PIL
from torchvision.transforms import v2 as transforms
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 12})
import textwrap
import os
from datamodules import COCOCaptionsDataModule
from models.SHRe import SHRePreTrainingLightningModule
import logging

logger = logging.getLogger(__name__)

In [None]:
coco_args = {
    "data_path": "/workspace",
    "num_max_bpe_tokens": 64,
    "task": "captioning",
    "color_jitter": None,
    "beit_transforms": False,
    "crop_scale": [1.0, 1.0],
    "batch_size": 256,
    "num_workers": 4,
    "shuffle": False,
    "drop_last": False,
}

In [None]:
version = ""
model_path = os.path.join("/workspace/models", version, 'fp32_last.ckpt')

In [None]:
encoder = get_bpe_encoder('/workspace')

In [None]:
@torch.no_grad()
def zero_shot_retrieval(model, dataloader, device):
    img_embeds = []
    imgs = []
    text_embeds = []
    texts = []
    coco_to_id = dict()

    for i, batch in track(enumerate(dataloader)):
        image = batch['image'].to(device)
        text = batch['text'].to(device)
        padding_mask = batch['padding_mask'].to(device) if 'padding_mask' in batch else None
        # encoding also normalizes the output
        img_emb = model.encode_image(image=image)['x']
        text_emb = model.encode_text(text=text, padding_mask=padding_mask)['x']
        img_embeds.append(img_emb)
        text_embeds.append(text_emb)
        imgs.append(batch['image'])
        texts.append(batch['text'])
        for i, img_id in enumerate(batch['id']):
            if img_id.item() not in coco_to_id:
                coco_to_id[img_id.item()] = [i]
            else:
                coco_to_id[img_id.item()].append(i)

    img_embeds = torch.cat(img_embeds, dim=0)
    text_embeds = torch.cat(text_embeds, dim=0)
    imgs = torch.cat(imgs, dim=0)
    texts = torch.cat(texts, dim=0)
    scores = img_embeds @ text_embeds.t()
    return scores, coco_to_id, imgs, texts

In [None]:
def retrieve_for_candidate(scores:torch.Tensor, idx_mapper, coco_id, from_modality, n=5):
    dim = 0 if from_modality == 'image' else 1
    idx = idx_mapper[coco_id]
    return scores[idx].topk(n, dim=dim, largest=True, sorted=True).indices

In [None]:
def plot_matches(data, from_modality, n):
    cols = n+1
    rows = len(data)

    fig, axes = plt.subplots(rows, cols, figsize=(18, 15))

    axes[0, 0].set_title("Query")
    for j in range(1, cols):
        axes[0, j].set_title(f"Retrieval #{j}")

    for i, key in enumerate(data.keys()):
        q = data[key][0]

        axes[i, 0].text(-0.2, 0.5, f"COCO #{key}", transform=axes[i, 0].transAxes, va='center', ha='right')

        if from_modality == 'image':
            axes[i, 0].imshow(q.permute(1, 2, 0))
        else:
            axes[i, 0].text(0.5, 0.5, q, fontsize=18, ha='center', va='center')
        axes[i, 0].axis('off')

        for j in range(1, cols):
            source = data[key][1][j-1]
            if from_modality == 'image':
                axes[i, j].text(0.5, 0.5, textwrap.fill(source, 15), fontsize=18, ha='center', va='center')
            else:
                axes[i, j].imshow(source.permute(1, 2, 0))
            axes[i, j].axis('off')

    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    line_x_position = (axes[0, 0].get_position().x1 + axes[0, 1].get_position().x0) / 2

    line = matplotlib.lines.Line2D([line_x_position, line_x_position], [0, 1], transform=fig.transFigure, color='black', linewidth=1)
    fig.add_artist(line)

    plt.show()

In [None]:
def visualize_retrievals(scores:torch.Tensor, idx_mapper, coco_ids, imgs, texts, from_modality, n=5):
    result_dict = dict()
    for coco_id in coco_ids:
        indices = retrieve_for_candidate(scores, idx_mapper, coco_id, from_modality, n)
        if from_modality == 'image':
            samples = [encoder.decode(sample) for sample in texts[indices]]
            result_dict[coco_id] = imgs[idx_mapper[coco_id]], samples
        else:
            samples = [sample for sample in imgs[indices]]
            result_dict[coco_id] = texts[idx_mapper[coco_id]], samples
    
    plot_matches(result_dict, from_modality, n)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

dm = COCOCaptionsDataModule(**coco_args)
dm.prepare_data()
dm.setup('test')
dataloaders = dm.test_dataloader()

In [None]:
model = SHRePreTrainingLightningModule.load_from_checkpoint(model_path).model

In [None]:
scores, coco_to_id, imgs, texts = zero_shot_retrieval(model, dataloaders, device)

In [None]:
coco_ids = []
from_modality = 'image'
n_to_retrieve = 5

In [None]:
visualize_retrievals(scores, coco_to_id, coco_ids, imgs, texts, from_modality, n_to_retrieve)