In [None]:
import hydra
import matplotlib.pyplot as plt
import numpy as np
from omegaconf import OmegaConf


In [None]:
from mantis.configs.config import Config

with hydra.initialize(version_base=None, config_path="../conf/"):
    cfg = hydra.compose(config_name="config")
    OmegaConf.resolve(cfg)
    config = Config(**cfg)

In [None]:
from mantis.data.datasets import setup_dataset

train_ds, val_ds = setup_dataset(config.dataset)

In [None]:
from torch.utils.data import DataLoader

B = 2

loader = DataLoader(train_ds, batch_size=B, shuffle=True)

In [None]:
import torch

from mantis.model.patch_extractor import extract_equal_zoom_patches


def vis_patches(batch: dict[str, torch.Tensor], z: float):
    P = 16
    eps = 1e-6
    imgs = batch["image"]
    if (imgs[0, 0, ...] - imgs[0, 1, ...]).abs().max() < eps and (imgs[0, 1, ...] - imgs[0, 2, ...]).abs().max() < eps:
        print("First image is grayscale")

    if (imgs[1, 0, ...] - imgs[1, 1, ...]).abs().max() < eps and (imgs[1, 1, ...] - imgs[1, 2, ...]).abs().max() < eps:
        print("Second image is grayscale")
    imgs = (imgs * 255).round().int().clip(0, 255)
    imgs = imgs.permute(0, 2, 3, 1).numpy().astype(np.uint8)
    locs = torch.rand((B, 2))
    scaled_locs = locs * 2 - 1
    patches = extract_equal_zoom_patches(batch["image"], locs=scaled_locs, z=z, patch_size=P)
    inp_dim = min(config.dataset.img_height, config.dataset.img_width)
    size = inp_dim / (2 ** z)

    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(imgs[0])
    axs[0].set_axis_off()
    axs[0].scatter(locs[0, 0] * config.dataset.img_width, locs[0, 1] * config.dataset.img_height, c="red")
    min_loc = (locs[0, 0] * config.dataset.img_width - size / 2).item(), (locs[0, 1] * config.dataset.img_height - size / 2).item()
    axs[0].add_patch(
        plt.Rectangle(min_loc, size, size, fill=False, edgecolor="red", linewidth=2)
    )

    axs[1].imshow(imgs[1])
    axs[1].set_axis_off()
    axs[1].scatter(locs[1, 0] * config.dataset.img_width, locs[1, 1] * config.dataset.img_height, c="red")
    min_loc = (locs[1, 0] * config.dataset.img_width - size / 2).item(), (locs[1, 1] * config.dataset.img_height - size / 2).item()
    axs[1].add_patch(
        plt.Rectangle(min_loc, size, size, fill=False, edgecolor="red", linewidth=2)
    )
    plt.show()

    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    patch = (patches[0] * 255).round().int().permute(1, 2, 0).numpy().astype(np.uint8)
    axs[0].imshow(patch)
    axs[0].set_axis_off()

    patch = (patches[1] * 255).round().int().permute(1, 2, 0).numpy().astype(np.uint8)
    axs[1].imshow(patch)
    axs[1].set_axis_off()

    plt.show()

In [None]:
loader = iter(loader)

In [None]:
batch = next(loader)
vis_patches(batch, z=4)