In [None]:
from settings import *
from utils import SpatialTransform, load_data, CoordsImageTest
from INRMorph import INRMorph
import pandas as pd
from matplotlib.animation import FuncAnimation
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas



In [None]:
def set_wandb_logger():
    try:
        with open('logger_name.txt', 'r') as f:
            lines = f.readlines()
            if lines:
                logger_name = lines[-1].strip()
            else:
                logger_name = None
    except FileNotFoundError:
        print("The file 'logger_name.txt' does not exist.")
        logger_name = None
    except Exception as e:
        print(f"An error occurred while reading the file: {e}")
        logger_name = None
        
    run = wandb.init(
        project="INRMorph",
        name="visualizations_"+logger_name
        )
    return run, logger_name


def jacobian_determinanat(coords, deformation_field):
    jac = compute_jacobian_matrix(coords, deformation_field)
    return torch.det(jac)

        

def compute_jacobian_matrix(coords, deformation_field):

    dim = coords.shape[1]
    # print("in jac", coords.shape, deformation_field.shape, dim)
    jacobian_matrix = torch.zeros(coords.shape[0], dim, dim)

    for i in range(dim):

        jacobian_matrix[:, i, :] = gradient(coords, deformation_field[:, i])
        # print("jacobian matrix", jacobian_matrix.shape)
    return jacobian_matrix        

    
def gradient(coords, output, grad_outputs=None):
    # print("in grad", coords.shape, output.shape)

    grad_outputs = torch.ones_like(output)
 
    grad = torch.autograd.grad(output, [coords], grad_outputs=grad_outputs, create_graph=True)[0]

    return grad         


In [None]:
transform = SpatialTransform()
validation_batch_size = 1
spatial_reg = 0.1
temporal_reg = 0.1
observed_time = [0, 13, 14, 24]
time = [0, 1, 2, 5, 7, 10, 12, 13, 14, 16, 18, 20, 22, 24, 27] #0, 13, 14, 24
time = torch.tensor(time, device=device)
normalized_time_points = time/12
patch_size = 1250
project_name = "INRMorph"
# batch_size = 13520000

imagepath = "dataset/ad/005_S_0814/resampled/"
maskpath = "dataset_copy/affine_registered/masks/"

os.environ["WANDB_NOTEBOOK_NAME"] = "INRMorph"
I0 = load_data(imagepath + "I0.nii")

I0_mask = load_data(maskpath + "I0_fsl_mask.nii.gz")

image_vector = CoordsImageTest(I0.shape, scale_factor = 1)
test_generator = DataLoader(dataset = image_vector, batch_size=patch_size, shuffle = False)



#load model and artifact
run, logger_name = set_wandb_logger()
print("Testing for run with model name: ", logger_name)
model = INRMorph(I0, I0, patch_size, spatial_reg, temporal_reg, 4, "siren", "NCC", "finite_difference", "L2",  normalized_time_points, observed_time).to(device)
artifact = run.use_artifact(f'aishalawal/INRMorph/{logger_name}:best', type='model')
artifact_dir = artifact.download()
model.load_state_dict(torch.load(artifact_dir + "/model.ckpt")["state_dict"])
model.eval()
print(model)

print("Number of parameters", sum(p.numel() for p in model.parameters() if p.requires_grad))
print("Model successfully loaded")

