In [10]:
import pandas as pd

In [11]:
# import data files
# young_train_df = pd.read_csv("young_train_df.csv")
young_test_df = pd.read_csv("young_test_df.csv")
# old_train_df = pd.read_csv("old_train_df.csv")
old_test_df = pd.read_csv("old_test_df.csv")

In [12]:
test_df = pd.concat([young_test_df, old_test_df], ignore_index=True)

In [13]:
test_df.columns

Index(['MRID', 'Age', 'Diagnosis', 'Sex', 'sth', 'B', 'Patient ID', 'Date',
       'Age-rounded', 'Age_Group', 'dataset'],
      dtype='object')

In [14]:
# Group by 'Patient ID'
grouped = test_df.groupby('Patient ID')

# Calculate the age range and count of scans for each patient
patient_scans = grouped.agg(
    age_range = ('Age', lambda x: x.max() - x.min()),  # Calculate the range of ages
    num_scans = ('Patient ID', 'size')  # Count the number of scans (rows) per group
).reset_index()

# Additional aggregation to get distinct diagnoses
diagnosis_data = test_df.groupby('Patient ID')['Diagnosis'].unique().reset_index()
diagnosis_data['Diagnosis'] = diagnosis_data['Diagnosis'].apply(lambda x: ', '.join(sorted(x)))

# Merge this back into the patient_scans DataFrame
patient_scans = patient_scans.merge(diagnosis_data, on='Patient ID', how='left')

# Sort the DataFrame by 'age_range' and 'num_scans', both in descending order
patient_scans = patient_scans.sort_values(by=['age_range', 'num_scans'], ascending=[False, False])

In [15]:
# Filter to exclude patients with only one diagnosis and age range less than two years
patient_scans_filtered = patient_scans[(patient_scans['Diagnosis'].str.contains(',')) & (patient_scans['age_range'] >= 2)]

In [16]:
patient_scans_filtered

Unnamed: 0,Patient ID,age_range,num_scans,Diagnosis
26,1190,12.783025,5,"AD, CN"
50,4100,7.340178,6,"CN, MCI"
44,4035,4.246407,5,"AD, MCI"
102,4715,4.21629,4,"AD, MCI"
106,4816,4.123203,5,"AD, MCI"
85,4446,4.087611,5,"CN, MCI"
12,604,3.055441,5,"AD, MCI"
13,638,3.033539,6,"AD, MCI"
19,1066,3.008898,3,"AD, MCI"
84,4432,2.12731,2,"AD, MCI"


In [17]:
# Get a list of unique Patient IDs from the filtered DataFrame
filtered_patient_ids = patient_scans_filtered['Patient ID'].unique()

# Filter the original DataFrame to include only rows with these Patient IDs
related_entries = test_df[test_df['Patient ID'].isin(filtered_patient_ids)]

# Select only the required columns
final_table = related_entries[['Patient ID', 'MRID', 'Age-rounded', 'Diagnosis']]

# Calculate the minimum 'Age-rounded' for each 'Patient ID'
min_age = final_table.groupby('Patient ID')['Age-rounded'].transform(min)

# Create a new column 'Type' based on the condition if 'Age-rounded' equals the minimum age for that 'Patient ID'
final_table['Type'] = ['input' if age == min_age else 'ground truth' for age, min_age in zip(final_table['Age-rounded'], min_age)]

# Add a new column 'age-difference' which is the difference between each 'Age-rounded' and the group minimum 'Age-rounded'
final_table['age-difference'] = final_table['Age-rounded'] - min_age

# Sort final_table to visually inspect the changes or prepare for output
final_table.sort_values(by=['Patient ID', 'Age-rounded'], inplace=True)

  min_age = final_table.groupby('Patient ID')['Age-rounded'].transform(min)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  final_table['Type'] = ['input' if age == min_age else 'ground truth' for age, min_age in zip(final_table['Age-rounded'], min_age)]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  final_table['age-difference'] = final_table['Age-rounded'] - min_age
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-ve

In [18]:
# For each entry, we load the initial scan, take the slices -> pass it through generator
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1
import sys
sys.path.append('/home/aarongzy/class/src')
import torch
from torch.utils.data import DataLoader
import numpy as np
import nibabel as nib
import os
from models import Generator
from data_loader import CustomImageDataset

