In [None]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from glob import glob
import nibabel as nib
import pickle
from tqdm import tqdm
import random
from datetime import datetime
from scipy import ndimage
import matplotlib.pyplot as plt

import torch
import torchio as tio

import random
import csv 
import os
import fnmatch
import SimpleITK as sitk

from utils.helpers import *

import torch.nn.functional as F
from torchsummary import summary

import mlflow
import mlflow.pytorch

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
now = datetime.now()
print("Started running the code at:", now)
print("############################################################################\n\n")

In [None]:
from utils.model import *
from utils.helpers import *
from utils.postprocess import *
from utils.evaluation import *

In [None]:
def img_for_prediction(original_img):

    img = sitk.Cast(original_img, sitk.sitkFloat64)
    img = sitk.DICOMOrient(img, 'RPS')
    #show_sitk_img_info(fixed_img)
    
    final_img = intensity_normalize(img)
    final_img_array = sitk.GetArrayFromImage(final_img)
    #final_img_array = sitk.GetArrayFromImage(img)
    final_img_array = np.expand_dims(final_img_array, axis=0)

    #print("before cropping", final_img_array.shape)
    #final_img_array = cropping_function(final_img_array, crop_shape)
    final_img_array = np.expand_dims(final_img_array, axis=0)
    final_img_array = np.expand_dims(final_img_array, axis=0)
    #print("after cropping", final_img_array.shape)
    
    img_torch = np.ascontiguousarray(final_img_array)
    img_torch = torch.from_numpy(img_torch).float()
    #print("image passed through model", img_torch.shape)

    return img_torch
    

In [None]:
def target_preprocess(target):

    fixed_img = sitk.Cast(target, sitk.sitkFloat64)
    
    fixed_img = sitk.DICOMOrient(fixed_img, 'RPS')
    fixed_img_array = sitk.GetArrayFromImage(fixed_img)
    #print("before", fixed_img_array.shape)
    
    # expanding dimension
    fixed_img_array = np.expand_dims(fixed_img_array, axis=0)
    #print("after", fixed_img_array.shape)
    #show_sitk_img_info(fixed_img)
    
    #final_img_array = cropping_function(fixed_img_array, crop_shape)

    label = fixed_img_array
    label = torch.from_numpy(label).long()
    label = F.one_hot(label, num_classes=3)
    label = label.permute(0, 4, 1, 2, 3).contiguous()
    #label = label.squeeze()

    #print("labels present", np.unique(fixed_img_array))
    #explore_3D_array(fixed_img_array)

    return label
    

In [None]:
fold = 0
crop_shape = (96, 96, 96)

# the model is the longitudinal one
best_model_path_dir = "exp_OURS"
best_model_path = best_model_path_dir + '/models/fold-' + str(fold) + '/model_fold_' + str(fold) + '_best_model.pth'

# Load the trained model
model = nnUNet3D_Dropout(in_channels=1, out_channels=3) 

# Load the state dict
state_dict = torch.load(best_model_path)


# Check if the model was saved with DataParallel (if 'module.' is in keys)
if 'module.' in list(state_dict.keys())[0]:
    # Remove the 'module.' prefix from keys
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(state_dict)

# Wrap the model with DataParallel to use multiple GPUs
if torch.cuda.device_count() > 1:
    print("multiple GPUs identified")
    model = nn.DataParallel(model)


# Move model to GPU
model = model.cuda()

model.eval()


## prediction done on the longitudinal dataset
test_list_file = "PATH_TO_DATASET_DIR/test_fold_"  + str(fold) + ".txt"
print(test_list_file)
# Read subjects from the text file
with open(test_list_file, "r") as f:
    test_subject_list = [line.strip() for line in f.readlines()]  # Strip newline characters

print("Number of subjects", len(test_subject_list))

nifti_img_path = "PATH_TO_DATASET_DIR/Dataset901_FSCropReg/imagesTr/"
nifti_label_path = "PATH_TO_DATASET_DIR/Dataset901_FSCropReg/labelsTr/"