In [None]:
observed_time = [0, 13, 14, 24]
time = [0, 1, 2, 5, 7, 10, 12, 13, 14, 16, 18, 20, 22, 24, 30] 
stack_total_deformation_field = []
stack_total_jac_det = []
stack_tdf = []
stack_moved_images = []
for selected_time in time:
    
    selected_time_normalised = torch.tensor(selected_time/12, device=device) 

    for k, coords in enumerate(test_generator):
        coords = coords.squeeze().to(device, dtype=torch.float32).requires_grad_(True)
        displacement_vector = model.test_step(coords, selected_time_normalised).squeeze().to(device)
        deformation_field = torch.add(displacement_vector, coords)

        #compute jacobian determinant batchwise
        jac_det = jacobian_determinanat(coords, deformation_field) #shape is batch_size

        coords = coords.cpu().detach()
        deformation_field = deformation_field.cpu().detach()
        jac_det = jac_det.cpu().detach()
        
        if k==0:

            total_jac_det = jac_det
            total_deformation_field = deformation_field
        else:
            total_jac_det = torch.cat((total_jac_det, jac_det))
            total_deformation_field = torch.cat((total_deformation_field, deformation_field), 0)

    total_deformation_field = total_deformation_field.view(-1, 3).unsqueeze(0)
    stack_total_deformation_field.append(total_deformation_field)
    moved = transform.trilinear_interpolation(total_deformation_field.to(device), I0).view(I0.shape)  
    stack_moved_images.append(moved.cpu().numpy().squeeze())

    stack_total_jac_det.append(total_jac_det.view(I0.shape))
    stack_tdf.append(total_deformation_field.view(*I0.shape,3))
    print(f"Time step {selected_time} done")
I0 = I0.cpu().numpy()
I0_mask = I0_mask.cpu().numpy()
    

In [None]:
#copies
stack_total_jac_det_copy = stack_total_jac_det.copy()
stack_tdf_copy = stack_tdf.copy()
stack_moved_images_copy = stack_moved_images.copy()
stack_total_deformation_field_copy = stack_total_deformation_field.copy()

In [None]:
#visualising residuals


def image_masking(img, mask, tdf = False):
    rows, cols = np.where(mask > 0)
    x1, y1 = np.min(rows), np.min(cols)
    x2, y2 = np.max(rows), np.max(cols)
    
    if tdf == True:
        masked_image = img[x1-25:x2+20, y1:y2+35, :]
    else:
        masked_image = img[x1-25:x2+20, y1:y2+35]
    return masked_image


num_slice = 150
image_mask_2d = I0_mask[num_slice, :, :]
I0_2d = I0[num_slice, :, :]
masked_I0 = image_masking(I0_2d, image_mask_2d, False)
video_images = []
video_titles = []
for idx, selected_time in enumerate(time):
    if selected_time in observed_time:
        idx = observed_time.index(selected_time)
        It = load_data(imagepath + f"I{idx}.nii")
        # It_mask = load_data(maskpath + f"I{idx}_fsl_mask.nii.gz")
    else: 
        It =  torch.zeros_like(torch.tensor(I0))
        # It_mask = torch.zeros_like(torch.tensor(I0_mask))

    It = It.cpu().numpy()
    # It_mask = It_mask.cpu().numpy()

    It_2d = It[num_slice, :, :]
    moved_2d = stack_moved_images[idx][num_slice, :, :]
    tdf_2d = stack_tdf[idx][num_slice, :, :].cpu().numpy()
    plot_jac_det_2d = stack_total_jac_det[idx][num_slice, :, :].cpu().numpy()
    masked_It = image_masking(It_2d, image_mask_2d, False)

    masked_moved = image_masking(moved_2d, image_mask_2d, False)
    masked_tdf = image_masking(tdf_2d, image_mask_2d, True )
    masked_jac_det = image_masking(plot_jac_det_2d, image_mask_2d, False)

    residual = masked_moved - masked_It if selected_time in observed_time else masked_It
    
    
    images = [masked_I0, masked_It, masked_moved,  residual , masked_tdf[..., 0]] 
    # titles = [r'Baseline $I_0$', r'Observed $I_t={}$'.format(selected_time), r'$I_t\' (I_0 \circ \phi_t={})$'.format(selected_time), 'It\'-It', 'flow']
    titles = [
    r'Baseline $I_0$', 
    r'Observed $I_{t=' + str(selected_time) + '}$', 
    r"$I_{t}' = (I_0 \circ \phi_{t=" + str(selected_time) + "})$", 
    r"$I_{t}' - I_t$", 
    'flow'
    ]
    video_images.append(images)
    video_titles.append(titles)
    ne.plot.slices(images, titles=titles, cmaps=['gray'], do_colorbars=True);

    wandb.log({"Transformed images": [wandb.Image(image, caption=title) for image, title in zip(images, titles)]})



