In [None]:
import torch
import torch.nn.functional as F

from omegaconf import OmegaConf
from app.utils import get_data, get_data_loader
from app.processings.post_processing import get_output_size, reconstruct, simple_nms

In [None]:
# state =  torch.load("/kaggle/working/resnet10/version_17/checkpoints/last.ckpt", map_location="cpu")

In [None]:
from numpy.typing import NDArray
from scipy.spatial import KDTree
from torchmetrics.utilities import dim_zero_cat

In [None]:
OmegaConf.register_new_resolver("eval", resolver=eval, replace=True)

os.environ["ISTPUVM"] = "1"
os.environ["PJRT_DEVICE"] = "CPU"
os.environ["PT_XLA_DEBUG_LEVEL"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TPU_ACCELERATOR_TYPE"] = "v3-8"
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "2,2,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
os.environ["TPU_RUNTIME_METRICS_PORTS"] = "8431,8432,8433,8434"
os.environ["TPU_SKIP_MDS_QUERY"] = "1"
os.environ["TPU_WORKER_HOSTNAMES"] = "localhost"
os.environ["TPU_WORKER_ID"] = "0"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

In [None]:
cfg = OmegaConf.load("../src/app/config/config.yaml")

In [None]:
cfg.val_persistent_workers = True

In [None]:
train_df, val_df = get_data(cfg, mode="fit")
train_loader = get_data_loader(cfg, train_df, mode="train")
val_loader = get_data_loader(cfg, val_df, mode="validation")

In [None]:
targets = torch.from_numpy(val_df[["z", "y", "x", "id", "vxs"]].values)

In [None]:
test_outputs = torch.load("/kaggle/working/logits38.pt", weights_only=False)

In [None]:
tr_it = iter(val_loader)

In [None]:
for batch in range(len(val_loader)):
    batch = next(tr_it)
    if batch["id"][0] != 38:
        continue
    break

In [None]:
batch["logits"] = torch.from_numpy(test_outputs)

In [None]:
batch = {
    k: v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in batch.items()
}

In [None]:
net_output = batch

In [None]:
device = net_output["logits"].device
new_size = torch.tensor(cfg.new_size, device=net_output["logits"].device)
roi_size = torch.tensor(cfg.roi_size, device=net_output["logits"].device)

In [None]:
img: "torch.Tensor" = net_output["logits"].detach()

In [None]:
batch["target"].shape

In [None]:
locations: "torch.Tensor" = net_output["location"]
scales: "torch.Tensor" = net_output["scale"]
tomo_ids: "torch.Tensor" = torch.tensor(net_output["id"], device=device)

In [None]:
out_size = get_output_size(img, locations, roi_size, device)
rec_img = reconstruct(
    img=img,
    locations=locations,
    out_size=out_size,
    crop_size=roi_size,
    device=device,
)

In [None]:
s = torch.tensor(rec_img.shape[-3:], device=device)
delta = (s - new_size) // 2  # delta to remove padding added during transforms
dz, dy, dx = delta.tolist()
nz, ny, nx = new_size.tolist()

rec_img = rec_img[:, :, dz : nz + dz, dy : ny + dy, dx : nx + dx]

rec_img = F.interpolate(
    rec_img,
    size=[d // 2 for d in new_size.tolist()],
    mode="trilinear",
    align_corners=False,
)

In [None]:
preds: "torch.Tensor" = rec_img.softmax(1)

In [None]:
preds0 = preds[:, 1, :][None,]

In [None]:
nms: "torch.Tensor" = simple_nms(preds0, nms_radius=100)  # (1,B, D, H, W)
nms = nms.squeeze(dim=0)  # (B, D, H, W)