In [1]:
import os
import cv2
import numpy as np
from PIL import Image
from scipy import ndimage, stats
from skimage.transform import rotate

from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import normalized_mutual_info_score

import nibabel as nib
from tqdm import tqdm
import glob
import matplotlib.pyplot as plt
import csv
import shutil
import argparse
import sys
sys.path.append("/datassd/PSMA_SEGMENTATION/Diffusion-based-Segmentation")
import random



%matplotlib inline

In [None]:
def read_nii_vol(target_filename):
    img = nib.load(target_filename)
    VOL = np.array(img.dataobj)
    return VOL

def visualize(img):
    _min = img.min()
    _max = img.max()
    normalized_img = (img - _min)/ (_max - _min)
    return normalized_img

def dice_score(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()

def normalize(img, _min, _max):
    normalized_img = (img - _min)/ (_max - _min)
    return normalized_img

def normalize3(volume):
    min_val = volume.min()
    max_val = volume.max()
    normalized_volume = (volume - min_val) / (max_val - min_val)
    return normalized_volume


In [2]:
dataset_root = "/datassd/PSMA_SEGMENTATION/DATA"
test_dir = os.path.join(dataset_root, "testing")
predictions_dir = os.path.join(dataset_root, "inferences")
gt_projections_save_dir = os.path.join(dataset_root, "gt_projections")
save_dir = os.path.join(dataset_root, "prediction_projections")

index = 0
# test_patients_dir_list = os.listdir(test_dir)
patients_list = glob.glob(os.path.join(predictions_dir, "*-000-01"))

for patient_dir in patients_list:
    patient_id = os.path.basename(patient_dir)[:-7]
    print(patient_id)
    for bodypart in ["top", "bottom"]:
        top_bottom = "01" if bodypart == "top" else "02"
        mean_images_list = list()
        gt_images_list = list()
        for rotation in range(0, 360, 5):
            study_case = f"{patient_id}-{str(rotation).zfill(3)}-{top_bottom}"
            # print(study_case)

            # '/datassd/PSMA_SEGMENTATION/DATA/testing/PSMA-01-097-000-02/PSMA-01-097_PETD_000-02.nii.gz'
            gt_filename = f"/datassd/PSMA_SEGMENTATION/DATA/testing/{study_case}/{patient_id}_SEG_{str(rotation).zfill(3)}-{top_bottom}.nii.gz"
            gt = read_nii_vol(gt_filename)
            gt = np.uint8(gt)
            gt_images_list.append(gt)
            
            pt_filename = f"/datassd/PSMA_SEGMENTATION/DATA/testing/{study_case}/{patient_id}_PET_{str(rotation).zfill(3)}-{top_bottom}.nii.gz"
            pt = read_nii_vol(pt_filename)
            pt = np.float32(pt)
            
            predicted_seg_filename = f"{patient_id}_Predicted-Seg_{str(rotation).zfill(3)}-{top_bottom}.npy"
            predicted_seg_np = np.load(os.path.join(predictions_dir, study_case, predicted_seg_filename))
            if predicted_seg_np.max() == 0:
                mean_image = np.zeros_like(gt)
                mean_images_list.append(mean_image)
            else:
                # print(study_case)
                
                # plt.figure(figsize=(10, 7))
                # images_cmap = 'bone'
                # plt.pcolormesh(pt.T, cmap=images_cmap)
                # plt.axis('off')
                # plt.title('MIP')
                # plt.colorbar()
                # plt.show()
                
                # plt.figure(figsize=(10, 7))
                # images_cmap = 'bone'
                # plt.pcolormesh(gt.T, cmap=images_cmap)
                # plt.axis('off')
                # plt.title('GT')
                # plt.colorbar()
                # plt.show()
                
                pred_list = list()
                # plt.figure(figsize=(10, 7))
                for i in range(predicted_seg_np.shape[0]):
                    
                    pred = predicted_seg_np[i,0,:,:]
                    
                    pred[pred[...] >= 0.4] = 1
                    pred[pred[...] < 0.4] = 0
                    pred_list.append(pred)
                    
                #     plt.subplot(4, 3, i+1)
                #     images_cmap = 'bone'
                #     plt.pcolormesh(pred.T, cmap=images_cmap)
                #     plt.axis('off')
                #     plt.title(f"prediction no: {i+1} ")
                #     plt.colorbar()
                    
                # plt.show()
                
                stacked_preds = np.stack(pred_list)
                    
                mean_image = np.mean(stacked_preds, axis=0)
                mean_image[mean_image[...] < 0.3] = 0
                mean_image[mean_image[...] >=0.3] = 1
                
                mean_images_list.append(mean_image)
                
                # asghar = "agha"
                # std_image = np.std(stacked_preds, axis=0)
                # var_image = np.var(stacked_preds, axis=0)

                # plt.figure(figsize=(20, 5))
                # plt.subplot(1, 3, 1)
                # plt.pcolormesh(mean_image[:,:].T, cmap='viridis', vmin=0, vmax=1)
                # plt.axis('off')
                # plt.title('mean')
                # plt.colorbar()

                # plt.subplot(1, 3, 2)
                # plt.pcolormesh(std_image[:,:].T, cmap='viridis', vmin=0, vmax=1)
                # plt.axis('off')
                # plt.title('std')
                # plt.colorbar()
                
                # plt.subplot(1, 3, 3)
                # plt.pcolormesh(var_image[:,:].T, cmap='viridis', vmin=0, vmax=1)
                # plt.axis('off')
                # plt.title('var')
                # plt.colorbar()
                # plt.show()



                # plt.figure(figsize=(20, 5))
                # plt.subplot(1, 3, 1)
                # plt.pcolormesh(pt.T, cmap='bone')
                # plt.axis('off')
                # plt.title('MIP')
                # plt.colorbar()

                # plt.subplot(1, 3, 2)
                # plt.pcolormesh(gt.T, cmap='bone')
                # plt.axis('off')
                # plt.title('GT')
                # plt.colorbar()

                # plt.subplot(1, 3, 3)
                # plt.pcolormesh(gt.T - mean_image[:,:].T, cmap='bwr', vmin=-1, vmax=1)
                # plt.axis('off')
                # plt.title('GT - Mean')
                # plt.colorbar()
                # plt.show()
                
        stacked_mean_preds = np.stack(mean_images_list)
        stacked_gt_projs = np.stack(gt_images_list)
        savedir = patient_id
        save_fulldir = os.path.join(save_dir, savedir)
        if not os.path.exists(save_fulldir):
            os.makedirs(save_fulldir)
        np_filename = f"{patient_id}-{top_bottom}.npy"
        np.save(os.path.join(save_fulldir, np_filename), stacked_mean_preds)
        gt_save_fulldir = os.path.join(gt_projections_save_dir, savedir)
        if not os.path.exists(gt_save_fulldir):
            os.makedirs(gt_save_fulldir)
        np.save(os.path.join(gt_save_fulldir, np_filename), stacked_gt_projs)
        
        asghar = "agha"
        

        
    asghar = "agha"
    index += 1

PSMA-01-097
PSMA-01-234
PSMA-01-371
PSMA-01-145
PSMA-01-453
PSMA-01-982
PSMA-01-512
PSMA-01-420
PSMA-01-225
PSMA-01-296
PSMA-01-111
PSMA-01-128
PSMA-01-032
PSMA-01-775
PSMA-01-240
PSMA-01-042
PSMA-01-170
PSMA-01-035
PSMA-01-492
PSMA-01-707
PSMA-01-161
PSMA-01-520
PSMA-01-286
PSMA-01-123
PSMA-01-605
PSMA-01-186
PSMA-01-367
PSMA-01-545
PSMA-01-377
PSMA-01-767
PSMA-01-457
PSMA-01-318
PSMA-01-188
PSMA-01-155
PSMA-01-364
PSMA-01-644
PSMA-01-429
PSMA-01-502
PSMA-01-560
PSMA-01-119
PSMA-01-276
PSMA-01-382
PSMA-01-787
PSMA-01-281
PSMA-01-282
PSMA-01-120
PSMA-01-193
PSMA-01-013
PSMA-01-534
PSMA-01-478