#### EVALUATION |J| MAPS

In [None]:


# from monai-> https://github.com/Project-MONAI/tutorials/blob/main/deep_atlas/utils.py


def preview_image(jacobian_det, mask, I0, selected_time, normalize_by="volume", cmap=None, figsize=(12, 12), threshold=None):
    """
    Display three orthogonal slices of the given 3D image and the corresponding I0 slices.

    jacobian_det is assumed to be of shape (H,W,D)

    If a number is provided for threshold, then pixels for which the value
    is below the threshold will be shown in red
    """
    if normalize_by == "slice":
        vmin = None
        vmax = None
        
    elif normalize_by == "volume":
        vmin = jacobian_det.min().item()
        # vmax = jacobian_det.max().item()
        vmax = jacobian_det.max().item()

    else:
        raise (ValueError(f"Invalid value '{normalize_by}' given for normalize_by"))

    # half-way slices
    rows, cols, dep = np.where(mask > 0)
    x1, x2, = min(rows), max(rows)
    y1, y2 = min(cols), max(cols)
    z1, z2 = min(dep), max(dep)

    x, y, z = np.array(jacobian_det.shape) // 2
    jac_det_slices = (jacobian_det[x, :, :], jacobian_det[:, y, :], jacobian_det[:, :, z])

    I0_slices = (I0[x, :, :], I0[:, y, :], I0[:, :, z])

    # jac_det_slices = (jacobian_det[x, y1:y2, z1:z2], jacobian_det[x, y1:y2, z1:z2], jacobian_det[x1:x2, y1:y2, z])

    # I0_slices = (I0[x, y1:y2, z1:z2], I0[x1:x2, y, z1:z2], I0[x1:x2, y1:y2, z])
    fig, axs = plt.subplots(2, 3, figsize=figsize)

    for i, (img, I0_slice) in enumerate(zip(jac_det_slices, I0_slices)):
        ax = axs[0, i]
        # ax.axis("off")
        gg = ax.imshow(img, origin="lower", vmin=vmin, vmax=vmax, cmap="viridis")
        if threshold is not None:
            red = np.zeros(img.shape + (4,))  # RGBA array
            red[img <= threshold] = [1, 0, 0, 1]
            ax.imshow(red, origin="upper")

        # Add colorbar for the image slice
        cbar = fig.colorbar(gg, ax=ax, fraction=0.046, pad=0.05)
        # cbar.ax.tick_params(labelsize=8)

        ax = axs[1, i]
        ax.axis("off")
        hh = ax.imshow(I0_slice, origin="upper", cmap="gray")
        # hh = ax.imshow(I0_slice, origin="upper", cmap=cmap)


        # Add colorbar for the I0 slice
        cbar = fig.colorbar(hh, ax=ax, fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=8)
        fig.tight_layout()
        
        text = f"""
        Time: {selected_time}
        Number of folds: {(det<=0).sum()}
        Number of expansion: {(det>1).sum()}
        Number of contraction: {((det > 0) & (det < 1)).sum()}
        Number of voxels without change: {(det==1).sum()}
        |J| min, max and mean: {float(det.min().item()):.2f}, {float(det.max().item()):.2f}, {float(det.mean().item()):.3f}
        
        """
    ax.set_facecolor("yellow")
    plt.figtext(0.7, 0.90, text, ha="left", fontsize=10, bbox={"facecolor": "white", "edgecolor": "white", "alpha": 0, "pad": 5, "linewidth": 3})
    wandb.log({"Jacobian Maps": wandb.Image(fig)})
    plt.show()
    return fig

