In [1]:
#| default_exp infer

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export 
import torch
from torchvision.transforms import Compose 
from voxdet.networks.monai_retina3d import retina_detector
from munch import munchify
from voxdet.bbox_func.nms import monai_nms
import numpy as np

In [4]:
import imageio
import numpy as np
import fastcore.all as fc
from pathlib import Path
from IPython.display import Image as DisplayImage
from voxdet.utils import hu_to_lung_window
from voxdet.utils import vis, load_sitk_img

In [5]:
#| export 
def subset_cfg_for_infer(cfg):
    """required transforms for spatial size"""
    required = ["anchor_params", "resolution", "classes", "spatial_size", "roi_size", "infer_cfg", \
                "infer_thr", "fe", "test_transforms", "fpn_params", "model_cfg"]
    cfg2 = {k:v for k, v in cfg.items() if k in required}
    return cfg2

In [6]:
#| export 
def load_model(path, map_device=torch.device("cpu")):
    data = torch.load(path, map_location=map_device)
    cfg = data["cfg"]
    cfg = munchify(cfg)

    transforms = Compose([i for i in cfg.test_transforms])
    model = retina_detector(cfg)
    model.load_state_dict(data["state_dict"], strict=False)
    model = model.eval()
    return model, cfg, transforms

In [None]:
series = "/cache/datanas1/qct-nodules/studies_nifti/WCG/1.3.6.1.4.1.55648.166786657465154199470575722567012949663.3.nii.gz"
series_id = series.rsplit("/")[-1][:-7]
oimg = load_sitk_img(series, series_id)
oimg["images"].shape, oimg["spacing"]

In [None]:
vis(oimg["images"], 64, window=64)

In [None]:
#| export 
class RetinaInfer:
    def __init__(self, checkpoint_path: str, device: str= None , inf_safe: bool =False):
        self.model, self.cfg, self.transforms = load_model(path = checkpoint_path)
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        self.model.to(self.device)
        self.inf_safe = inf_safe
    
    @torch.no_grad()
    def __call__(self, img: dict, nms_thr=0.2, cnf_thr=0.05):
        if not self.inf_safe : 
            nimg = self.transforms(img)
        else :
            nimg = img
        
        if len(nimg["images"].shape) == 4:
            input_image = torch.from_numpy(nimg["images"]).type(torch.float32).to(self.device).unsqueeze(0)
        else:
            input_image = torch.from_numpy(nimg["images"]).type(torch.float32).to(self.device).unsqueeze(0).unsqueeze(0)

        logits = self.model(input_image, None, use_inferer=True)[0]  # hardcoded for one class
        if nms_thr is not None: logits = monai_nms(logits, nms_thr, cnf_thr)
        nimg.update(logits)
        if self.inf_safe :
            return nimg
        nimg = self.reverse_apply(nimg)
        return nimg
    
    def reverse_apply(self, img):
        out = img.copy()
        for tfsm in self.transforms.transforms[::-1]:
            #import pdb; pdb.set_trace()
            out = tfsm.reverse_apply(out)
        return out

In [None]:
CHECKPOINT_DET = "../lightning_logs/v150/version_6/checkpoints/epoch=224-step=22950-val/AP=0.638.ckpt"
infer = RetinaInfer(checkpoint_path=CHECKPOINT_DET)

In [None]:
from voxdet.tfsm.med import AddLungCache
infer.transforms.transforms.insert(1, AddLungCache(cache_dir="/cache/datanas1/qct-nodules/nifti_with_annots/lung_mask_cache/",\
                                                       model_ckpt="/home/users/vanapalli.prakash/repos/qct_nodule_detection/resources/unet_r231-d5d2fc3d_v0.0.1.pth"))

In [None]:
infer.transforms.transforms

In [None]:
path = Path("/cache/datanas1/qct-nodules/nifti_with_annots/medframe/")
series = fc.L(path.glob("*.nii.gz"))
series

In [None]:
# nodules = []
# for s_ in series:
#     sp = path/(s_+".nii.gz")
#     img = load_sitk_img(sp, s_)
#     nimg = infer(img)
#     nodules.append((s_, len(nimg["boxes"][nimg["scores"]>0.9])))

In [None]:
%%time
img = load_sitk_img(series[0], series[0].name[:-7])
nimg = infer(img)

In [None]:
nimg["lung_mask"].shape, img["images"].shape, nimg["images"].shape, img["spacing"], nimg["spacing"]

In [None]:
nimg["images"].shape, nimg["boxes"].shape, nimg["scores"]

In [None]:
from voxdet.retina_test import convert2int, draw_on_ct
from qct_utils.ctvis.viewer import plot_scans

In [None]:
plot_scans([nimg["images"][0]], ["scan"])

In [None]:
#nimg["images"].shape, nimg["boxes"].shape, nimg["scores"]

In [None]:
boxes = convert2int(nimg["boxes"][nimg["scores"]>0.9])
timg = img["images"]
dimg = draw_on_ct(timg, boxes)

In [None]:
#boxes

In [None]:
plot_scans([dimg], ["scan"])

In [None]:
box = nimg["boxes"][:10, :][7].astype(int)
bimg = img["images"][box[0]:box[3], box[1]-10:box[4]+10, box[2]-10:box[5]+10]
bimg = np.uint8(hu_to_lung_window(bimg)*255)
imageio.mimsave('sld_3.gif', [i for i in bimg])
DisplayImage(data='sld_3.gif', width=180, height=180) 

In [None]:
nimg["images"].shape

In [None]:
vis(nimg["images"]*255, 64, window=False)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()