from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM


import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Do a nested for loop with two loops. The first iterate through Patient ID, get the first scan based on its MRID, slice it and run each slice through the generator
# the second inner loops reads the scans from the same patient, 
model_path = "/home/aarongzy/class/src/trained_models/generator.pth"
gen = Generator().to(device)
gen.load_state_dict(torch.load(model_path, map_location=device))

# Function to encode the diagnosis based on predefined category mapping
def encode_diagnosis(diagnosis):
    category_mapping = {"CN": [0, 0], "MCI": [0, 1], "AD": [1, 1]}
    # Fetch the list from the mapping or default to [None, None] if diagnosis is not found
    encoded_list = category_mapping.get(diagnosis, [None, None])
    # Convert list to a 2x1 NumPy array
    return np.array(encoded_list).reshape(2, 1)

def age_vector(age):
    vector = np.zeros(100, dtype=int)
    if age < 100:  # If age is 100 or more, the vector will be all 1s
        vector[100-age:] = 1
    else:
        vector[:] = 1
    return vector


# Function to process each slice
def process_slice(slice_data):
    # Assuming your model expects a certain input shape and tensor type
    # Convert slice_data to tensor, unsqueeze to add batch dimension, etc.
    tensor = torch.from_numpy(slice_data).float().unsqueeze(0)  # Example, adjust dimensions as needed
    with torch.no_grad():  # Disable gradient computation
        output = generator(tensor)
    return output

