In [None]:
import os

os.chdir("../")

from pathlib import Path

import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import colormaps as cm

cmap = cm.get_cmap("viridis")

hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(version_base="1.3", config_path="../configs/")
cfg = hydra.compose(config_name="eval.yaml")

In [None]:
root_dir = Path("logs/calibrate/runs")
ckpt_path = root_dir / "<experiment folder>/cmodel.ckpt"

model = hydra.utils.instantiate(cfg.model)
checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint["state_dict"])
nc_curves = checkpoint["nc_curves"]

In [None]:
cfg.data.batch_size = 6
dm = hydra.utils.instantiate(cfg.data)
dm.setup("test")

dl = dm.test_dataloader()

batch = next(iter(dl))

print(batch["image"].shape)

out = model(batch["image"])

In [None]:
seg_pprobs = torch.sigmoid(out["seg_logits"])

class_idx = 1
fg_pprobs = seg_pprobs[:, class_idx]
fg_ncs = 1 - fg_pprobs
fg_nc_curves = nc_curves[:, class_idx]
fg_confs = torch.zeros_like(fg_pprobs)

alpha = 0.1
mask = fg_ncs <= fg_nc_curves[int((1 - alpha) * 100)]

for i in range(batch["image"].shape[0]):
    img = batch["image"][i].permute(1, 2, 0).numpy()
    gt_mask = batch["target_segmentation"][i][class_idx]
    conformal_predicted_mask = mask[i]

    fig, ax = plt.subplots(1, 3)
    ax[0].imshow(img)

    ax[1].imshow(gt_mask, cmap="gray")

    ax[2].imshow(conformal_predicted_mask, cmap="gray")

    for a in ax:
        a.axis("off")