In [1]:
import glob
import os
import torch
import sys
from metrics_cond import *
import slice_view
import torch.nn.functional as F
import nibabel as nib
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm


In [2]:
# Normalize tensors
def normalize_tensor(tensor, min_range, max_range):
    min_vals = torch.min(tensor, dim=2).values.min(dim=2).values
    max_vals = torch.max(tensor, dim=2).values.max(dim=2).values
    normalized_tensor = (tensor - min_vals.unsqueeze(2).unsqueeze(3)) / (max_vals.unsqueeze(2).unsqueeze(3) - min_vals.unsqueeze(2).unsqueeze(3))
    normalized_tensor = normalized_tensor * (max_range - min_range) + min_range
    return normalized_tensor

In [3]:
# Model call
import AD_DAE_model_call as model_call
model = model_call.AD_DAE_model_call_func()

Seed set to 0


Model params: 129.18 M


In [4]:
# Retrieve model configuration from the model object
conf = model.conf

# Set device to CUDA for GPU acceleration
device = 'cuda'

# Create training dataset using configuration parameters
# - Loads data from specified paths, including CSV and HDF5 files
# - Includes ventricle mask for region-specific analysis (e.g., AD progression)
model.train_data = model.conf.make_dataset(path=model.conf.data_config_path,
                                               csv_path=model.conf.csv_path,
                                               h5_save_path=model.conf.h5_save_path_train,
                                               csv_file_name=model.conf.csv_file_name_train, 
                                               csv_mask_name=model.conf.csv_mask_name_train,
                                               ventricle_mask_root_path=model.conf.ventricle_mask_root_path)

# Print the number of training samples
print('train data:', len(model.train_data))

# Create validation dataset using configuration parameters
# - Similar to training dataset but uses test-specific paths
model.val_data = model.conf.make_dataset(path=model.conf.data_config_path,
                                         csv_path=model.conf.csv_path_test,
                                         h5_save_path=model.conf.h5_save_path_test,
                                         csv_file_name=model.conf.csv_file_name_test, 
                                         csv_mask_name=model.conf.csv_mask_name_test,
                                         ventricle_mask_root_path=model.conf.ventricle_mask_root_path)

# Print the number of validation samples
print('val data:', len(model.val_data))

# Define batch sizes for training and evaluation
batch_size = 50
batch_size_eval = 50

# Create DataLoader for training data
# - shuffle=False: Maintains data order
# - num_workers=1: Single worker for data loading
# - pin_memory=True: Optimizes data transfer to GPU
train_dataloader = DataLoader(model.train_data, shuffle=False, batch_size=batch_size, num_workers=1, pin_memory=True)

# Create DataLoader for validation data
# - Similar configuration to training DataLoader
test_dataloader = DataLoader(model.val_data, shuffle=False, batch_size=batch_size_eval, num_workers=1, pin_memory=True)


self.data_name ADNI
['__image', '__baseline_image'] ['__label', '__baseline_label']
in fuction ['__image', '__baseline_image', '__label', '__baseline_label']
train data: 75350
self.data_name ADNI
['__image', '__baseline_image'] ['__label', '__baseline_label']
in fuction ['__image', '__baseline_image', '__label', '__baseline_label']
val data: 31100


In [7]:

# Define a lambda function to convert a list of ages to a stacked tensor of floats
str_float_tensor = lambda age_list: torch.stack([torch.tensor(float(age)) for age in age_list])

# Initialize list to store results (commented out in original code)
# save_results = []