video_jacobian_maps = []
for idx, selected_time in enumerate(time):
    det = stack_total_jac_det[idx]
    fig =preview_image(det, I0_mask, stack_moved_images[idx], selected_time, normalize_by="slice", threshold=0)
    video_jacobian_maps.append(fig)
plt.close(fig)


#### EVALUATING SEGMENTATION

In [None]:
#warp the labbeled I0 with deformation field to compute dice between moved_I0 and It_seg


def load_seg_data(path: str) -> torch.tensor:
    data = np.array(nib.load(path).get_fdata())
    data = torch.tensor(data, device=device, dtype=torch.float32)
    return data

stack_moved_seg = []
video_images_seg = []
video_titles_seg = []
I0_seg = load_seg_data("dataset/ad/005_S_0814/labels/I0_seg.nii")
z = 150
image_mask_2d_seg = I0_mask[:, z, :] 
I0_2d_seg = I0_seg[:, z, :].cpu().numpy()


for idx, selected_time in enumerate(time):
    if selected_time in observed_time:
        idx = observed_time.index(selected_time)
        It_seg = load_seg_data(f"dataset/ad/005_S_0814/labels/I{idx}_seg.nii")
    else: 
        It_seg =  torch.zeros_like(torch.tensor(I0_seg))
    
    moved_seg = transform.nearest_neighbor_interpolation(stack_total_deformation_field[idx].to(device), I0_seg).view(I0_seg.shape)  
    stack_moved_seg.append(moved_seg.cpu().numpy().squeeze())


    It_2d_seg = It_seg[:, z, :].cpu().numpy()
    moved_2d_label = moved_seg[:, z, :].cpu().numpy()

    residual = moved_2d_label-It_2d_seg if selected_time in observed_time else It_2d_seg
    
    tdf_2d = stack_tdf[idx][:, z, :].cpu().numpy()
    masked_tdf = image_masking(tdf_2d, image_mask_2d, True )

    images = [I0_2d_seg, It_2d_seg, moved_2d_label,  residual, masked_tdf[..., 1]] 
    titles = [
        r'Baseline $I_0$', 
        r'Observed $S(I_{t=' + str(selected_time) + '})$', 
        r"$S(I_{t}') = S(I_0)  \circ \phi_{t=" + str(selected_time) + "}$", 
        r"$S(I_{t}') - S(I_t)$",
        'flow'
    ]
    video_images_seg.append(images)
    video_titles_seg.append(titles)
    
    ne.plot.slices(images, titles=titles, do_colorbars=True,  cmaps=['viridis']);

    wandb.log({"Transformed segmentation maps": [wandb.Image(image, caption=title) for image, title in zip(images, titles)]})




#### CREATING MEAN |J| MASKS AND SAVING RESULTS

In [None]:
def load_data(path: str) -> torch.tensor: #256, 256, 166
    data = np.array(nib.load(path).get_fdata())
    data = torch.tensor(data, device=device, dtype=torch.float32)
    return data

structures = [
    "lateral_ventricle",
    "thalamus",
    "caudate",
    "putamen",
    "pallidum",
    "hippocampus",
    "amygdala"
]

left_structures = [4, 10, 11, 12, 13, 17, 18]
right_structures = [43, 49, 50, 51, 52, 53, 54]
result_path = "result/"+logger_name+".csv"

def dice_score(label1, label2):
    intersection = np.sum(label1[label1 > 0] == label2[label1 > 0])
    dice = (2 * intersection) / (np.sum(label1) + np.sum(label2))
    return dice

def combine_labels(img1, img2, selected_time, state = "Dice Score"):

    total_dice = 0
    structure_dice = {}

    for i, structure in enumerate(structures):
        img1_combined = np.where(np.isin(img1, [left_structures[i], right_structures[i]]), 1, 0)
        img2_combined = np.where(np.isin(img2, [left_structures[i], right_structures[i]]), 1, 0)
        dice = dice_score(img1_combined, img2_combined)
        structure_dice[structure] = dice
        # print(f"{structure}: {dice:.4f}")
        total_dice+=dice
    if selected_time in observed_time:
        print(f"Time = {selected_time}, {state} : {total_dice/len(structures)}")
    return structure_dice, total_dice/len(structures)