nifti_pred_path = best_model_path_dir + "/PREDS/fold-" + str(fold) + "/" 

# Get all files that start with any of the subject items
match_files = sorted([
    file for file in os.listdir(nifti_img_path)
    if any(file.startswith(subject) for subject in test_subject_list)
])

# Print or use the matching files
print("Total matched files", len(match_files))
#print(match_files)


# Construct file paths
all_files = [os.path.join(nifti_img_path, subject) for subject in match_files ]
print("Total files renamed", len(all_files))
#print(all_files)



In [None]:

# Print all file paths
   
print("############################################################################")

@torch.no_grad()
def predict(model):

    overall_dice_1 = 0
    overall_dice_2 = 0
    total_test = 0
    
    for img_path in all_files:

        #just_name = img_path.split('/')[-1] #.replace("_MRI_T1W_MNI_0000.nii.gz", "_MRI_T1W_MNI.nii.gz")
        just_name = img_path.split('/')[-1].replace("_0000.nii.gz", ".nii.gz")
        print(just_name)

        # Load original image and converting to tensor
        original_img = sitk.ReadImage(img_path)
        img_torch = img_for_prediction(original_img)
        img_arr = img_torch.cpu().numpy()
        #print("LOADING : original image:", sitk.GetArrayFromImage(original_img).shape)
        #show_sitk_img_info(img_arr)
        #print("ORIGINAL IMAGE DISPLAY SITK")
        #explore_3D_array(sitk.GetArrayFromImage(original_img))
        
        #print("ORIGINAL IMAGE DISPLAY")
        #explore_3D_array(img_arr[0, 0, 0, :, :, :])
        
        preds, _ = model(img_torch[:, 0, :, :, :, :].to(device)) # getting prediciton from the model and getting the one hot encoding of prediction
        #print("Prediction shape that came out from network", preds.shape)

        
        # Apply softmax to get class probabilities
        probabilities = torch.softmax(preds, dim=1)
        #explore_3D_array(probabilities[0, 0, :, :, :].cpu().numpy())
        predicted = torch.argmax(probabilities, dim=1) # getting the actual volume with 1 channel
        #print("PREDICTION DISPLAY")
        #explore_3D_array(predicted[0, :, :, :].cpu().numpy())
        #print("shape and unique in probabilties: ", probabilities.cpu().numpy().shape, np.unique(probabilities.cpu().numpy()))
        #print("shape and unique in prediction: ", predicted.cpu().numpy().shape, np.unique(predicted.cpu().numpy()))

    
        # loading the target and converting it to one hot encoding
        #target_path = nifti_label_path + just_name.replace(".nii.gz", "_hippocampus.nii.gz")
        target_path = nifti_label_path + just_name.replace("_0000.nii.gz", ".nii.gz")
        target = sitk.ReadImage(target_path)
        #print("LOADING : ground truth:", sitk.GetArrayFromImage(target).shape)
        #print("LOADING : unique in ground truth: ", np.unique(sitk.GetArrayFromImage(target)))
        target_one_hot = target_preprocess(target)
        #print("GROUND TRUTH DISPLAY")
        #explore_3D_array(sitk.GetArrayFromImage(target)[:, :, :])

        print("GROUND TRUTH OVER ORIGINAL VOLUME DISPLAY")
        #explore_3D_array_overlay_VOL(sitk.GetArrayFromImage(original_img), sitk.GetArrayFromImage(target))
        
        
    
        ### pad to make original size
        #predicted_full_res = pad_to_original(predicted, sitk.GetArrayFromImage(original_img).shape) 
        predicted_full_res = reverse_cropping_function(predicted, sitk.GetArrayFromImage(original_img).shape)
        print("shape and unique in pred full res: ", predicted_full_res.cpu().numpy().shape, np.unique(predicted_full_res.cpu().numpy()))

        #explore_3D_array_overlay_GT(sitk.GetArrayFromImage(original_img), sitk.GetArrayFromImage(target),predicted_full_res.cpu().numpy()[0,:,:,:], "Full Res")
        
        predicted_one_hot = predicted_full_res.long()
        predicted_one_hot = F.one_hot(predicted_one_hot, num_classes=3)
        predicted_one_hot = predicted_one_hot.permute(0, 4, 1, 2, 3).contiguous() # one hot encoding
        # checking the shape of the volumes
        #print(images.shape, targets.shape)
        print("Prediction full res shape", predicted_full_res.shape)
        
        #print("Label = 1 volume of prediction", (predicted[0, :, :, :] == 1).sum())
        #print("Label = 2 volume of prediction", (predicted[0, :, :, :] == 2).sum())
        #print("Label = 1 volume of target", (target_arr[0, :, :, :] == 1).sum())
        #print("Label = 2 volume of target", (target_arr[0, :, :, :] == 2).sum())

        ## flipping to match the target 
        #predicted_full_res_arr = sitk.GetArrayFromImage(out_image) #predicted_full_res.cpu().numpy()
        #print("before transpose", predicted_full_res_arr.shape)        
        predicted_full_res_arr = predicted_full_res.squeeze().cpu().numpy()
        predicted_full_res_arr = np.flip(predicted_full_res_arr , axis=2)

        # Convert NumPy array to SimpleITK image
        #out_image = sitk.GetImageFromArray(predicted_full_res.squeeze().cpu().numpy())
        out_image = sitk.GetImageFromArray(predicted_full_res_arr)
        out_image = sitk.Cast(out_image, sitk.sitkFloat32)
        out_image.SetSpacing(original_img.GetSpacing())
        out_image.SetOrigin(original_img.GetOrigin())
        out_image.SetDirection(original_img.GetDirection())
        #show_sitk_img_info(out_image)
        #print(nifti_pred_path + just_name)
        sitk.WriteImage(out_image, nifti_pred_path + just_name)
        #explore_3D_array(predicted_full_res[0, :, :, :].cpu().numpy())

        #### getting the dice scores
        ##############################################################################
        target_one_hot = target_one_hot.to(device)
        predicted_one_hot = predicted_one_hot.to(device)
        #print("Ground truth one hot shape", target_one_hot.shape)
        #print("Predicted one hot shape", predicted_one_hot.shape)
        
        dice1 = dice_score((predicted_one_hot[0, 1, :, :, :]).float(), (target_one_hot[0, 1, :, :, :]).float())
        dice2 = dice_score((predicted_one_hot[0, 2, :, :, :]).float(), (target_one_hot[0, 2, :, :, :]).float())
        
        print("Dice score for label=1 :", dice1.item()*100, "%")
        print("Dice score for label=2 :", dice2.item()*100, "%")
        #overall_dice_1 = overall_dice_1 + dice1.item()
        #overall_dice_2 = overall_dice_2 + dice2.item()
        total_test = total_test + 1 

        #original_img_arr = sitk.GetArrayFromImage(original_img)
        #target_arr = sitk.GetArrayFromImage(target)
        
        #print(original_img_arr.shape, target_arr.shape, predicted_full_res_arr.shape)
        #if (dice1 < 0.70) or (dice2 < 0.70):
        #    explore_3D_array_comparison_three(original_img_arr[:, :, :], target_arr[:, :, :], predicted_full_res_arr[:, :, :])
        #explore_3D_array_comparison(img_arr[idx, 0, :, :, :], target_arr[idx, 0,  :, :, :])
        
        print("\n")

    print("Total sample in the folder for fold {0} is {1}".format(fold, total_test))
    #print("Overall Dice for Label 1 -", (overall_dice_1/total_test)*100)
    #print("Overall Dice for Label 2 -", (overall_dice_2/total_test)*100)

# Predict using the model and the DataLoader
predict(model)

print("Prediction DONE")

