In [1]:
import os
import numpy as np
from numpy.typing import NDArray
import torch
import pydicom
import matplotlib.pyplot as plt
from tcia_utils import nbia
from monai.bundle import ConfigParser, download
from monai.transforms import LoadImage, LoadImaged, Orientation, Orientationd, EnsureChannelFirst, EnsureChannelFirstd, Compose
from rt_utils import RTStructBuilder
from scipy.ndimage import label, measurements
import json

2024-05-11 11:26:14.163479: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
datadir = '/Users/williamlee/Documents/Git Repos/CT_WholeBody_Segmentation/Data/'

# Part 1: Open CT Image

In [5]:
CT_folder = os.path.join(datadir, '1.3.6.1.4.1.14519.5.2.1.3320.3273.193828570195012288011029757668')

# Part 2: Download the Model

In [4]:
model_name = "wholeBody_ct_segmentation"
download(name=model_name, bundle_dir=datadir)

2024-05-11 09:35:46,024 - INFO - --- input summary of monai.bundle.scripts.download ---
2024-05-11 09:35:46,026 - INFO - > name: 'wholeBody_ct_segmentation'
2024-05-11 09:35:46,027 - INFO - > bundle_dir: '/Users/williamlee/Documents/Git Repos/Segmentation/Tutorial/Data/'
2024-05-11 09:35:46,028 - INFO - > source: 'monaihosting'
2024-05-11 09:35:46,031 - INFO - > remove_prefix: 'monai_'
2024-05-11 09:35:46,033 - INFO - > progress: True
2024-05-11 09:35:46,036 - INFO - ---




HTTPError: 404 Client Error: Not Found for url: https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/wholeBody_ct_segmentation

In [6]:
# Define path to model
model_path = os.path.join(datadir, 'wholeBody_ct_segmentation', 'models', 'model_lowres.pt')
config_path = os.path.join(datadir, 'wholeBody_ct_segmentation', 'configs', 'inference.json')

# Part 3: Pipeline Set-Up

## 3.1: Config Instance

In [7]:
# Config instance to read from json file
config = ConfigParser()
config.read_config(config_path)

## 3.2: Preprocessing

From this we can extract the preprocessing pipeline specified by the `inference.json` file
* These are all the operations applied to the data before feeding it to the model

In [8]:
preprocessing = config.get_parsed_content("preprocessing")

In [9]:
data = preprocessing({'image': CT_folder})

## 3.3: Inferer

In [10]:
# Takes in the data and the model, and returns model output
inferer = config.get_parsed_content("inferer")

## 3.4: Postprocesser

In [11]:
postprocessing = config.get_parsed_content("postprocessing")
data['image'].unsqueeze(0).shape

torch.Size([1, 1, 167, 167, 650])

# Part 5: Load the Model

In [12]:
# Obtain the model using the "network" key
model = config.get_parsed_content("network")

In [13]:
# Load model with pre-trained parameters
print(model_path)

# Load the state dictionary
state_dict = torch.load(model_path, map_location=torch.device('cpu'))

# Load the state dictionary onto the model
model.load_state_dict(state_dict)

/Users/williamlee/Documents/Git Repos/Segmentation/Tutorial/Data/wholeBody_ct_segmentation/models/model_lowres.pt


<All keys matched successfully>

# Part 6: Prediction

In [None]:
data = preprocessing({'image': CT_folder}) # returns a dictionary
# 2. Compute mask prediction, add it to dictionary
with torch.no_grad():
    # Have to add additional batch dimension to feed into model
    data['pred'] = inferer(data['image'].unsqueeze(0), network=model, device='cpu')
# Remove batch dimension in image and prediction
data['pred'] = data['pred'][0]
data['image'] = data['image'][0]
# Apply postprocessing to data
data = postprocessing(data)
segmentation = torch.flip(data['pred'][0], dims=[2])
segmentation = segmentation.cpu().numpy()

In [None]:
slice_idx = 250
CT_coronal_slice = CT[0,:,slice_idx].cpu().numpy()
segmentation_coronal_slice = segmentation[:,slice_idx]

In [None]:
plt.subplots(1,2,figsize=(6,8))
plt.subplot(121)
plt.pcolormesh(CT_coronal_slice.T, cmap='Greys_r')
plt.axis('off')
plt.subplot(122)
plt.pcolormesh(segmentation_coronal_slice.T, cmap='nipy_spectral')
plt.axis('off')
plt.show()