In [None]:
import sys
sys.path.append('..')
sys.path.append('../..')
sys.path.append('../beit2')
from datamodules import DATAMODULE_REGISTRY
from models import MODEL_REGISTRY
from models.image_vq import ImageVQ
import torch
from pytorch_lightning import LightningModule
import torch.nn as nn
import pytorch_lightning as pl
from rich.progress import track
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams["axes.axisbelow"] = False
matplotlib.rcParams.update({'font.size': 12})

In [None]:
pl.seed_everything(42)

In [None]:
def plot_vq_classes(images, classes, classes_to_plot):
    images_per_row = 6
    num_rows = len(classes_to_plot)

    _, axes = plt.subplots(num_rows, images_per_row, figsize=(18, 3*num_rows))

    for j in range(images_per_row):
        axes[0, j].set_title(f"Example {j+1}")

    for i, label in enumerate(classes_to_plot):
        axes[i, 0].text(-0.2, 0.5, label, transform=axes[i, 0].transAxes, va='center', ha='right')

    for class_idx in classes_to_plot:
        indices = (classes == class_idx).nonzero(as_tuple=False).squeeze()
        matched_images = images[indices]
        for img in matched_images:
            axes[i, j].imshow(img.permute(1, 2, 0))
            axes[i, j].axis('off')

    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    plt.show()

In [None]:
MODEL_PATH = ""

In [None]:
coco_dm_kwargs = {
    'data_path': '../../data',
    'num_max_bpe_tokens': 64,
    'color_jitter': None,
    'beit_transforms': False,
    'crop_scale': [1.0, 1.0],
    'batch_size': 256,
    'num_workers': 8,
    'shuffle': True,
    'drop_last': False,
}

In [None]:
pl.seed_everything(42)
coco_dm = DATAMODULE_REGISTRY['coco_captions'](**coco_dm_kwargs)

In [None]:
coco_dm.prepare_data()
coco_dm.setup('fit')

In [None]:
dl = iter(coco_dm.val_dataloader())

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_cls:LightningModule = MODEL_REGISTRY['vq_image']['module']
model:ImageVQ = model_cls.load_from_checkpoint(MODEL_PATH).model
model = model.to(device)
model.requires_grad_(False)
model.eval()

In [None]:
all_class_indices = []
all_images = []
with torch.no_grad():
    for batch in track(dl):
        images = batch['image']
        images_raw = batch['image_raw']
        class_indices = model.quantize_image(images.to(device))['embed_ind']
        all_class_indices.append(class_indices)
        all_images.append(images_raw)

all_class_indices = torch.cat(all_class_indices)
all_images = torch.cat(all_images)

In [None]:
plot_vq_classes(images=all_images, classes=all_class_indices, classes_to_plot=list(range(10)))