In [19]:
import os
from glob import glob
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import torch
from torchvision.io import decode_jpeg, read_file
from tqdm.auto import tqdm
from app.models.lightning import Net
from omegaconf import OmegaConf
from app.trainers import get_lightning_trainer
from app.utils import get_callbacks, get_data, get_data_loader, set_seed
from app.models import LNet
from app.processings import post_process_pipeline
from app.processings.post_processing import get_output_size, reconstruct, simple_nms

In [32]:
import torch.nn.functional as F

In [9]:
os.environ["ISTPUVM"] = "1"
os.environ["PJRT_DEVICE"] = "TPU"
os.environ["PT_XLA_DEBUG_LEVEL"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TPU_ACCELERATOR_TYPE"] = "v3-8"
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "2,2,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
os.environ["TPU_RUNTIME_METRICS_PORTS"] = "8431,8432,8433,8434"
os.environ["TPU_SKIP_MDS_QUERY"] = "1"
os.environ["TPU_WORKER_HOSTNAMES"] = "localhost"
os.environ["TPU_WORKER_ID"] = "0"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

In [10]:
cfg = OmegaConf.load("../src/app/config/config.yaml")

In [11]:
train_df, val_df = get_data(cfg, mode="fit")
train_loader = get_data_loader(cfg, train_df, mode="train")
val_loader = get_data_loader(cfg, val_df, mode="validation")

In [14]:
import datetime

In [15]:
start_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
cfg.default_root_dir = os.path.join(
    cfg.output_dir,
    cfg.backbone,
    f"seed_{cfg.seed}",
    f"fold{cfg.fold}",
    f"{start_time}",
)
os.makedirs(cfg.default_root_dir, exist_ok=True)

cfg.backbone_args = dict(
    spatial_dims=cfg.spatial_dims,
    in_channels=cfg.in_channels,
    out_channels=cfg.n_classes,
    backbone=cfg.backbone,
    pretrained=cfg.pretrained,
)

In [None]:
batch

In [16]:
model = Net(cfg)

In [12]:
batch = next(iter(val_loader))

In [22]:
model.eval()
output = model(batch)

In [None]:
def post_process_pipeline(
    cfg: "DictConfig", net_output: "dict[str, Any]"
) -> "torch.Tensor":
    """Post-process the output of the model to get the final coordinates and confidence scores.
    Args:
        cfg (DictConfig): Configuration object containing project's parameters.
        net_output (dict): The output of the model.
    Returns:
        torch.Tensor: The final coordinates and confidence scores.
    """
    device = net_output["logits"].device
    new_size = torch.tensor(cfg.new_size, device=net_output["logits"].device)
    roi_size = torch.tensor(cfg.roi_size, device=net_output["logits"].device)

    img: "torch.Tensor" = net_output["logits"].detach()

    locations: "torch.Tensor" = net_output["location"]
    scales: "torch.Tensor" = net_output["scale"]
    tomo_ids: "torch.Tensor" = torch.tensor(net_output["id"], device=device)

    img = F.interpolate(
        img,
        size=roi_size.tolist(),
        mode="trilinear",
        align_corners=False,
    )

    out_size = get_output_size(img, locations, roi_size, device)
    rec_img = reconstruct(
        img=img,
        locations=locations,
        out_size=out_size,
        crop_size=roi_size,
        device=device,
    )

    s = torch.tensor(rec_img.shape[-3:], device=device)
    delta = (s - new_size) // 2  # delta to remove padding added during transforms
    dz, dy, dx = delta.tolist()
    nz, ny, nx = new_size.tolist()

    rec_img = rec_img[:, :, dz : nz + dz, dy : ny + dy, dx : nx + dx]

    rec_img = F.interpolate(
        rec_img,
        size=[d // 2 for d in new_size.tolist()],
        mode="trilinear",
        align_corners=False,
    )

    preds: "torch.Tensor" = rec_img.softmax(1)
    preds = preds[:, 0, :][None,]
    nms: "torch.Tensor" = simple_nms(preds, nms_radius=cfg.nms_radius)  # (B,1, D, H, W)
    nms = nms.squeeze(dim=1)  # (B, D, H, W)

    flat_nms = nms.reshape(nms.shape[0], -1)  # (B, D*H*W)
    conf, indices = torch.topk(flat_nms, k=cfg.topk, dim=1)
    zyx = torch.stack(torch.unravel_index(indices, nms.shape[-3:]), dim=-1)  # (B, K, 3)
    b = (
        torch.arange(zyx.shape[0], device=device)
        .unsqueeze(1)
        .expand(zyx.shape[0], cfg.topk)
    )

    b = b.reshape(-1, 1)
    zyx = zyx.reshape(-1, 3)

    zyx = ((zyx * 2) / scales[b]).round().to(torch.int)
    b = b.to(torch.long)
    conf = conf.to(torch.float32)

    ids: "torch.Tensor" = tomo_ids[b]

    ids = ids.reshape(-1, 1)
    conf = conf.reshape(-1, 1)
    zyx = zyx.reshape(-1, 3)

    output: "torch.Tensor" = torch.cat([zyx, ids, conf], dim=1)
    return output

In [25]:
device = output["logits"].device
new_size = torch.tensor(cfg.new_size, device=output["logits"].device)
roi_size = torch.tensor(cfg.roi_size, device=output["logits"].device)

In [26]:
img: "torch.Tensor" = output["logits"].detach()

In [27]:
locations: "torch.Tensor" = output["location"]
scales: "torch.Tensor" = output["scale"]
tomo_ids: "torch.Tensor" = torch.tensor(output["id"], device=device)

In [33]:
img = F.interpolate(
    img,
    size=roi_size.tolist(),
    mode="trilinear",
    align_corners=False,
)

In [34]:
out_size = get_output_size(img, locations, roi_size, device)
rec_img = reconstruct(
    img=img,
    locations=locations,
    out_size=out_size,
    crop_size=roi_size,
    device=device,
)

In [36]:
s = torch.tensor(rec_img.shape[-3:], device=device)
delta = (s - new_size) // 2  # delta to remove padding added during transforms
dz, dy, dx = delta.tolist()
nz, ny, nx = new_size.tolist()

rec_img = rec_img[:, :, dz : nz + dz, dy : ny + dy, dx : nx + dx]

In [38]:
rec_img = F.interpolate(
    rec_img,
    size=[d // 2 for d in new_size.tolist()],
    mode="trilinear",
    align_corners=False,
)

In [39]:
preds: "torch.Tensor" = rec_img.softmax(1)
preds = preds[:, 0, :][None,]

In [41]:
nms: "torch.Tensor" = simple_nms(preds, nms_radius=cfg.nms_radius)  # (B,1, D, H, W)
nms = nms.squeeze(dim=1)  # (B, D, H, W)

In [68]:
flat_nms = nms.reshape(nms.shape[0], -1)  # (B, D*H*W)
conf, indices = torch.topk(flat_nms, k=cfg.topk, dim=1)
zyx = torch.stack(torch.unravel_index(indices, nms.shape[-3:]), dim=-1)  # (B, K, 3)

In [69]:
zyx.shape

torch.Size([1, 20, 3])

In [60]:
b = torch.arange(zyx.shape[0], device=device)

In [66]:
zyx = ((zyx * 2) / scales[b.squeeze()]).round().to(torch.int)

In [67]:
zyx.shape

torch.Size([1, 20, 3])

In [None]:
b = (
    torch.arange(zyx.shape[0], device=device)
    .unsqueeze(1)
    .expand(zyx.shape[0], cfg.topk)
)

In [None]:

b = b.to(torch.long)
conf = conf.to(torch.float32)

ids: "torch.Tensor" = tomo_ids[b]

ids = ids.reshape(-1, 1)
conf = conf.reshape(-1, 1)
zyx = zyx.reshape(-1, 3)

output: "torch.Tensor" = torch.cat([zyx, ids, conf], dim=1)

torch.Size([400, 3])