In [1]:
import sys
sys.path.append('..')
sys.path.append('../..')
sys.path.append('../beit2')
from datamodules import DATAMODULE_REGISTRY
from models import MODEL_REGISTRY
import torch
from pytorch_lightning import LightningModule
import pytorch_lightning as pl
from rich.progress import track
import matplotlib
import matplotlib.pyplot as plt
from transformers import BertTokenizer
plt.rcParams["axes.axisbelow"] = False
matplotlib.rcParams.update({'font.size': 12})

[2024-08-24 11:54:12,084] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


  def beit_base_patch16_224(pretrained=False, **kwargs):
  def beit_base_patch16_384(pretrained=False, **kwargs):
  def beit_large_patch16_224(pretrained=False, **kwargs):
  def beit_large_patch16_384(pretrained=False, **kwargs):
  def beit_large_patch16_512(pretrained=False, **kwargs):
2024-08-24 11:54:13 | INFO | datasets | PyTorch version 2.2.0 available.


In [2]:
pl.seed_everything(0)

Seed set to 0


0

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
@torch.no_grad()
def get_embeddings(model, dataloader, device):
    img_embeds = []
    text_embeds = []
    img_ids = []
    images = []
    texts = []

    for batch in track(dataloader):
        image = batch['image'].to(device)
        text = batch['text'].to(device)
        padding_mask = batch['padding_mask'].to(device) if 'padding_mask' in batch else None
        # encoding also normalizes the output
        img_emb = model.encode_image(image=image)['x']
        text_emb = model.encode_text(text=text, padding_mask=padding_mask)['x']
        img_embeds.append(img_emb)
        text_embeds.append(text_emb)
        img_ids.append(batch['id'].to(device))
        images.append(batch['image_raw'])
        texts.append(text)

    return img_embeds, text_embeds, images, texts, img_ids

In [None]:
def get_scores(img_embeds, text_embeds, images, texts, img_ids):
    image_feats = {} # collect all unique image features, and create mapping based on id
    raw_images = {}
    for feats, ids, img in zip(img_embeds, img_ids, images):
        for i, _idx in enumerate(ids):
            idx = _idx.item()
            if idx not in image_feats:
                image_feats[idx] = feats[i]
                raw_images[idx] = img[i]

    tiids = torch.cat(img_ids, dim=0)
    iids = []
    sorted_tensors = []
    sorted_images = []
    for key in sorted(image_feats.keys()):
        sorted_tensors.append(image_feats[key].view(1, -1))
        sorted_images.append(raw_images[key])
        iids.append(key)

    img_embeds = torch.cat(sorted_tensors, dim=0)
    images = torch.cat(sorted_images, dim=0)
    text_embeds = torch.cat(text_embeds, dim=0)

    scores = img_embeds @ text_embeds.t()
    iids = torch.LongTensor(iids).to(scores.device)

    return scores, images, texts, iids, tiids

In [None]:
def plot_text_retrievals(n_queries, n_retrievals, scores, images, text, iids, tiids):
    num_cols = n_retrievals+2
    num_rows = len(n_queries)

    query_indices = torch.randperm((scores.shape[0]))[:n_queries]

    topk_retrievals = scores.topk(n_retrievals, dim=1).indices
    
    topk_retrievals = topk_retrievals[query_indices]
    query_coco_ids = iids[query_indices]
    query_samples = images[query_indices]

    retrieved_samples = []
    retrieval_coco_ids = []
    for idx in topk_retrievals:
        retrieved_samples.append(text[idx])
        retrieval_coco_ids.append(tiids[idx])

    _, axes = plt.subplots(num_rows, num_cols, figsize=(18, 3*num_rows))

    axes[0, 0].set_title("COCO ID")
    axes[0, 1].set_title("Query")
    for j in range(2, num_cols):
        axes[0, j].set_title(f"Retrieval {j+1}")

    for i, idx, sample in enumerate(zip(query_coco_ids, query_samples)):
        axes[i, 0].text(0.5, 0.5, idx.item(), ha='center', va='center', fontsize=12)
        axes[i, 0].axis('off')
        axes[i, 1].imshow(sample.permute(1, 2, 0))
        axes[i, 1].axis('off')

    for i, indices, samples in enumerate(zip(retrieval_coco_ids, retrieved_samples)):
        samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
        for j in range(n_retrievals):
            axes[i, j].text(0.5, 0.5, samples[j], ha='center', va='center', fontsize=10)
            axes[i, j].axis('off')

            if indices[j].item() == query_coco_ids[i].item():
                color = 'green'
            else:
                color = 'red'
            for spine in axes[i, j].spines.values():
                spine.set_edgecolor(color)
                spine.set_linewidth(2)


    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    plt.show()

