In [44]:
import os
import time
from pathlib import Path

import torch
from omegaconf import OmegaConf
import sys


SRC_ROOT = Path("..").resolve()
PROJECT_ROOT = SRC_ROOT.parent
if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))
from data_utils import get_data_module
from model.unet import get_unet_module
from model.score_adapter import get_score_prediction_module

# --- Config (mirrors train_score_predictor.sh defaults) ---
DATASET = "pmri"  # "mnmv2" or "pmri"
SPLIT_MAP = {"mnmv2": "scanner-symphonytim", "pmri": "promise12"}
PERTURBATION = "predictor"
PREDICTOR_PARAMETER = "location+concentration"
SWIVELS = "first-embedding-last"
LIKELIHOOD = "beta"
LOSS_FN = "dice"
WARMUP = 5
BATCH_SIZE = 15

repo_root = Path("/root/workspace/repos/QualityControl")
data_cfg = OmegaConf.load(repo_root / "src/configs/data" / f"{DATASET}.yaml")
unet_cfg = OmegaConf.load(repo_root / "src/configs/unet/monai_unet.yaml")
score_cfg = OmegaConf.load(repo_root / "src/configs/model/score_predictor.yaml")

data_cfg.dataset = DATASET
data_cfg.split = SPLIT_MAP[DATASET]
data_cfg.batch_size = BATCH_SIZE
data_cfg.non_empty_target = True
data_cfg.train_transforms = "global_transforms"

score_cfg.loss_fn = LOSS_FN
score_cfg.perturbations = PERTURBATION
score_cfg.predictor_parameter = PREDICTOR_PARAMETER
score_cfg.likelihood = LIKELIHOOD
score_cfg.swivels = SWIVELS

if DATASET == "mnmv2":
    unet_cfg.out_channels = 4
    score_cfg.num_classes = 4
elif DATASET == "pmri":
    unet_cfg.out_channels = 1
    score_cfg.num_classes = 2
else:
    raise ValueError(f"Unsupported dataset: {DATASET}")

if not torch.cuda.is_available():
    raise RuntimeError("CUDA is required for score adapter shape inference.")
device = torch.device("cuda")



In [45]:
# --- Data ---
dm = get_data_module(data_cfg)
dm.prepare_data()
dm.setup("fit")
train_loader = dm.train_dataloader()
batch = next(iter(train_loader))
x = batch["input"]
if not torch.is_tensor(x):
    x = torch.as_tensor(x)
x = x.to(device)


In [53]:
x = batch["input"]
if not torch.is_tensor(x):
    x = torch.as_tensor(x)
x = x.to(device)

In [None]:
ITERS = 1000
# --- Models (untrained) ---
unet_module = get_unet_module(
    cfg=unet_cfg,
    metadata={"unet": OmegaConf.to_container(unet_cfg)},
    load_from_checkpoint=False,
)
unet = unet_module.model.to(device)
unet.eval()

score_module = get_score_prediction_module(
    data_cfg=data_cfg,
    model_cfg=score_cfg,
    unet=unet,
    metadata={"model": OmegaConf.to_container(score_cfg)},
    ckpt=None,
)
score_wrapper = score_module.wrapper.to(device)
score_wrapper.eval() 

def time_forward(fn, x, iters=30, warmup=5):
    with torch.inference_mode():
        for _ in range(warmup):
            _ = fn(x)
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        for _ in range(iters):
            _ = fn(x)
        torch.cuda.synchronize()
    return (time.perf_counter() - t0) / iters

t_unet = time_forward(unet, x, iters=ITERS, warmup=WARMUP)
t_score = time_forward(score_wrapper, x, iters=ITERS, warmup=WARMUP)
fps_unet = 1 / t_unet
fps_score = 1 / t_score

print(f"Dataset={DATASET}, batch={BATCH_SIZE}")
print(f"UNet only: {t_unet*1000:.2f} ms/iter | {fps_unet:.2f} fps")
print(f"UNet + score adapter: {t_score*1000:.2f} ms/iter | {fps_score:.2f} fps")

Location target score delta: 1.0
Dataset=pmri, batch=15
UNet only: 15.22 ms/iter | 65.71 fps
UNet + score adapter: 20.53 ms/iter | 48.72 fps