def compute_mean_jac_det(moved_label):
    mean_jacobian_determinants = {}
    for i, structure in enumerate(structures):
        mask =  np.where(np.isin(moved_label, [left_structures[i], right_structures[i]]), 1, 0)
        
        jacobian_values = det[mask > 0]
        mean_jacobian_determinants[structure] = jacobian_values.mean().item()
    return mean_jacobian_determinants


for selected_time in time:
    selected_time_normalised = selected_time/12
    if selected_time in observed_time:
        idx = observed_time.index(selected_time)
        It_seg = load_data(f"dataset/ad/005_S_0814/labels/I{idx}_seg.nii")
    else: 
        It_seg =  torch.zeros_like(I0_seg)
    It_seg = It_seg.cpu().numpy()
         
    structure_dice_affine, total_dice_affine = combine_labels(I0_seg.cpu().numpy(), It_seg, selected_time,"Dice between I0 and It (affine)")
    structure_dice_target_predicted, total_dice_target_predicted = combine_labels(stack_moved_seg[idx], It_seg, selected_time, "Dice at It and It\'(target vs predicted)")

    det = stack_total_jac_det[idx]
    mean_jac_det = compute_mean_jac_det(stack_moved_seg[idx])
    number_of_folds = (det<=0).sum().item()
    number_of_expansions = (det>1).sum().item()
    number_of_contractions= ((det > 0) & (det < 1)).sum().item()
    no_volume_change =  (det==1).sum().item()

    if not os.path.exists(result_path):
        with open(result_path, 'w') as f:
            writer = csv.writer(f)
            writer.writerow(["subjectID", "selected_time", "structure", "structure_mean_jac_det", "jac_det_mean","jac_det_min", "jac_det_max","structure_dice_affine", "structure_dice_target_predicted",
                            "total_dice_affine", "total_dice_target_predicted", "number_of_folds", "number_of_expansions", "number_of_contractions", "no_volume_change"])

    try:

        with open(result_path, 'a') as f:
            if selected_time not in time: # when interpolating or extrapolating
                for i, structure in enumerate(structures):
                    writer = csv.writer(f)
                    rows = [imagepath.split("/")[2], selected_time, structure, mean_jac_det[structure], det.mean().item(),
                        det.min().item(), det.max().item(), np.nan, np.nan, 
                        np.nan, np.nan, number_of_folds, number_of_expansions, number_of_contractions, no_volume_change]
                    writer.writerow(rows)
                    
            else:
                for i, structure in enumerate(structures):
                    writer = csv.writer(f)
                    rows = [imagepath.split("/")[2], selected_time, structure, mean_jac_det[structure], det.mean().item(),
                        det.min().item(), det.max().item(), structure_dice_affine[structure], structure_dice_target_predicted[structure], 
                        total_dice_affine, total_dice_target_predicted, number_of_folds, number_of_expansions, number_of_contractions, no_volume_change]
                    writer.writerow(rows)

            
                
    except FileNotFoundError:
        print(f"Error: Could not write to file {result_path}")




In [None]:
#plotting graphs
df = pd.read_csv(result_path)
fig, ax = plt.subplots(figsize=(10, 4))

# Group the data by structure and plot the line plot
for structure in df['structure'].unique():
    structure_data = df[df['structure'] == structure]
    ax.plot(structure_data['selected_time'], structure_data['structure_mean_jac_det'], label=structure, marker='*', linestyle='-')

# Add labels and title
ax.set_xlabel('Time Point')
ax.set_ylabel('mean |J|')
ax.set_title('Mean |J| over Time')
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
ax.grid()
plt.show()
wandb.log({"Mean Jacobian Determinant Over time": wandb.Image(fig)})


