In [None]:
import sys
sys.path.append('..')
sys.path.append('../..')
sys.path.append('../beit2')
from datamodules import DATAMODULE_REGISTRY
from models import MODEL_REGISTRY
import torch
from pytorch_lightning import LightningModule
import torch.nn as nn
import pytorch_lightning as pl
from data.imagenet_zeroshot_data import imagenet_classnames

import matplotlib.pyplot as plt
plt.rcParams["axes.axisbelow"] = False
import numpy as np

from bpe_encoder import get_bpe_encoder

In [None]:
def plot_prob_distribution(input, image, top_probs, top_classes): 
    fig = plt.figure(figsize=(8, 4))

    gs = fig.add_gridspec(1, 2, wspace=0, hspace=0.3)

    idx = 0
    for i in range(gs.nrows):
        ax = fig.add_subplot(gs[i, 0])
        if image:
            ax.imshow(input)
        else:
            ax.text(0.5, 0.5, input, fontsize=10, ha='center', va='center', transform=ax.transAxes)
            rect = plt.Rectangle((0, 0), 1, 1, transform=ax.transAxes,
                     color="none", ec="black", lw=1)
            ax.add_patch(rect)
        ax.axis("off")

        ax = fig.add_subplot(gs[i, 1])
        ax.barh(np.arange(5), top_probs[idx])
        ax.set_xlim(0, 1)
        # ax.invert_yaxis()
        # ax.set_axisbelow(True)
        ax.tick_params(axis='y', direction='in', pad=-30)
        ax.set_yticks(np.arange(top_probs.shape[-1]), top_probs[idx])

    plt.show()

In [None]:
MODEL_PATH = ""

In [None]:
encoder = get_bpe_encoder('../data')

In [None]:
coco_dm_kwargs = {
    'data_path': '../../data',
    'num_max_bpe_tokens': 64,
    'color_jitter': None,
    'beit_transforms': False,
    'crop_scale': [1.0, 1.0],
    'batch_size': 4,
    'num_workers': 1,
    'shuffle': True,
    'drop_last': False,
}

In [None]:
pl.seed_everything(42)
coco_dm = DATAMODULE_REGISTRY['coco_captions'](**coco_dm_kwargs)

In [None]:
coco_dm.prepare_data()
coco_dm.setup('test')

In [None]:
dl = iter(coco_dm.test_dataloader())

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_cls:LightningModule = MODEL_REGISTRY['SHRe']['module']
model = model_cls.load_from_checkpoint(MODEL_PATH).model
model = model.to(device)
model.requires_grad_(False)
model.eval()

In [None]:
batch = next(dl)
with torch.no_grad():
    img_out = model.encode_image(batch['image'])['encoder_out']
    text_out = model.encode_text(batch['text'], batch['padding_mask'])['encoder_out']
img_probs, img_labels = img_out.cpu().topk(5, dim=-1)
text_probs, text_labels = text_out.cpu().topk(5, dim=-1)
img_top_classes = [imagenet_classnames[label] for label in img_labels]
text_top_classes = [imagenet_classnames[label] for label in text_labels]

In [130]:
for i in range(img_out.shape[0]):
    plot_prob_distribution(batch['image'][i], True, img_probs[i], img_top_classes[i])
    plot_prob_distribution(encoder.decode(batch['text'][i]), False, text_probs[i], text_top_classes[i])