In [None]:
import hydra
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from omegaconf import OmegaConf

from mantis.configs.config import Config

with hydra.initialize(version_base=None, config_path="../conf/"):
    cfg = hydra.compose(config_name="config")
    OmegaConf.resolve(cfg)
    config = Config(**cfg)

In [None]:
from mantis.data.datasets import get_distributed_dataloader, setup_dataset

train_ds, val_ds = setup_dataset(config.dataset, device=0)
train_loader = get_distributed_dataloader(train_ds, 0, config, is_train=True)
train_loader = iter(train_loader)

In [None]:
batch = next(train_loader)

In [None]:
from torch import Tensor

from mantis.model.patch_extractor import extract_grid_patches


def vis_batch(batch: dict[str, Tensor], ind: int) -> None:
    images = batch["image"]
    _, _, H, W = images.shape
    Hp = H // config.patch_size
    Wp = W // config.patch_size
    img = (images[ind] * 255).round().int().clip(0, 255).numpy().astype(np.uint8)
    patches, locs = extract_grid_patches(images, config.patch_size)
    print("patches:", patches.shape, patches.dtype, patches.min().item(), patches.max().item())
    print("locs:", locs.shape, locs.dtype, locs.min().item(), locs.max().item())
    locs_vis = (locs[..., :2] + 1) / 2
    locs_x = (locs_vis[ind, ..., 0] * W).flatten()
    locs_y = (locs_vis[ind, ..., 1] * H).flatten()
    plt.imshow(img.transpose(1, 2, 0))
    plt.axis("off")
    plt.scatter(locs_x, locs_y, c="r")
    plt.show()

    patch_vis = patches[ind].reshape(Hp, Wp, 3, config.patch_size, config.patch_size) # (Hp, Wp, 3, patch_size, patch_size)
    patch_vis = (patch_vis * 255).round().int().clip(0, 255)
    fig, axes = plt.subplots(Hp, Wp)
    for i in range(Hp):
        for j in range(Wp):
            patch = patch_vis[i, j].permute(1, 2, 0).numpy().astype(np.uint8)
            axes[i, j].imshow(patch)
            axes[i, j].axis("off")
    plt.show()

In [None]:
vis_batch(batch, 0)