fig, ax = plt.subplots(figsize=(10, 4))

affine_label_added = False
nonlinear_label_added = False

for time_point in df['selected_time'].unique():
    if time_point in observed_time:
        time_point_data = df[df['selected_time'] == time_point]
        # Plot affine line
        ax.plot(time_point, time_point_data['total_dice_affine'].mean(), 
                marker='*', linestyle='-', color='blue', label="Affine" if not affine_label_added else "")
        affine_label_added = True
        ax.plot(time_point, time_point_data['total_dice_target_predicted'].mean(), marker='*', linestyle='-', color='orange', label="Non-linear" if not nonlinear_label_added else "")
        nonlinear_label_added = True
ax.set_xlabel('Time Point')
ax.set_ylabel('Dice Cooefficient')
ax.set_title('Comparing affine and non-linear Dice Coefficient over Time')
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.show()
wandb.log({"Comparing affine and non-linear Dice Coefficient over Time": wandb.Image(fig)})



#### GENERATING JACOBIAN MAP VIDEO

In [None]:

fig, ax = plt.subplots(figsize=(10, 7))

#initialise function to display the first frame
def init():
    ax.clear()  #clear the axis
    ax.axis('off') 
    return []

#update function for each frame
def update(frame):
    ax.clear()  #clear previous frame
    ax.axis('off')

    #render the figure to a canvas
    canvas = FigureCanvas(video_jacobian_maps[frame])
    canvas.draw()
    
    #convert canvas to a numpy array
    img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(canvas.get_width_height()[::-1] + (3,))  #reshape to (height, width, 3)

    ax.imshow(img)
    return []

ani = FuncAnimation(fig, update, frames=len(video_jacobian_maps), init_func=init, blit=True, interval=1000)

video_filename = 'jacobian_maps_video.mp4'
ani.save(video_filename, writer='ffmpeg', fps=1,  codec='vp9', dpi=300)  #1 frame per second

wandb.log({"Jacobian Maps Video": wandb.Video(video_filename, format="mp4")})

plt.show()
plt.close(fig)
# plt.fig(video_jacobian_maps[0])

In [None]:

#number of time points and images per time point
num_time_points = len(video_images)
num_images = len(video_images[0])

#initialize a list to hold the composite images
composite_images = []

#create individual figures for each time point and store them
for t in range(num_time_points):
    fig, axs = plt.subplots(2, num_images, figsize=(10, 7))
    
    for i in range(num_images):
        axs[0, i].imshow(video_images[t][i], cmap='gray')
        axs[0, i].set_title(video_titles[t][i])
        axs[0, i].axis('off')
    
    for i in range(num_images):
        axs[1, i].imshow(video_images_seg[t][i], cmap='gray')
        axs[1, i].set_title(video_titles_seg[t][i])
        axs[1, i].axis('off')

    plt.tight_layout()
    
    #render the figure to a canvas and convert to a numpy array
    canvas = FigureCanvas(fig)
    canvas.draw()
    
    #convert the canvas to a numpy array
    img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(canvas.get_width_height()[::-1] + (3,))  #(height, width, 3)
    
    # Append the image to the list
    composite_images.append(img)
    plt.close(fig) 

#create the animation using the stored images
fig, ax = plt.subplots(figsize=(15, 7))

def init():
    ax.clear()
    ax.axis('off')
    return []

def update(frame):
    ax.clear()
    ax.axis('off')
    ax.imshow(composite_images[frame])
    return []

ani = FuncAnimation(fig, update, frames=len(composite_images), init_func=init, blit=True, interval=1000)

video_filename = 'time_lapse_transition.mp4'
ani.save(video_filename, writer='ffmpeg', fps=1, dpi=300, codec='vp9')
wandb.log({"Time Lapse Transition Video": wandb.Video(video_filename, format="mp4")})
plt.show()
plt.close(fig)