# Example notebook of Unet segmentation model used in Bachelor Project

Requierements:
* Computer needs lots of memory (preferably DTU HPC on a100sh)
* Plotly visualisations needs to be viewed locally as thinlink does not work great with rendering
* Run notebook from base folder

In [2]:
###
### Original script from: Vedrana Andersen Dahl
###

import nibabel as nib
import numpy as np
import skimage.measure
import torch
from monai.data import DataLoader, Dataset, decollate_batch
from monai.inferers import sliding_window_inference
from monai.transforms import (Activationsd, AsDiscreted, Compose,
                              CropForegroundd, EnsureChannelFirstd, Invertd,
                              LoadImaged, Orientationd, SaveImaged,
                              ScaleIntensityRanged, Spacingd)
from skimage.transform import rescale
from tqdm.notebook import tqdm
from skimage.measure import label   

import notebooks.volvizplotly as vvp
from src.models.unet_model import load_unet


In [3]:
device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Predicting on rat kidney

Loads images in with select preprossening
Then saves the prediction mask in the original image format.

In [5]:
def select_kidney(x):
    return x == 1

test_org_transforms_rats = Compose(
    [
        LoadImaged(keys=["image", "mask"]),
        EnsureChannelFirstd(keys=["image", "mask"]),
        Orientationd(keys=["image",'mask'], axcodes="RAS"),
        Spacingd(keys=["image", "mask"], pixdim=(0.0226, 0.0226, 0.0226), mode=("bilinear", "nearest")),
        CropForegroundd(keys=["image"], source_key="mask"), 
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57, # model should be retrained with original range of 0 to 255, but to late now.
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
    ]
)

post_transforms_rats = Compose(
    [
        Invertd(
            keys="pred",
            transform=test_org_transforms_rats,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
        ),
        #AsDiscreted(keys="pred", argmax=True, to_onehot=2),
        Activationsd(keys="pred", softmax=True), 
        AsDiscreted(keys="pred", argmax=True),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir='notebooks', output_postfix="seg", resample=False),
    ]
)

def save_prediction_masks(model, test_org_loader, post_transforms, device = 'cpu'):
    """
    Runs inference with all images in test_org_loader, using model and applying post_transforms
    """
    model.eval()

    with torch.no_grad():
        for test_data in test_org_loader:
            test_inputs = test_data["image"]#.to(device)
            roi_size = (160, 160, 160)
            sw_batch_size = 8 # change to lower number for smaller gpu memory 
            # really slow 30 min on gpu
            print('Running inference on test data')
            test_data["pred"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model,sw_device=device,progress=True)
            print('Saving prediction masks')
            test_data = [post_transforms(i) for i in decollate_batch(test_data)]

In [6]:
model_load_path = 'models/finetune-kfold/model_16742264.pth'
model, params = load_unet(model_load_path, device=device)

train_images = ['/dtu/3d-imaging-center/projects/2020_QIM_22_rat_kidney/analysis/analysis_rat37/rat37_reorient.nii.gz']
train_masks = ['/dtu/3d-imaging-center/projects/2020_QIM_22_rat_kidney/analysis/study_diabetic/aligned/rat37_aligned_rigid.nii']
data_dicts = [{"image": image_name, "mask": train_mask} for image_name, train_mask in zip(train_images, train_masks)]

print('Loading data (slow)')
test_org_ds = Dataset(data=data_dicts, transform=test_org_transforms_rats)
test_org_loader = DataLoader(test_org_ds, batch_size=1)
print('Finished loading data')

save_prediction_masks(model, test_org_loader, post_transforms_rats, device)

Loading data (slow)
Finished loading data
Running inference on test data


100%|██████████████████████████████████████████████████████████████████████████████████████████| 81/81 [18:10<00:00, 13.46s/it]


2023-06-17 10:26:16,510 INFO image_writer.py:194 - writing: notebooks/rat37_reorient/rat37_reorient_seg.nii.gz


## Visualises prediction
HTML output needs to be opened in firefox on a local computer. 
Performance is slow with thinlinc

In [15]:
camera = dict(
    up=dict(x=0, y=0, z=1),
    center=dict(x=0, y=0, z=0),
    eye=dict(x=-.6, y=2.1, z=.3)
)

# downscaled data for performance reasons
downscale_coeff = .5 # change to 0.25 for easier visualization
anti_aliasing= False
threshold = False

In [10]:
# loads original image
train_images = f'/dtu/3d-imaging-center/projects/2020_QIM_22_rat_kidney/analysis/analysis_rat37/rat37_reorient.nii.gz'
train_labels = f'/dtu/3d-imaging-center/projects/2020_QIM_22_rat_kidney/analysis/analysis_rat37/vessel_zoom_ground_truth-ish_rat37_v2.nii.gz'
train_masks = f'/dtu/3d-imaging-center/projects/2020_QIM_22_rat_kidney/analysis/study_diabetic/maskKidney/rat37_kidneyMaskProc.nii.gz'

