# Monai test.

### Imports.

In [None]:
import os
import torch
import matplotlib.pyplot as plt
from monai.bundle import ConfigParser, download
from monai.transforms import LoadImage, LoadImaged, Orientation, Orientationd, EnsureChannelFirst, EnsureChannelFirstd, Compose
from tcia_utils import nbia

### Consts.

In [None]:
model_name = "wholeBody_ct_segmentation"
download_model = True
data_dir = "./data"
dicom_dir = os.path.join(data_dir, "1.3.6.1.4.1.14519.5.2.1.3320.3273.193828570195012288011029757668")
model_path = os.path.join(data_dir, model_name, "models", "model_lowres.pt")
config_path = os.path.join(data_dir, model_name, "configs", "inference.json")
slice_index = 256

### Download the CT data.

In [None]:
cart_name = "nbia-56561691129779503"
cart_data = nbia.getSharedCart(cart_name)
df = nbia.downloadSeries(cart_data, format="df", path=data_dir)

### Get the data.

In [None]:
# Get the volume.
preprocessing_pipeline = Compose([
    LoadImage(image_only=True),
    EnsureChannelFirst(),
    Orientation(axcodes="LPS")
])
volume = preprocessing_pipeline(dicom_dir)

# Display one coronal slice.
slice = volume[0, :, slice_index].cpu().numpy()
plt.figure(figsize=(3,8))
plt.pcolormesh(slice.T, cmap="Greys_r")
plt.colorbar(label="HU")
plt.axis("off")
plt.show()

# Get volume into the dict.
preprocessing_pipeline = Compose([
    LoadImaged(keys="image", image_only=True),
    EnsureChannelFirstd(keys="image"),
    Orientationd(keys="image",axcodes="LPS")
])
data = {"image": dicom_dir}
data = preprocessing_pipeline(data)
print(data)

### Get config.

In [None]:
config = ConfigParser()
config.read_config(config_path)

### Get the data.

In [None]:
preprocessing = config.get_parsed_content("preprocessing")
data = preprocessing({"image": dicom_dir})
print(data)

### Load the model.

In [None]:
if download_model:
    download(name=model_name, bundle_dir=data_dir)

model = config.get_parsed_content("network")
model.load_state_dict(torch.load(model_path))
model.eval()

### Run segmentation.

In [None]:
inferer = config.get_parsed_content("inferer")

with torch.no_grad():
    data["pred"] = inferer(data["image"].unsqueeze(0), network=model)

data["pred"] = data["pred"][0]
data["image"] = data["image"][0]

postprocessing = config.get_parsed_content("postprocessing")
data = postprocessing(data)

segmentation = torch.flip(data["pred"][0], dims=[2])
segmentation = segmentation.cpu().numpy()

### Show results.

In [None]:
coronal_slice = volume[0, :, slice_index].cpu().numpy()
segmentation_coronal_slice = segmentation[:, slice_index]

plt.subplots(1,2,figsize=(6,8))
plt.subplot(121)
plt.pcolormesh(coronal_slice.T, cmap="Greys_r")
plt.axis("off")
plt.subplot(122)
plt.pcolormesh(segmentation_coronal_slice.T, cmap="nipy_spectral")
plt.axis("off")
plt.show()