# Testing datamodule

## 1. Workspace setup

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [None]:
import sys
import os
import json
import open3d as o3d
import numpy as np
import matplotlib.pyplot as plt
from spinediffusion.datamodule.datamodule import SpineDataModule

## 2. Instantiate datamodule

In [None]:
transform_args = {
    # "spine_length_normalize": {
    #     "transform_number": 0,
    #     "norm_length": 1,
    # },
    "resample_3d_curve": {
        "transform_number": 0,
        "n_points": 1024,
    },
    "resample_point_cloud": {
        "transform_number": 1,
        "n_points": 65536,
        "method": "uniform",
    },
    "project_to_plane": {
        "transform_number": 2,
        "height": 124,
        "width": 124,
        "spine_factor": 1,
        "intensity": 1,
        "method": "median",
        "z_lims": [-150, 100],
    },
    "close_depthmap": {"transform_number": 3, "se_size": (3, 3)},
    "tensorize": {"transform_number": 4},
}

augment_args = {
    "random_rotation": {"transform_number": 0, "theta_range": [0, 0.5], "num_aug": 10}
}

num_control_points = 4

isl_gen_mean = np.array([[2.06, 0.52, -58.01]] * num_control_points)
isl_gen_std = np.array(
    [[3.4, 0, 4.5], [14.31, 115.69, 25.91], [14.31, 115.69, 25.91], [3.4, 0, 4.5]]
)

sl_args = {
    "sl_mean": isl_gen_mean,
    "sl_std": isl_gen_std,
    "sample_method": "normal",
    "length": 200,
    "num_spl_points": 100,
    "project_args": transform_args["project_to_plane"],
}
exclude_patients = {
    # "balgrist": [1, 5],
    # "croatian": [4, 5],
    # "italian": [],
    # "ukbb": [],
}

datamodule = SpineDataModule(
    data_dir="P:\\Projects\\LMB_4Dspine\\back_scan_database",
    batch_size=32,
    train_fraction=0.8,
    val_fraction=0.1,
    test_fraction=0.1,
    transform_args=transform_args,
    augment_args=augment_args,
    num_subjects=10,
    exclude_patients=exclude_patients,
    use_cache=False,
    cache_dir="../cache/",
    predict_size=16,
    num_workers=1,
    conditional=True,
    sl_args=sl_args,
)
datamodule.setup(stage=None)

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(20, 10))


for i in range(10):
    depthmap = datamodule.data[f"balgrist_1_{i}"]["depth_map"]
    axs.flat[i].imshow(depthmap.squeeze(0), cmap="gray")
    axs.flat[i].axis("off")

plt.show()

In [None]:
num_samples = 3
sample_keys = np.random.choice(list(datamodule.data.keys()), num_samples, replace=False)

fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

for i, key in enumerate(sample_keys):
    sample = datamodule.data[key]

    axs[i, 0].imshow(sample["depth_map"].squeeze(0), cmap="gray")
    axs[i, 0].set_title("Back scan depth map")
    axs[i, 1].imshow(sample["esl_depth_map"], cmap="gray")
    axs[i, 1].set_title("ESL depth map")
    axs[i, 2].imshow(sample["isl_depth_map"], cmap="gray")
    axs[i, 2].set_title("ISL depth map")

plt.show()

In [None]:
# Same as above but with the esl and isl superimposed on the back scan with a different color make it a N x 4 subplot

num_samples = 16
sample_keys = np.random.choice(list(datamodule.data.keys()), num_samples, replace=False)

n_cols = 4
n_rows = int(np.ceil(num_samples / n_cols))

fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))

for i, key in enumerate(sample_keys):
    sample = datamodule.data[key]

    back_scan = sample["depth_map"].squeeze(0)
    esl = sample["esl_depth_map"]
    isl = sample["isl_depth_map"]

    back_scan = np.stack([back_scan, back_scan, back_scan], axis=-1)
    back_scan = back_scan / back_scan.max()

    esl = np.stack([esl, np.zeros_like(esl), np.zeros_like(esl)], axis=-1)
    isl = np.stack([np.zeros_like(isl), isl, np.zeros_like(isl)], axis=-1)

    superimposed = back_scan + esl + isl
    superimposed = superimposed / superimposed.max()

    axs.flat[i].imshow(superimposed)
    axs.flat[i].set_title(key)

    # add legend and remove axes
    axs.flat[i].axis("off")

    # Add the legend only once at the top right corner of the figure
    if i == 3:
        axs.flat[i].text(125, 10, "ESL", color="red", fontsize=24)
        axs.flat[i].text(125, 30, "ISL", color="green", fontsize=24)

plt.show()

In [None]:
datamodule.data.keys()

In [None]:
isls = np.concatenate([sample["isl"] for sample in datamodule.data.values()])
esls = np.concatenate([sample["esl"] for sample in datamodule.data.values()])

print(f"ISL mean: {isls.shape}")
print(f"ESL mean: {esls.shape}")

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(30, 30))

axs[0, 0].hist(isls[:, 0], bins=100)
axs[0, 0].set_title(f"ISL x histogram")
axs[0, 0].grid()
axs[0, 1].hist(isls[:, 1], bins=100)
axs[0, 1].set_title(f"ISL y histogram")
axs[0, 1].grid()
axs[0, 2].hist(isls[:, 2], bins=100)
axs[0, 2].set_title(f"ISL z histogram")
axs[0, 2].grid()

axs[1, 0].hist(esls[:, 0], bins=100)
axs[1, 0].set_title(f"ESL x histogram")
axs[1, 0].grid()
axs[1, 1].hist(esls[:, 1], bins=100)
axs[1, 1].set_title(f"ESL y histogram")
axs[1, 1].grid()
axs[1, 2].hist(esls[:, 2], bins=100)
axs[1, 2].set_title(f"ESL z histogram")
axs[1, 2].grid()