img = nib.load(train_images)
img_mask = nib.load(train_masks)
lab_true = nib.load(train_labels)
lab_pred = nib.load('notebooks/rat37_reorient/rat37_reorient_seg.nii.gz')

# convert to arrays
vol_fdata = np.array(img.get_fdata()).transpose((2, 1, 0))
vol_mask_fdata = np.array(img_mask.get_fdata()).transpose((2, 1, 0))
seg_true_fdata = np.array(lab_true.get_fdata()).transpose((2, 1, 0))

if threshold:
    lab_pred_fdata = vol_fdata > 120
else:
    lab_pred_fdata =  np.array(lab_pred.get_fdata()).transpose((2, 1, 0))

In [25]:
# rescales for visualisation
vol = rescale(vol_fdata, downscale_coeff, anti_aliasing=anti_aliasing)
print('1/4 rescale')
vol_mask = rescale(vol_mask_fdata, downscale_coeff, anti_aliasing=anti_aliasing)
print('2/4 rescale')
seg_true = rescale(seg_true_fdata, downscale_coeff, anti_aliasing=anti_aliasing)
print('3/4 rescale')
lab_pred = rescale(lab_pred_fdata, downscale_coeff, anti_aliasing=anti_aliasing)
print('4/4 rescale')

1/4 rescale
2/4 rescale
3/4 rescale
4/4 rescale


In [26]:
# Use marching cubes to obtain the surface mesh
verts1, faces1, _, _ = skimage.measure.marching_cubes(seg_true!=0, 0.1)
verts2, faces2, _, _ = skimage.measure.marching_cubes(lab_pred==1, 0.5)
verts3, faces3, _, _ = skimage.measure.marching_cubes((seg_true!=0) & (lab_pred==1), 0.5)

### Save prediction, label, and slice to file

In [27]:
# creates plots
fig = vvp.volume_slicer(vol, [None, 'mid', None], show=False, title='middle ct-slice',width=1200, height=1200)
fig = vvp.show_mesh(verts1, faces1, fig=fig, show=False, surface_color='red', wireframe_opacity=.5,surface_opacity=.5,camera=camera)
fig = vvp.show_mesh(verts2, faces2, fig=fig, show=False,surface_color='green', wireframe_opacity=.5,surface_opacity=.5)
fig = vvp.show_mesh(verts3, faces3, fig=fig, show=False,surface_color='blue', wireframe_opacity=1,surface_opacity=1)

In [28]:
# HTML output needs to be opened in firefox on a local computer. 
# Performance is slow with thinlinc
fig.write_html('notebooks/rat37_reorient_pred_full.html')

### Save prediction to file

In [18]:
# creates plots (ONLY INFERENCE)
fig = vvp.show_mesh(verts2, faces2, fig=None, show=False,surface_color='green', wireframe_opacity=1,surface_opacity=1)

In [19]:
# HTML output needs to be opened in firefox on a local computer. 
# Performance is slow with thinlinc
fig.write_html('notebooks/rat37_reorient_pred.html')

In [31]:

def getLargestCC(segmentation):
    labels = label(segmentation)
    assert( labels.max() != 0 ) # assume at least 1 CC
    largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
    return largestCC

lab_pred = getLargestCC(lab_pred_fdata) # first calculate LCC
lab_pred = rescale(lab_pred, downscale_coeff, anti_aliasing=anti_aliasing)


# Use marching cubes to obtain the surface mesh - for new  LCC
verts1, faces1, _, _ = skimage.measure.marching_cubes(seg_true!=0, 0.5)
verts2, faces2, _, _ = skimage.measure.marching_cubes(lab_pred==1, 0.5)
verts3, faces3, _, _ = skimage.measure.marching_cubes((seg_true!=0) & (lab_pred==1), 0.5)

# Currenyly only shows figure with LCC uncomment to show all and shange fig=None to fig=fig
#fig = vvp.volume_slicer(vol, [None, 'mid', None], show=False, title='middle ct-slice',width=1200, height=1100)
#fig = vvp.show_mesh(verts1, faces1, fig=fig, show=False, surface_color='red', wireframe_opacity=1,surface_opacity=1)
fig = vvp.show_mesh(verts2, faces2, fig=None, show=False,surface_color='green', wireframe_opacity=1,surface_opacity=1)
#fig = vvp.show_mesh(verts3, faces3, fig=fig, show=False,surface_color='yellow', wireframe_opacity=1,surface_opacity=1)

In [32]:
# HTML output needs to be opened in firefox on a local computer. 
# Performance is slow with thinlinc
fig.write_html('notebooks/rat37_reorient_pred_lcc.html')