# Iterate over training DataLoader with progress bar
for batch in tqdm(train_dataloader, total=len(train_dataloader)):
    # Flag to start generation from noise (True) or encoded latent (False)
    start_from_noise = True  # True

    # Set mode to 'train' for accessing training-specific batch keys
    mode = 'train'

    # Disable gradient computation for evaluation-like inference
    with torch.no_grad():
        # Set model to evaluation mode
        model.eval()
        # Move model to GPU
        model.cuda()
        # Set latent shift predictor to evaluation mode
        model.model.eval()
        # Move latent shift predictor to GPU
        model.model.cuda()

    # Initialize dictionary to store batch results
    save_dict_ = {}

    # Extract ventricle mask from batch for region-specific processing
    ventricle_mask_batch = batch['ventricle_mask']

    # Normalize training images to [0, 1] range
    x_start = normalize_tensor(batch['_'+mode+'_image'], 0, 1)

    # Normalize baseline images to [0, 1] range
    x_start_baseline = normalize_tensor(batch['_'+mode+'_baseline_image'], 0, 1)

    # Move baseline images to GPU
    x_start_baseline = x_start_baseline.cuda()

    # Extract indices from batch
    idxs = batch['idx']

    #################################
    # Calculate age difference between follow-up and baseline ages
    age_diff = model.str_list_tensor(batch['Age']) - model.str_list_tensor(batch['baseline Age'])
    # Ensure age_diff matches the image tensor's data type
    age_diff = age_diff.to(batch['_'+mode+'_image'].dtype)

    # Initialize condition vector (12 dimensions) for progression attributes
    cond_vector = torch.zeros(batch['_'+mode+'_image'].shape[0], 12)

    # Get condition vector, shifts, and shifted condition vector from batch
    # - Incorporates cognitive status (v_d) and age gap (v_a)
    cond_vector, shifts, cond_vector_shift = model.get_data_elements(batch, age_diff, cond_vector)

    # Assign shifted condition vector as health state
    health_state = cond_vector_shift

    # Move shifted condition vector to same dtype and device as age_diff
    cond_vector_shift = cond_vector_shift.to(age_diff.dtype).to(age_diff.device)

    # Move condition vector to same dtype and device
    cond_vector = cond_vector.to(age_diff.dtype).to(age_diff.device)

    # Move normalized training images to GPU
    x_start = x_start.cuda()

    # Generate random noise for diffusion process (same shape as x_start)
    noise = torch.randn(x_start.shape[0], 1, x_start.shape[2], x_start.shape[3], device=x_start.device)

    # Set ventricle mask for sampler (used in rendering)
    model.sampler.ventricle_mask_batch = ventricle_mask_batch

    # Set age shift for sampler and move to correct dtype/device
    model.sampler.age_shift = shifts.to(batch['_'+mode+'_image'].dtype).to(batch['_'+mode+'_image'].device)

    # Set weight for condition shift in sampler
    model.sampler.cond_shift_weight = 1

    # Get latent shift predictor module
    shift_predictor = model.model.latent_shift_predictor

    # Set predictor to evaluation mode
    shift_predictor.eval()

    # Predict latent shift based on shifted condition vector
    shift_new = shift_predictor.forward(cond_vector_shift.cuda())

    # Encode baseline image to latent representation
    cond_baseline = model.encode(batch['_'+mode+'_baseline_image'].to(device))

    # Encode baseline image stochastically with 200 diffusion steps
    # - Uses baseline image, condition, and age/health attributes
    xT = model.encode_stochastic(batch['_'+mode+'_baseline_image'].to(device), cond_baseline, T=200,
                                x_start_baseline=x_start_baseline.cuda(), age_diff=age_diff.cuda(),
                                health_state=health_state.cuda())

    # Clone predicted shift for modification
    shift_new_modf = shift_new.clone()

    # Adjust shift for cases where age difference is zero
    # - Copies last 50 dimensions to first 50 dimensions
    for ind_, diff in enumerate(age_diff):
        if diff == 0:
            shift_new_modf[ind_][0:50] = shift_new_modf[ind_][450:500]

    # Choose starting point: noise or encoded latent
    if start_from_noise:
        start_ = noise
    else:
        start_ = xT

    # Render follow-up image
    pred_followup_xT_shift_ = model.render(start_, {'cond': cond_baseline + (shift_new_modf)}, T=50,
                                           mask_mult=False,
                                           health_state=cond_vector_shift.to(device))

    # Apply label mask to predicted follow-up image
    pred_followup_xT_shift = pred_followup_xT_shift_ * batch['_'+mode+'_label'].to(torch.float32).cuda()

    # Apply label mask to normalized training image
    x_start_ = x_start * batch['_'+mode+'_label'].to(torch.float32).cuda()

    # Apply label mask to normalized baseline image
    x_start_baseline_ = x_start_baseline * batch['_'+mode+'_label'].to(torch.float32).cuda()

    # Store results in a list (commented out in original code)
    # save_results.append({"pred_followup_xT_shift":pred_followup_xT_shift,\
    #                      "x_start_":x_start_,\
    #                      "x_start_baseline_":x_start_baseline_})

    # Convert baseline and follow-up ages to tensors
    starting_age = torch.tensor(float(batch['baseline Age'][0]))
    followup_age = torch.tensor(float(batch['Age'][0]))

    # Print subject information for debugging
    print('Subject:', batch['nii path'][0].split('/')[-1])
    print("Baseline Age: ", batch['baseline Age'][0])
    print('Follow-up Age: ', batch['Age'][0])
    print('Cognitive State: ', batch['Health status'][0])

    # Create dictionary to save results
    save_dict = {'pred_followup_xT_shift': pred_followup_xT_shift,
                 'x_start_': x_start_,
                 'x_start_baseline_': x_start_baseline_,
                 'baseline Age': batch['baseline Age'][0],
                 'Age': batch['Age'][0],
                 'Health status': batch['Health status'],
                 'baseline nii path': batch['baseline nii path'][0],
                 'nii path': batch['nii path'][0]
                 }

    # Generate unique ID for saving results
    uniq_id = batch['Subject'][0]+'_'+str(round(followup_age.item(),2))+'_'+str(round(starting_age.item(),2))+'.pt'

    # Define save path for results
    AD_DAE_save_path = './save_results/'

    # Create save directory if it doesn't exist
    if not os.path.exists(AD_DAE_save_path):
        os.mkdir(AD_DAE_save_path)

    # Save results dictionary as a PyTorch file
    torch.save(save_dict, AD_DAE_save_path+uniq_id)

    # Exit loop after processing one batch
    break

  0%|                                                                                                                                                                                                                                                          | 0/1507 [01:09<?, ?it/s]

Subject: ADNI_033_S_1285_MR_MPR__GradWarp__B1_Correction__N3__Scaled_Br_20090317135732405_S63539_I139246.nii
Baseline Age:  81.7
Follow-up Age:  82.2
Cognitive State:  AD





In [9]:
#3D Loading of Ground-truth
slc = slice_view.slicer((save_dict['x_start_'])[:,0,:,:].detach().cpu())
slc.slicer_view()

interactive(children=(Dropdown(description='slice_view', options=('x', 'y', 'z'), value='x'), IntSlider(value=…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>

In [11]:
#3D Loading of Generated Images
slc = slice_view.slicer((save_dict['pred_followup_xT_shift'])[:,0,:,:].detach().cpu())
slc.slicer_view()

interactive(children=(Dropdown(description='slice_view', options=('x', 'y', 'z'), value='x'), IntSlider(value=…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>