In [None]:
# Setup
## Mount google drive
from google.colab import drive

drive.mount('/content/drive')

In [None]:
!pip install SimpleITK

In [None]:
!pip install nnunetv2

In [None]:
# Define folder path
folder_path = '/content/drive/My Drive/Inference models balanced dataset'

In [None]:
# Import libraries
import torch
import networks
import dataloader
import numpy as np
import SimpleITK as sitk
from typing import Tuple
import pandas
import scipy.ndimage as snd
from pathlib import Path
from tqdm import tqdm

from nnunetv2.paths import nnUNet_results, nnUNet_raw
from batchgenerators.utilities.file_and_folder_operations import join
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO

# Helper functions
## Function to keep the central connected component in a patch
def keep_central_connected_component(
    prediction: sitk.Image,
    patch_size: Tuple = (128, 128, 64),
) -> sitk.Image:
    """Function to post-process the prediction to keep only the central connected component in a patch

    Args:
        prediction (sitk.Image): prediction file (should be binary)
        patch_size (np.array, optional): patch size (x, y, z) to ensure the center is computed appropriately. Defaults to np.array([96, 96, 96]).

    Returns:
        sitk.Image: post-processed binary file with only the central connected component
    """

    origin = prediction.GetOrigin()
    spacing = prediction.GetSpacing()
    direction = prediction.GetDirection()

    prediction = sitk.GetArrayFromImage(prediction)

    c, n = snd.label(prediction)
    centroids = np.array(
        [np.array(np.where(c == i)).mean(axis=1) for i in range(1, n + 1)]
    ).astype(int)

    patch_size = np.array(list(reversed(patch_size)))

    if len(centroids) > 0:
        dists = np.sqrt(((centroids - patch_size // 2) ** 2).sum(axis=1))
        keep_idx = np.argmin(dists)
        output = np.zeros(c.shape)
        output[c == (keep_idx + 1)] = 1
        prediction = output.astype(np.uint8)

    prediction = sitk.GetImageFromArray(prediction)
    prediction.SetSpacing(spacing)
    prediction.SetOrigin(origin)
    prediction.SetDirection(direction)
    return prediction

## Function to perform inference on the test set
def perform_inference_on_test_set(workspace: Path):
    #check that the workspace is well-defined
    if workspace is None:
      raise ValueError("workspace no puede ser None")
    
    #now we load the mdoel architectures for both malignancy and noduletype
    malignancy_model = networks.CNN3D(1, 1, task="malignancy").cuda()
    noduletype_model = networks.CNN3D(1, 4, task="noduletype").cuda()

    #set the malignancy and noduletype models in evaluation mode
    malignancy_model.eval()
    noduletype_model.eval()

    # instantiate the nnUNetPredictor
    segmentation_model = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=True,
        perform_everything_on_device=True,
        device=torch.device('cuda', 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True
    )
    # initializes the network architecture, loads the model parameters
    segmentation_model.initialize_from_trained_model_folder(
        (workspace / 'nnUNet/nnUNet_results/Dataset001_LUNA/nnUNetTrainer__nnUNetResEncUNetMPlans__3d_fullres'),
        use_folds=("all",),
        checkpoint_name='checkpoint_best.pth',
    )

    # load the parameters for malignancy and noduletype models
    ckpt = torch.load(workspace / "20240526_0_malignancy_ORIGINAL/fold0/best_model.pth")
    malignancy_model.load_state_dict(ckpt)

    ckpt = torch.load(workspace / "20240526_0_noduletype_ORIGINAL/fold0/best_model.pth")
    noduletype_model.load_state_dict(ckpt)

    #define the paths to the test set and to a reuslts folder to store the predictions
    test_set_path = Path(workspace / "data" / "test_set" / "images")
    save_path = workspace / "results" / "test_set_predictions"

    segmentation_save_path = save_path / "segmentations"
    segmentation_save_path.mkdir(exist_ok=True, parents=True)


    patch_size = np.array([64, 128, 128])
    size_mm = 50
    size_px = 64

    predictions = []

    # iterate over the images in the test set
    for idx, image_path in enumerate(tqdm(list(test_set_path.glob("*.mha")))):

        # load and pre-process input image

        image = sitk.ReadImage(str(image_path))

        #load the image and some properties of it with SimpleITKIO() class from nnunetv2
        image_nnunet, props = SimpleITKIO().read_images([str(image_path)])

        noduleid = image_path.stem

        #image = sitk_image
        metad = {
            "origin": np.flip(image.GetOrigin()),
            "spacing": np.flip(image.GetSpacing()),
            "transform": np.array(np.flip(image.GetDirection())).reshape(3, 3),
            "shape": np.flip(image.GetSize()),
        }
        image = sitk.GetArrayFromImage(image)

        image = image.squeeze()

        image = dataloader.extract_patch(
            CTData=image,
            coord=tuple(patch_size // 2),
            srcVoxelOrigin=(0, 0, 0),
            srcWorldMatrix=metad["transform"],
            srcVoxelSpacing=metad["spacing"],
            output_shape=(size_px, size_px, size_px),
            voxel_spacing=(
                size_mm / size_px,
                size_mm / size_px,
                size_mm / size_px,
            ),
            coord_space_world=False,
        )
        #print("Patch extracted succesfully")

        #reshape the image for it to be compatible with the clip and scale function from dataloader
        image = image.reshape(1, 1, size_px, size_px, size_px).astype(np.float32)

        image = dataloader.clip_and_scale(image)

        #convert the image to a tensor and move it to the GPU for it to be processed by malignancy and segmentation models
        image_tensor = torch.from_numpy(image).cuda()

        #get the outputs of the models
        with torch.no_grad():
            outputs = {"segmentation": segmentation_model.predict_single_npy_array(image_nnunet, props, None, None, False),
                #"segmentation": segmentation_model(image)["segmentation"],
                "noduletype": noduletype_model(image_tensor)["noduletype"],
                "malignancy": malignancy_model(image_tensor)["malignancy"],
            }
        print("Outputs computed by the models")

        outputs = {k: (np.array(outputs[k]).squeeze() if k == "segmentation"
                       else outputs[k].data.cpu().numpy().squeeze()) for k in outputs.keys()}


        segmentation = outputs["segmentation"]

        # resample image to original spacing
        segmentation = snd.zoom(
            segmentation,
            (size_mm / size_px) / metad["spacing"],
            order=1,
        )
        #print(f"{segmentation.shape}")

        # pad image
        diff = metad["shape"] - segmentation.shape
        #print(diff)
        pad_widths = [
            (np.round(a), np.round(b))
            for a, b in zip(
                diff // 2.0 + 1,
                diff - diff // 2.0 - 1,
            )
        ]

        #if the diff vector has some 0 component then, pad_widths matrix is forced to make 0 all that row, so as 
        #to keep the shape of the image correctly
        for i in range(3):
          if diff[i] == 0:
            pad_widths[i] = (0.0, 0.0)

        pad_widths = np.array(pad_widths).astype(int)
        if pad_widths.max() <= 0:
          pad_widths[:] = 0

        pad_widths = np.clip(pad_widths, 0, pad_widths.max())
        pad_widths = np.clip(pad_widths, 0, pad_widths.max())
        segmentation = np.pad(
            segmentation,
            pad_width=pad_widths,
            mode="constant",
            constant_values=0,
        )

        # crop, if necessary
        if diff.min() < 0:

            shape = np.array(segmentation.shape)
            center = shape // 2

            segmentation = segmentation[
                center[0] - patch_size[0] // 2 : center[0] + patch_size[0] // 2,
                center[1] - patch_size[1] // 2 : center[1] + patch_size[1] // 2,
                center[2] - patch_size[2] // 2 : center[2] + patch_size[2] // 2,
            ]

        # apply threshold
        segmentation = (segmentation > 0.5).astype(np.uint8)

        # set metadata
        segmentation = sitk.GetImageFromArray(segmentation)
        segmentation.SetOrigin(np.flip(metad["origin"]))
        segmentation.SetSpacing(np.flip(metad["spacing"]))
        segmentation.SetDirection(np.flip(metad["transform"].reshape(-1)))
        #print(segmentation.GetSize())

        # keep central connected component
        segmentation = keep_central_connected_component(segmentation)
        print(segmentation.GetSize())
        #if segmentation.GetSize()!=(128,128,64):
          #break

        # write as simpleitk image
        sitk.WriteImage(
            segmentation,
            str(segmentation_save_path / f"{noduleid}.mha"),
            True,
        )

        # combine predictions from other task models
        prediction = {
            "noduleid": noduleid,
            "malignancy": outputs["malignancy"],
            "noduletype": outputs["noduletype"].argmax(),
            "ggo_probability": outputs["noduletype"][0],
            "partsolid_probability": outputs["noduletype"][1],
            "solid_probability": outputs["noduletype"][2],
            "calcified_probability": outputs["noduletype"][3],
        }

        predictions.append(pandas.Series(prediction))

    predictions = pandas.DataFrame(predictions)
    predictions.to_csv(save_path / "predictions.csv", index=False)

In [None]:
# Perform inference on the test set
if __name__ == "__main__":

    workspace = Path("/content/drive/My Drive/Inference models balanced dataset")

    perform_inference_on_test_set(workspace=workspace)