plt.show()

In [None]:
isl_mean = isls.mean(axis=0)
isl_std = isls.std(axis=0)

esl_mean = esls.mean(axis=0)
esl_std = esls.std(axis=0)

print(f"ISL mean: {isl_mean}")
print(f"ISL std: {isl_std}")
print(f"ESL mean: {esl_mean}")
print(f"ESL std: {esl_std}")

In [None]:
num_control_points = 5
isl_gen_mean = np.array([isl_mean] * num_control_points)
isl_gen_std = np.array([[3.4, 0, 4.5], isl_std, isl_std, isl_std, [3.4, 0, 4.5]])
isl_gen_std.shape

In [None]:
from spinediffusion.datamodule.sl_generator import SLGenerator

project_args = transform_args["project_to_plane"]

sl_generator = SLGenerator(
    sl_mean=isl_gen_mean,
    sl_std=isl_gen_std,
    sample_method="normal",
    length=200,
    num_spl_points=100,
    project_args=project_args,
)

In [None]:
splines = [next(sl_generator) for _ in range(16)]

if project_args:
    fig, axs = plt.subplots(4, 4, figsize=(20, 20))
else:
    fig, axs = plt.subplots(4, 4, figsize=(20, 20), subplot_kw={"projection": "3d"})

for i, spline in enumerate(splines):
    if project_args:
        # splines are images so use imshow with gray cmap
        axs.flat[i].imshow(spline, cmap="gray", vmin=0, vmax=1)
        axs.flat[i].set_title(f"Spline {i}")
        # remove axes
        axs.flat[i].axis("off")

    else:
        axs.flat[i].plot(spline[:, 0], spline[:, 1], spline[:, 2])
        axs.flat[i].set_title(f"Spline {i}")
        axs.flat[i].set_aspect("equal")

plt.show()

In [None]:
means = {}
stds = {}

means_diff = []
stds_diff = []

for key in data.keys():
    backscan = data[key]["backscan"]
    backscan_upsampled = data_upsampled[key]["backscan"]

    means[key] = {}
    stds[key] = {}

    means[key]["original"] = np.mean(np.asarray(backscan.points), axis=0)
    stds[key]["original"] = np.std(np.asarray(backscan.points), axis=0)
    means[key]["upsampled"] = np.mean(np.asarray(backscan_upsampled.points), axis=0)
    stds[key]["upsampled"] = np.std(np.asarray(backscan_upsampled.points), axis=0)

In [None]:
x_mean_diff = []
y_mean_diff = []
z_mean_diff = []

x_std_diff = []
y_std_diff = []
z_std_diff = []

for key in means.keys():
    x_mean_diff.append(means[key]["original"][0] - means[key]["upsampled"][0])
    y_mean_diff.append(means[key]["original"][1] - means[key]["upsampled"][1])
    z_mean_diff.append(means[key]["original"][2] - means[key]["upsampled"][2])

    x_std_diff.append(stds[key]["original"][0] - stds[key]["upsampled"][0])
    y_std_diff.append(stds[key]["original"][1] - stds[key]["upsampled"][1])
    z_std_diff.append(stds[key]["original"][2] - stds[key]["upsampled"][2])

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(15, 10))

axs[0, 0].hist(x_mean_diff, bins=20)
axs[0, 0].set_title("X mean difference")
axs[0, 0].set_xlabel("Difference")
axs[0, 0].set_ylabel("Frequency")

axs[0, 1].hist(y_mean_diff, bins=20)
axs[0, 1].set_title("Y mean difference")
axs[0, 1].set_xlabel("Difference")
axs[0, 1].set_ylabel("Frequency")

axs[0, 2].hist(z_mean_diff, bins=20)
axs[0, 2].set_title("Z mean difference")
axs[0, 2].set_xlabel("Difference")
axs[0, 2].set_ylabel("Frequency")

axs[1, 0].hist(x_std_diff, bins=20)
axs[1, 0].set_title("X std difference")
axs[1, 0].set_xlabel("Difference")
axs[1, 0].set_ylabel("Frequency")

axs[1, 1].hist(y_std_diff, bins=20)
axs[1, 1].set_title("Y std difference")
axs[1, 1].set_xlabel("Difference")
axs[1, 1].set_ylabel("Frequency")

axs[1, 2].hist(z_std_diff, bins=20)
axs[1, 2].set_title("Z std difference")
axs[1, 2].set_xlabel("Difference")
axs[1, 2].set_ylabel("Frequency")

In [None]:
height = 128  # pixels

factor = 1

areas = []

for key in data.keys():
    special_points = data[key]["special_points"]
    C7 = special_points["C7"]
    DR = special_points["DR"]
    DL = special_points["DL"]

    DM = (DR + DL) / 2

    spine_length = np.linalg.norm(C7 - DM)  # length units

    spine_pixels = height * factor

    d_height = spine_length / spine_pixels

    areas.append(d_height**2)


plt.hist(areas, bins=20)
plt.title("Pixel area")
plt.xlabel("Area [$mm^2$]")
plt.ylabel("# subjects")

plt.show()

In [None]:
indices = np.random.choice(len(datamodule.data.keys()), 16, replace=False)

fig, axs = plt.subplots(4, 4, figsize=(15, 15))

for i, idx in enumerate(indices):
    key = list(datamodule.data.keys())[idx]

    depth_map = datamodule.data[key]["depth_map"]
    axs.flat[i].imshow(depth_map[0], cmap="gray")