def process(baseline_mrid, mrid, diagnosis, age_difference):
    MSE_v = []
    SSIM_v = []
    # load baseline image, ground truth image based on mrid
    data_folder_path = "/home/aarongzy/class/adni_data"
    baseline_file = os.path.join(data_folder_path, baseline_mrid, baseline_mrid+"_MNI152_registered.nii.gz")
    ground_truth_file = os.path.join(data_folder_path, mrid, mrid+"_MNI152_registered.nii.gz")
    ad = age_vector(int(age_difference))
    ho = encode_diagnosis(diagnosis).squeeze(-1)

    # read images
    baseline_img = nib.load(baseline_file).get_fdata()
    ground_truth_img = nib.load(ground_truth_file).get_fdata()

    # Compute the 99.5th percentile intensity value
    baseline_image_percentile = np.percentile(baseline_img, 99.5)
    ground_truth_image_percentile = np.percentile(ground_truth_img, 99.5)
    
    # Rescale the intensities
    baseline_img = np.clip(baseline_img, 0, baseline_image_percentile)
    ground_truth_img = np.clip(ground_truth_img, 0, ground_truth_image_percentile)
    
    # Normalize pixel values to range [-1, 1]
    baseline_img_normalized = (baseline_img / np.abs(np.max(baseline_img))) * 2 - 1
    ground_truth_img_normalized = (ground_truth_img / np.abs(np.max(ground_truth_img))) * 2 - 1

    baseline_img_normalized = baseline_img
    # ground_truth_img_normalized = ground_truth_img
    
    # Convert numpy arrays to torch tensors and send them to CUDA device
    baseline_img_tensor = torch.tensor(baseline_img_normalized, device=device, dtype=torch.float32)
    ground_truth_img_tensor = torch.tensor(ground_truth_img_normalized, device=device, dtype=torch.float32)

    # convert age_difference and diagnosis to torch tensors and send them to CUDA device
    ad = torch.tensor(ad, device=device, dtype=torch.float32)
    ho = torch.tensor(ho, device=device, dtype=torch.float32)
    
    # Slice and process image tensors on GPU
    slice_count = baseline_img_tensor.shape[2]
    central_start = max(slice_count // 2 - 2, 0)
    central_end = min(central_start + 2, slice_count)
    ground_truth_img_central_slices = ground_truth_img_tensor[:, :, central_start:central_end]
    baseline_img_central_slices = baseline_img_tensor[:, :, central_start:central_end]
    
    resize = torch.nn.functional.interpolate
    for i in range(ground_truth_img_central_slices.shape[2]):
        # print(".") 
        ground_truth_slice = ground_truth_img_central_slices[:, :, i].unsqueeze(0).unsqueeze(0)
        baseline_slice = baseline_img_central_slices[:, :, i].unsqueeze(0).unsqueeze(0)
    
        # Resize images on GPU
        ground_truth_img_reshaped = resize(ground_truth_slice, size=(208, 160), mode='bilinear', align_corners=False).squeeze(0)
        ground_truth_img_reshaped = torch.unsqueeze(ground_truth_img_reshaped, dim=1)
        baseline_img_reshaped = resize(baseline_slice, size=(208, 160), mode='bilinear', align_corners=False).squeeze(0)

        # send age_difference and diagnosis to cuda
        # pass baseline through generator
        ho = ho.view(1,2)
        ad = ad.view(1,100)
        baseline_img_reshaped = baseline_img_reshaped.view(1,1,208,160)
        # print(baseline_img_reshaped.shape)
        # print(ho.shape)
        # print(ad.shape)
        with torch.no_grad():
            pred_img = gen(baseline_img_reshaped, ho, ad) 
        
        # do a min-max normalization on both pred_img and ground_truth_slice
        pred_img_norm = (pred_img - torch.min(pred_img)) / (torch.max(pred_img) - torch.min(pred_img))
        # pred_img_norm = pred_img
        ground_truth_norm = (ground_truth_img_reshaped - torch.min(ground_truth_img_reshaped)) / (torch.max(ground_truth_img_reshaped) - torch.min(ground_truth_img_reshaped))
        print("Checking dimensions")
        print("pred", pred_img_norm.shape)
        print("gt:", ground_truth_norm.shape)
        print("Checking intensity range")
        print("pred min:", torch.min(pred_img_norm).item(), "pred max:", torch.max(pred_img_norm).item())
        print("gt min:", torch.min(ground_truth_norm).item(), "gt max:", torch.max(ground_truth_norm).item())
        print("pred var:", torch.var(pred_img_norm))
        print("gt var:", torch.var(ground_truth_norm))
        

        # calculate MSE
        with torch.no_grad():
            mse = torch.nn.functional.mse_loss(pred_img_norm, ground_truth_norm)
            MSE_v.append(mse.item())

        # calculate SSIM
        ssim = SSIM(data_range=1.0).to(device)
        ssim_val = ssim(pred_img_norm, ground_truth_norm)
        SSIM_v.append(ssim_val.item())

        print("MSE:", mse.item())
        print("SSIM:", ssim_val.item())
        print("--------")

        # Convert tensors to numpy arrays
        pred_img_np = pred_img_norm.cpu().squeeze(0).numpy()
        ground_truth_np = ground_truth_norm.cpu().squeeze(0).numpy()
        raw_output_np = pred_img.cpu().squeeze(0).numpy()

        # Create NIfTI image objects
        pred_img_nifti = nib.Nifti1Image(pred_img_np, affine=np.eye(4))  # Assuming identity affine matrix
        ground_truth_nifti = nib.Nifti1Image(ground_truth_np, affine=np.eye(4))  # Assuming identity affine matrix
        raw_output_np = nib.Nifti1Image(raw_output_np, affine=np.eye(4))

        # Save the NIfTI images as .nii.gz files
        nib.save(pred_img_nifti, "predicted_image.nii.gz")
        nib.save(ground_truth_nifti, "ground_truth_image.nii.gz")
        nib.save(raw_output_np, "raw_output.nii.gz")

        # Convert tensors to PIL images
        pred_img_pil = TF.to_pil_image(pred_img_norm.cpu().squeeze(0))
        ground_truth_pil = TF.to_pil_image(ground_truth_norm.cpu().squeeze(0))

        # # Plotting
        # fig, axes = plt.subplots(1, 2, figsize=(10, 5))

        # # Plot pred_img_norm
        # axes[0].imshow(pred_img_pil, cmap='gray')
        # axes[0].set_title('Predicted Image (Normalized)')
        # axes[0].axis('off')

        # # Plot ground_truth_norm
        # axes[1].imshow(ground_truth_pil, cmap='gray')
        # axes[1].set_title('Ground Truth Image (Normalized)')
        # axes[1].axis('off')

        # plt.show()
    # return metric lists
    return MSE_v, SSIM_v   

"""
for patient-ID:
    # get 1st scan of the patient
    # get the slices
    for mrid # of the same patient
        # get ho, ad encoding
        # slice ground truth
        for slices: # of the same mrid
            # call generator on corresponding 1st slice with ho, ad parameter
            # calculate MSE and SSI, append it to a list for the current health state and paitnet (1D array)
            # export the predicted image
        # concatenate all MSE and SSI arrays for the same patient (2D matrix)
    # concatenate by patient (3D matrix)
"""
mse_all = []
ssim_all = []
for patient_id in final_table['Patient ID'].unique():
    patient_data = final_table[final_table['Patient ID'] == patient_id]
    # Grab the first entry as the baseline
    baseline_scan = patient_data.iloc[0]
    baseline_mrid = baseline_scan['MRID']
    baseline_age_rounded = baseline_scan['Age-rounded']
    baseline_diagnosis = baseline_scan['Diagnosis']
    baseline_age_difference = baseline_scan['age-difference']
    assert baseline_age_difference==0.0
    

    # Iterate over each scan for the current patient
    counter = 0
    mse_cur_patient = []
    ssim_cur_patient = []
    for index, scan in patient_data.iterrows():
        counter += 1
        # if counter == 6:
        #     break
        mrid = scan['MRID']
        diagnosis = scan['Diagnosis']
        age_difference = scan['age-difference']

        if (age_difference != 0.0):
            # process
            mse, ssim = process(baseline_mrid, mrid, diagnosis, age_difference)
            mse_cur_patient.append(mse)
            ssim_cur_patient.append(ssim)
    mse_all.append(mse_cur_patient)
    ssim_all.append(ssim_cur_patient)



















env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1
Using device: cuda




Checking dimensions
pred torch.Size([1, 1, 208, 160])
gt: torch.Size([1, 1, 208, 160])
Checking intensity range
pred min: 0.0 pred max: 1.0
gt min: 0.0 gt max: 1.0
pred var: tensor(0.1118, device='cuda:0')
gt var: tensor(0.1102, device='cuda:0')
MSE: 0.0019878267776221037
SSIM: 0.9295975565910339
--------
Checking dimensions
pred torch.Size([1, 1, 208, 160])
gt: torch.Size([1, 1, 208, 160])
Checking intensity range
pred min: 0.0 pred max: 1.0
gt min: 0.0 gt max: 1.0
pred var: tensor(0.1123, device='cuda:0')
gt var: tensor(0.1103, device='cuda:0')
MSE: 0.0019771938677877188
SSIM: 0.9306945204734802
--------
Checking dimensions
pred torch.Size([1, 1, 208, 160])
gt: torch.Size([1, 1, 208, 160])
Checking intensity range
pred min: 0.0 pred max: 1.0
gt min: 0.0 gt max: 1.0
pred var: tensor(0.1118, device='cuda:0')
gt var: tensor(0.1162, device='cuda:0')
MSE: 0.0032336043659597635
SSIM: 0.905687153339386
--------
Checking dimensions
pred torch.Size([1, 1, 208, 160])
gt: torch.Size([1, 1, 208,

In [19]:
# final_table

In [20]:
import statistics

# Flatten the nested list

flattened_mse = [item for sublist in mse_all for item in sublist]
flattened_flattened_mse = [item for sublist in flattened_mse for item in sublist]

# Calculate the mean
mean_mse = statistics.mean(flattened_flattened_mse)
print("Mean MSE:", mean_mse)
# Calculate the standard deviation for MSE
std_dev_mse = statistics.stdev(flattened_flattened_mse)
print("Standard Deviation MSE:", std_dev_mse)


flattened_ssim = [item for sublist in ssim_all for item in sublist]
flattened_flattened_ssim = [item for sublist in flattened_ssim for item in sublist]

# calculate mean SSMI
mean_ssim = statistics.mean(flattened_flattened_ssim)
print("Mean SSMI:", mean_ssim)
# Calculate the standard deviation for SSIM
std_dev_ssim = statistics.stdev(flattened_flattened_ssim)
print("Standard Deviation SSIM:", std_dev_ssim)


Mean MSE: 0.0028943189307548372
Standard Deviation MSE: 0.0034772926404405776
Mean SSMI: 0.9107203973191125
Standard Deviation SSIM: 0.05110082324056769
