In [1]:
import datetime
import os
from glob import glob

import numpy as np
import torch
import torch.nn.functional as F

from omegaconf import OmegaConf
from torchinfo import summary
from app.config import root_dir
from app.metrics import BYUFbeta
from app.metrics.metrics import get_topk_by_id, thresholder, filter_negatives
from app.models.lightning import Net
from app.models import LNet
from app.utils import get_data, get_data_loader
from app.processings.post_processing import get_output_size, reconstruct, simple_nms

In [2]:
from numpy.typing import NDArray
from scipy.spatial import KDTree
from torchmetrics.utilities import dim_zero_cat

In [3]:
OmegaConf.register_new_resolver("root_dir", resolver=root_dir, replace=True)
OmegaConf.register_new_resolver("eval", resolver=eval, replace=True)

os.environ["ISTPUVM"] = "1"
os.environ["PJRT_DEVICE"] = "CPU"
os.environ["PT_XLA_DEBUG_LEVEL"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TPU_ACCELERATOR_TYPE"] = "v3-8"
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "2,2,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
os.environ["TPU_RUNTIME_METRICS_PORTS"] = "8431,8432,8433,8434"
os.environ["TPU_SKIP_MDS_QUERY"] = "1"
os.environ["TPU_WORKER_HOSTNAMES"] = "localhost"
os.environ["TPU_WORKER_ID"] = "0"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

In [4]:
cfg = OmegaConf.load("../src/app/config/config.yaml")

In [5]:
train_df, val_df = get_data(cfg, mode="fit")
train_loader = get_data_loader(cfg, train_df, mode="train")
val_loader = get_data_loader(cfg, val_df, mode="validation")

In [6]:
preds1 = torch.load("/kaggle/working/resnet50/seed_4294967295/fold0/20250513193109/val_epoch_7_end_step384_rank1.pt")
preds2 = torch.load("/kaggle/working/resnet50/seed_4294967295/fold0/20250513193056/val_epoch_7_end_step384_rank0.pt")
preds = torch.cat([preds1, preds2], dim=0)

In [7]:
targets = torch.from_numpy(val_df[["z", "y", "x", "id", "vxs"]].values)

In [8]:
topk  = get_topk_by_id(preds, targets)

In [9]:
ut_preds, candidates, ntargets, ptargets = thresholder(0.998, topk, targets)

In [10]:
metric = BYUFbeta(cfg, compute_on_cpu=True, dist_sync_on_step=True)

In [11]:
# zyxic = torch.load("")

In [12]:
# targets = torch.from_numpy(val_df[val_df.id.isin(np.unique(zyxic[:, 3]))][["z", "y", "x", "id", "vxs"]].copy().values)

In [13]:
# results = metric(zyxic, targets)

In [14]:
# results

In [18]:
test_outputs = torch.load("/kaggle/working/logits.pt", weights_only=False)

In [33]:
tr_it = iter(val_loader)

In [34]:
batch = next(tr_it)

In [36]:
batch["logits"] = torch.from_numpy(test_outputs)

In [37]:
batch = {k:v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

In [38]:
import copy

In [39]:
net_output = batch

In [40]:
device = net_output["logits"].device
new_size = torch.tensor(cfg.new_size, device=net_output["logits"].device)
roi_size = torch.tensor(cfg.roi_size, device=net_output["logits"].device)

In [41]:
img: "torch.Tensor" = net_output["logits"].detach()

In [42]:
locations: "torch.Tensor" = net_output["location"]
scales: "torch.Tensor" = net_output["scale"]
tomo_ids: "torch.Tensor" = torch.tensor(net_output["id"], device=device)

In [43]:
img = F.interpolate(
    img,
    size=roi_size.tolist(),
    mode="trilinear",
    align_corners=False,
)

In [44]:
out_size = get_output_size(img, locations, roi_size, device)
rec_img = reconstruct(
    img=img,
    locations=locations,
    out_size=out_size,
    crop_size=roi_size,
    device=device,
)

In [45]:
s = torch.tensor(rec_img.shape[-3:], device=device)
delta = (s - new_size) // 2  # delta to remove padding added during transforms
dz, dy, dx = delta.tolist()
nz, ny, nx = new_size.tolist()

rec_img = rec_img[:, :, dz : nz + dz, dy : ny + dy, dx : nx + dx]

rec_img = F.interpolate(
    rec_img,
    size=[d // 2 for d in new_size.tolist()],
    mode="trilinear",
    align_corners=False,
)

In [46]:
preds: "torch.Tensor" = rec_img.softmax(1)

In [64]:
preds[:, 0].median()

tensor(0.1270, device='cuda:0')

In [67]:
preds[:, 0].quantile(q=0.99)

tensor(0.8122, device='cuda:0')

In [None]:
preds0 = preds[:, 0, :][None,]

In [77]:
preds[:, 0].min()

tensor(5.5334e-11, device='cuda:0')

In [88]:
nms: "torch.Tensor" = simple_nms(preds, nms_radius=100)  # (1,B, D, H, W)
nms = nms.squeeze(dim=0)  # (B, D, H, W)

In [None]:
nms.max()

In [87]:
nms[nms>0.99]

tensor([0.9918, 0.9989, 0.9984, 0.9913, 0.9997, 0.9918, 0.9900, 0.9958, 0.9979,
        0.9990, 0.9998, 0.9998, 0.9999, 0.9986, 1.0000, 0.9993, 0.9975, 1.0000,
        0.9960, 0.9985, 0.9990, 0.9989, 0.9995, 0.9993, 0.9996, 0.9921, 0.9991,
        0.9932, 0.9999, 0.9982, 0.9904, 0.9940, 0.9990, 0.9999, 0.9967, 0.9996,
        0.9995, 0.9981, 0.9987, 0.9943, 0.9952, 0.9988, 0.9968, 0.9992],
       device='cuda:0')

In [None]:
def post_process_pipeline(
    cfg: "DictConfig", net_output: "dict[str, Any]"
) -> "torch.Tensor":
    """Post-process the output of the model to get the final coordinates and confidence scores.
    Args:
        cfg (DictConfig): Configuration object containing project's parameters.
        net_output (dict): The output of the model.
    Returns:
        torch.Tensor: The final coordinates and confidence scores.
    """
    device = net_output["logits"].device
    new_size = torch.tensor(cfg.new_size, device=net_output["logits"].device)
    roi_size = torch.tensor(cfg.roi_size, device=net_output["logits"].device)

    img: "torch.Tensor" = net_output["logits"].detach()

    locations: "torch.Tensor" = net_output["location"]
    scales: "torch.Tensor" = net_output["scale"]
    tomo_ids: "torch.Tensor" = torch.tensor(net_output["id"], device=device)

    img = F.interpolate(
        img,
        size=roi_size.tolist(),
        mode="trilinear",
        align_corners=False,
    )

    out_size = get_output_size(img, locations, roi_size, device)
    rec_img = reconstruct(
        img=img,
        locations=locations,
        out_size=out_size,
        crop_size=roi_size,
        device=device,
    )

    s = torch.tensor(rec_img.shape[-3:], device=device)
    delta = (s - new_size) // 2  # delta to remove padding added during transforms
    dz, dy, dx = delta.tolist()
    nz, ny, nx = new_size.tolist()

    rec_img = rec_img[:, :, dz : nz + dz, dy : ny + dy, dx : nx + dx]

    rec_img = F.interpolate(
        rec_img,
        size=[d // 2 for d in new_size.tolist()],
        mode="trilinear",
        align_corners=False,
    )

    preds: "torch.Tensor" = rec_img.softmax(1)
    preds = preds[:, 0, :][None,]

    nms: "torch.Tensor" = simple_nms(preds, nms_radius=cfg.nms_radius)  # (1,B, D, H, W)
    nms = nms.squeeze(dim=0)  # (B, D, H, W)

    flat_nms = nms.reshape(nms.shape[0], -1)  # (B, D*H*W)
    conf, indices = torch.topk(flat_nms, k=cfg.topk, dim=1)
    zyx = torch.stack(torch.unravel_index(indices, nms.shape[-3:]), dim=-1)  # (B, K, 3)

    b = torch.arange(zyx.shape[0], device=device).unsqueeze(1)
    ids = torch.unique(tomo_ids.reshape(zyx.shape[0], -1), dim=1).expand(
        zyx.shape[0], cfg.topk
    )

    zyx = ((zyx * 2) / scales[b]).round().to(torch.int)
    conf = conf.to(torch.float32)

    ids = ids.reshape(-1, 1)
    conf = conf.reshape(-1, 1)
    zyx = zyx.reshape(-1, 3)

    output: "torch.Tensor" = torch.cat([zyx, ids, conf], dim=1)
    return output
