The goal of this notebook is to create a transform which selects `N-slices` randomly keeping bounding boxes in mind and create `Scan-crop`. We then pad the image to minimum required slices and then send it RandPosCrop. 

In [None]:
#| default_exp tfsm/scan_crop

In [None]:
#| export 
import numpy as np 
import fastcore.all as fc

from typing import Optional
from voxdet.tfsm.voxt import RandPosCrop, pad3d
from voxdet.tfsm.standard import BaseT


## Solution
our goal is to select slices such that atleast one nodule is present 75% (number changable).
- Randomly select a nodule 
- Take all the slices of that nodule. pickup 75% continous slices. so if [1, 2, 3, 4] are slices we can [1, 2, 3] or [2, 3, 4]
- Now we have to select N [10, 20, 30] slices such that [1, 2, 3] is present in it. 
- After getting final slices, we can realign the bboxes according to this.

In [None]:
from pathlib import Path
from safetensors.numpy import load_file
from voxdet.utils import vis

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
imgs = list(Path("/home/users/vanapalli.prakash/safe_ds_new/fold_0/").glob("*.safetensors"))
img = load_file(imgs[np.random.randint(len(imgs))])
len(imgs), img.keys()

In [None]:
vis(img["images"], 64)

In [None]:
rand = RandPosCrop(crop_size=(32, 192, 192))
rand

In [None]:
k = rand(img)

In [None]:
k["images"].shape, k["boxes"]

In [None]:
vis(k["images"], 64)

In [None]:
from voxdet.tfsm.voxt import PadIfNeeded

In [None]:
pad = PadIfNeeded(img_size=(96, 192, 192), side="right")
k2 = pad(k)

In [None]:
k2["images"].shape

In [None]:
vis(k2["images"], 64)

For some reason we have pad only on "right" only. we can get away with this in the following way. 
- Randomly select N 
- Use RandPosCrop to select N slices
- Required (96-N) are the total slices we need to pad.
- Randomly get (96-N) slices and "left pad" the slices. we can add z values to all the bboxes. 
- PadIfneeded now for remaining slices. 

In [None]:
k["images"].shape, k["boxes"]

In [None]:
left_pad = np.random.randint(96-32)
print(left_pad)
img, boxes = pad3d(k["images"], k["boxes"], pad=(left_pad, 0, 0), side="left")
img.shape, boxes

In [None]:
k["images"].shape[0]

In [None]:
vis(img, 64)

In [None]:
img.shape[0]

In [None]:
img2, boxes2 = pad3d(img, boxes, pad=(96-img.shape[0], 0, 0), side="right")
img2.shape

In [None]:
vis(img2, 64)

In [None]:
#| export 
def assymetric_z_pad(img, bbox, dim=96):
    imgshape = img.shape[0] if len(img.shape) == 3 else img.shape[1]
    assert imgshape < dim, f"img.shape[0] should be less than dim. \
                                 should have {imgshape}, got {dim}"
    left_pad = np.random.randint(dim-imgshape)
    img, bbox = pad3d(img, bbox, pad=(left_pad, 0, 0), side="left")
    imgshape = img.shape[0] if len(img.shape) == 3 else img.shape[1]
    img, bbox = pad3d(img, bbox, pad=(dim-imgshape, 0, 0), side="right")
    return img, bbox

In [None]:
k["images"].shape, k["boxes"]

In [None]:
img, bbox = assymetric_z_pad(k["images"], k["boxes"], dim=96)
img.shape, bbox

In [None]:
vis(img, 64)

In [None]:
x = np.random.randint(15, 30)
x

> Combining everything

In [None]:
#| export 
class SliceCrop(BaseT):
    def __init__(self, min_slices=15, max_slices: Optional[int]=None, multi_view: bool=False):
        """both min and max slices are integers"""
        fc.store_attr()
        super().__init__()
    __repr__ = fc.basic_repr("min_slices, max_slices, multi_view")
    
    def apply(self, img: dict):
        assert "images" in img.keys(), f"images not present in input [img]. Only: {img.keys()} present"
        if self.multi_view:
            _, zs, ys, xs = img["images"].shape
        else:
            zs, ys, xs = img["images"].shape
        slices = np.random.randint(self.min_slices, self.max_slices if self.max_slices is not None else zs)
        func = RandPosCrop(crop_size=(slices, ys, xs), multi_view=self.multi_view)
        img = func(img)
        
        ##
        fimg = img["images"].copy()
        boxes = img["boxes"].copy() if "boxes" in img.keys() else None 
        
        nimg = {}
        if boxes is not None: nimg["images"], nimg["boxes"] = assymetric_z_pad(fimg, boxes, dim=zs)
        else: nimg["images"] = assymetric_z_pad(fimg, boxes, dim=zs)
        for i in img.keys():
            if i not in nimg.keys(): nimg[i] = img[i]
        return nimg

In [None]:
np.random.sample(1)[0]

In [None]:
#| export
class RandSliceCrop(BaseT):
    def __init__(self, min_slices=15, max_slices: Optional[int]=None, prob=0.9, multi_view: bool=False):
        """both min and max slices are integers"""
        fc.store_attr()
        super().__init__()
        self.func = SliceCrop(self.min_slices, self.max_slices, multi_view=self.multi_view)
    
    __repr__ = fc.basic_repr("min_slices, max_slices, prob")
    
    def apply(self, img:dict):
        if np.random.sample(1)[0]<=self.prob: return self.func(img)
        else: return img 

In [None]:
rand = RandPosCrop(crop_size=(96, 192, 192))
sc = RandSliceCrop(max_slices=40, min_slices=20, prob=1.0)

In [None]:
%%time
img = load_file(imgs[np.random.randint(len(imgs))])
img2 = rand(img)
print(img2["images"].shape)
img3 = sc(img2)

In [None]:
vis(img2["images"], 64)

In [None]:
img2["boxes"], img3["boxes"]

In [None]:
img3["images"].sum((1, 2)).nonzero()[0].shape

In [None]:
vis(img3["images"], 64)

In [None]:
import imageio
from IPython.display import Image as DisplayImage
from voxdet.utils import hu_to_lung_window


In [None]:
box = img3["boxes"][:2, :][0].astype(int)
bimg = img3["images"][box[0]:box[3]+1, max(0, box[1]-10):box[4]+10, max(box[2]-10, 0):box[5]+10]
bimg = np.uint8(hu_to_lung_window(bimg)*255)
imageio.mimsave('sld_3.gif', [i for i in bimg])
DisplayImage(data='sld_3.gif', width=180, height=180) 

In [None]:
img3["images"].shape

In [None]:
## Checking if this is working for multi-view

In [None]:
rand = RandPosCrop(crop_size=(96, 192, 192), multi_view=True)
sc = RandSliceCrop(max_slices=40, min_slices=20, prob=0.5, multi_view=True)

In [None]:
%%time
img = load_file(imgs[np.random.randint(len(imgs))])
img["images"] = np.expand_dims(img["images"], 0)
print(img["images"].shape, img["boxes"].shape)
img2 = rand(img)
print(img2["images"].shape)
img3 = sc(img2)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()