In [1]:
import datetime
import os
from glob import glob

import numpy as np
import torch
import torch.nn.functional as F

from omegaconf import OmegaConf
from torchinfo import summary
from app.metrics import BYUFbeta
from app.metrics.metrics import get_topk_by_id, thresholder, filter_negatives
from app.models.lightning import Net
from app.models import LNet
from app.utils import get_data, get_data_loader
from app.processings.post_processing import get_output_size, reconstruct, simple_nms

In [2]:
# state =  torch.load("/kaggle/working/resnet10/version_17/checkpoints/last.ckpt", map_location="cpu")

In [3]:
from numpy.typing import NDArray
from scipy.spatial import KDTree
from torchmetrics.utilities import dim_zero_cat

In [4]:
OmegaConf.register_new_resolver("eval", resolver=eval, replace=True)

os.environ["ISTPUVM"] = "1"
os.environ["PJRT_DEVICE"] = "CPU"
os.environ["PT_XLA_DEBUG_LEVEL"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TPU_ACCELERATOR_TYPE"] = "v3-8"
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "2,2,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
os.environ["TPU_RUNTIME_METRICS_PORTS"] = "8431,8432,8433,8434"
os.environ["TPU_SKIP_MDS_QUERY"] = "1"
os.environ["TPU_WORKER_HOSTNAMES"] = "localhost"
os.environ["TPU_WORKER_ID"] = "0"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

In [5]:
cfg = OmegaConf.load("../src/app/config/config.yaml")

In [11]:
cfg.val_persistent_workers = True

In [12]:
train_df, val_df = get_data(cfg, mode="fit")
train_loader = get_data_loader(cfg, train_df, mode="train")
val_loader = get_data_loader(cfg, val_df, mode="validation")

In [7]:
targets = torch.from_numpy(val_df[["z", "y", "x", "id", "vxs"]].values)

In [9]:
test_outputs = torch.load("/kaggle/working/logits422.pt", weights_only=False)

In [48]:
tr_it = iter(val_loader)

In [80]:
for batch in range(len(val_loader)):
    batch = next(tr_it)
    if batch["id"][0] != 38:
        continue
    break

In [82]:
batch["logits"] = torch.from_numpy(test_outputs)

In [83]:
batch = {k:v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

In [95]:
net_output = batch

In [96]:
device = net_output["logits"].device
new_size = torch.tensor(cfg.new_size, device=net_output["logits"].device)
roi_size = torch.tensor(cfg.roi_size, device=net_output["logits"].device)

In [97]:
img: "torch.Tensor" = net_output["logits"].detach()

In [101]:
sm = img.softmax(1)

In [107]:
sm[:, 1, :].min()

tensor(0., device='cuda:0')

In [87]:
locations: "torch.Tensor" = net_output["location"]
scales: "torch.Tensor" = net_output["scale"]
tomo_ids: "torch.Tensor" = torch.tensor(net_output["id"], device=device)

In [88]:
img = F.interpolate(
    img,
    size=roi_size.tolist(),
    mode="trilinear",
    align_corners=False,
)

In [89]:
out_size = get_output_size(img, locations, roi_size, device)
rec_img = reconstruct(
    img=img,
    locations=locations,
    out_size=out_size,
    crop_size=roi_size,
    device=device,
)

In [90]:
s = torch.tensor(rec_img.shape[-3:], device=device)
delta = (s - new_size) // 2  # delta to remove padding added during transforms
dz, dy, dx = delta.tolist()
nz, ny, nx = new_size.tolist()

rec_img = rec_img[:, :, dz : nz + dz, dy : ny + dy, dx : nx + dx]

rec_img = F.interpolate(
    rec_img,
    size=[d // 2 for d in new_size.tolist()],
    mode="trilinear",
    align_corners=False,
)

In [91]:
preds: "torch.Tensor" = rec_img.softmax(1)

In [92]:
preds[:, 0, :].max()

tensor(1., device='cuda:0')

In [93]:
preds[:, 1, :].max()

tensor(1.3958e-19, device='cuda:0')

In [94]:
val_df[val_df["id"] == 38]

Unnamed: 0,row_id,tomo_id,z,y,x,axis0,axis1,axis2,vxs,n_motors,id,fold
45,45,tomo_0fe63f,197.0,362.0,265.0,300,960,928,13.1,1,38,0


In [79]:
val_df[val_df["n_motors"] == 2]

Unnamed: 0,row_id,tomo_id,z,y,x,axis0,axis1,axis2,vxs,n_motors,id,fold
127,127,tomo_2b3cdf,134.0,173.0,662.0,300,960,928,13.1,2,107,0
232,232,tomo_507b7a,451.0,561.0,366.0,500,928,960,13.1,2,200,0
233,233,tomo_507b7a,427.0,384.0,353.0,500,928,960,13.1,2,200,0
478,478,tomo_a84050,80.0,771.0,150.0,300,928,928,13.1,2,419,0
479,479,tomo_a84050,150.0,458.0,606.0,300,928,928,13.1,2,419,0


In [None]:
preds0 = preds[:, 1, :][None,]

In [77]:
preds[:, 0].min()

tensor(5.5334e-11, device='cuda:0')

In [88]:
nms: "torch.Tensor" = simple_nms(preds, nms_radius=100)  # (1,B, D, H, W)
nms = nms.squeeze(dim=0)  # (B, D, H, W)

In [None]:
nms.max()

In [87]:
nms[nms>0.99]

tensor([0.9918, 0.9989, 0.9984, 0.9913, 0.9997, 0.9918, 0.9900, 0.9958, 0.9979,
        0.9990, 0.9998, 0.9998, 0.9999, 0.9986, 1.0000, 0.9993, 0.9975, 1.0000,
        0.9960, 0.9985, 0.9990, 0.9989, 0.9995, 0.9993, 0.9996, 0.9921, 0.9991,
        0.9932, 0.9999, 0.9982, 0.9904, 0.9940, 0.9990, 0.9999, 0.9967, 0.9996,
        0.9995, 0.9981, 0.9987, 0.9943, 0.9952, 0.9988, 0.9968, 0.9992],
       device='cuda:0')