In [None]:
def plot_image_retrievals(n_queries, n_retrievals, scores, images, text, iids, tiids):
    num_cols = n_retrievals+2
    num_rows = len(n_queries)

    scores = scores.t()

    query_indices = torch.randperm((scores.shape[0]))[:n_queries]

    topk_retrievals = scores.topk(n_retrievals, dim=1).indices
    
    topk_retrievals = topk_retrievals[query_indices]
    query_coco_ids = tiids[query_indices]
    query_samples = text[query_indices]

    retrieved_samples = []
    retrieval_coco_ids = []
    for idx in topk_retrievals:
        retrieved_samples.append(images[idx])
        retrieval_coco_ids.append(iids[idx])

    _, axes = plt.subplots(num_rows, num_cols, figsize=(18, 3*num_rows))

    axes[0, 0].set_title("COCO ID")
    axes[0, 1].set_title("Query")
    for j in range(2, num_cols):
        axes[0, j].set_title(f"Retrieval {j+1}")

    query_samples = tokenizer.batch_decode(query_samples, skip_special_tokens=True)
    for i, idx, sample in enumerate(zip(query_coco_ids, query_samples)):
        axes[i, 0].text(0.5, 0.5, idx.item(), ha='center', va='center', fontsize=12)
        axes[i, 0].axis('off')
        axes[i, 1].text(0.5, 0.5, sample, ha='center', va='center', fontsize=10)
        axes[i, 1].axis('off')

    for i, indices, samples in enumerate(zip(retrieval_coco_ids, retrieved_samples)):
        for j in range(n_retrievals):
            axes[i, j].imshow(samples[j].permute(1, 2, 0))
            axes[i, j].axis('off')

            if indices[j].item() == query_coco_ids[i].item():
                color = 'green'
            else:
                color = 'red'
            for spine in axes[i, j].spines.values():
                spine.set_edgecolor(color)
                spine.set_linewidth(2)


    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    plt.show()

In [4]:
MODEL_PATH = "/workspace/models/cluster.pt"
MODEL_NAME = ""

In [5]:
coco_dm_kwargs = {
    'data_path': '/workspace',
    'num_max_bpe_tokens': 64,
    'color_jitter': None,
    'beit_transforms': False,
    'crop_scale': [1.0, 1.0],
    'batch_size': 256,
    'num_workers': 8,
    'shuffle': True,
    'drop_last': False,
}

In [6]:
pl.seed_everything(42)
coco_dm = DATAMODULE_REGISTRY['coco_captions'](**coco_dm_kwargs)

Seed set to 42


In [7]:
coco_dm.prepare_data()
coco_dm.setup('fit')

2024-08-24 11:54:20 | INFO | datasets_.base_datasets | [COCOCaptions]: Data already exists under: /workspace/coco
2024-08-24 11:54:20 | INFO | datasets_.base_datasets | [COCOCaptions]: Data already exists under: /workspace/coco
2024-08-24 11:54:20 | INFO | datasets_.base_datasets | [COCOCaptions]: Data already exists under: /workspace/coco
2024-08-24 11:54:24 | INFO | datasets_.base_datasets | [COCOCaptions]: Load 566747 image-text pairs from /workspace/coco/coco_captioning.train.jsonl. 
2024-08-24 11:54:24 | INFO | datasets_.base_datasets | [COCOCaptions]: Load 25010 image-text pairs from /workspace/coco/coco_captioning.val.jsonl. 


In [8]:
dl = iter(coco_dm.test_dataloader())

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_cls:LightningModule = MODEL_REGISTRY[MODEL_NAME]['module']
model = model_cls.load_from_checkpoint(MODEL_PATH).model
model = model.to(device)
model.requires_grad_(False)
model.eval()

KeyError: 'vq_image'

In [None]:
scores, images, texts, iids, tiids = get_scores(*get_embeddings(model, dl, device))

In [None]:

mask = iids.unsqueeze(1) == tiids.unsqueeze(0)
selected_image_ids = torch.randperm(mask.shape[0])[:1000]

selected_indices = []
for row in mask[selected_image_ids]:
    true_indices = torch.nonzero(row, as_tuple=False).squeeze()
    
    if true_indices.numel() == 1:
        selected_indices.append(true_indices.item())
    elif true_indices.numel() > 0:
        selected_index = true_indices[torch.randint(0, len(true_indices), (1,))]
        selected_indices.append(selected_index.item())
    else:
        raise ValueError("No matching indices found")

selected_indices = torch.tensor(selected_indices)

selected_texts = texts[selected_indices]
selected_images = images[selected_image_ids]
iids_ = tiids_ = torch.arange(selected_images.shape[0])