# Testing datamodule

## 1. Workspace setup

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [None]:
import sys
import os
import open3d as o3d
import numpy as np
import matplotlib.pyplot as plt
from spinediffusion.datamodule.datamodule import SpineDataModule

## 2. Instantiate datamodule

In [None]:
transform_args = {
    # "spine_length_normalize": {
    #     "transform_number": 0,
    #     "norm_length": 1,
    # },
    "resample_3d_curve": {
        "transform_number": 0,
        "n_points": 256,
    },
    "resample_point_cloud": {
        "transform_number": 1,
        "n_points": 65536,
        "method": "uniform",
    },
    "project_to_plane": {
        "transform_number": 2,
        "height": 124,
        "width": 124,
        "spine_factor": 1,
        "intensity": 1,
        "method": "median",
        "z_lims": [-150, 100],
    },
    "close_depthmap": {"transform_number": 3, "se_size": (3, 3)},
    "tensorize": {"transform_number": 4},
}

exclude_patients = {
    # "balgrist": [1, 5],
    # "croatian": [4, 5],
    # "italian": [],
    # "ukbb": [],
}

datamodule = SpineDataModule(
    data_dir="P:\\Projects\\LMB_4Dspine\\back_scan_database",
    batch_size=32,
    train_fraction=0.8,
    val_fraction=0.1,
    test_fraction=0.1,
    transform_args=transform_args,
    n_subjects=None,
    exclude_patients=exclude_patients,
    use_cache=True,
    cache_dir="../cache/",
)
datamodule.setup(stage=None)

In [None]:
datamodule._parse_datapaths()
datamodule._load_data()
datamodule._reformat_data()

data = datamodule.data.copy()

In [None]:
datamodule._preprocess_data()

data_upsampled = datamodule.data.copy()

In [None]:
means = {}
stds = {}

means_diff = []
stds_diff = []

for key in data.keys():
    backscan = data[key]["backscan"]
    backscan_upsampled = data_upsampled[key]["backscan"]

    means[key] = {}
    stds[key] = {}

    means[key]["original"] = np.mean(np.asarray(backscan.points), axis=0)
    stds[key]["original"] = np.std(np.asarray(backscan.points), axis=0)
    means[key]["upsampled"] = np.mean(np.asarray(backscan_upsampled.points), axis=0)
    stds[key]["upsampled"] = np.std(np.asarray(backscan_upsampled.points), axis=0)

In [None]:
x_mean_diff = []
y_mean_diff = []
z_mean_diff = []

x_std_diff = []
y_std_diff = []
z_std_diff = []

for key in means.keys():
    x_mean_diff.append(means[key]["original"][0] - means[key]["upsampled"][0])
    y_mean_diff.append(means[key]["original"][1] - means[key]["upsampled"][1])
    z_mean_diff.append(means[key]["original"][2] - means[key]["upsampled"][2])

    x_std_diff.append(stds[key]["original"][0] - stds[key]["upsampled"][0])
    y_std_diff.append(stds[key]["original"][1] - stds[key]["upsampled"][1])
    z_std_diff.append(stds[key]["original"][2] - stds[key]["upsampled"][2])

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(15, 10))

axs[0, 0].hist(x_mean_diff, bins=20)
axs[0, 0].set_title("X mean difference")
axs[0, 0].set_xlabel("Difference")
axs[0, 0].set_ylabel("Frequency")

axs[0, 1].hist(y_mean_diff, bins=20)
axs[0, 1].set_title("Y mean difference")
axs[0, 1].set_xlabel("Difference")
axs[0, 1].set_ylabel("Frequency")

axs[0, 2].hist(z_mean_diff, bins=20)
axs[0, 2].set_title("Z mean difference")
axs[0, 2].set_xlabel("Difference")
axs[0, 2].set_ylabel("Frequency")

axs[1, 0].hist(x_std_diff, bins=20)
axs[1, 0].set_title("X std difference")
axs[1, 0].set_xlabel("Difference")
axs[1, 0].set_ylabel("Frequency")

axs[1, 1].hist(y_std_diff, bins=20)
axs[1, 1].set_title("Y std difference")
axs[1, 1].set_xlabel("Difference")
axs[1, 1].set_ylabel("Frequency")

axs[1, 2].hist(z_std_diff, bins=20)
axs[1, 2].set_title("Z std difference")
axs[1, 2].set_xlabel("Difference")
axs[1, 2].set_ylabel("Frequency")

In [None]:
height = 128  # pixels

factor = 1

areas = []

for key in data.keys():
    special_points = data[key]["special_points"]
    C7 = special_points["C7"]
    DR = special_points["DR"]
    DL = special_points["DL"]

    DM = (DR + DL) / 2

    spine_length = np.linalg.norm(C7 - DM)  # length units

    spine_pixels = height * factor

    d_height = spine_length / spine_pixels

    areas.append(d_height**2)


plt.hist(areas, bins=20)
plt.title("Pixel area")
plt.xlabel("Area [$mm^2$]")
plt.ylabel("# subjects")

plt.show()

In [None]:
indices = np.random.choice(len(datamodule.data.keys()), 16, replace=False)

fig, axs = plt.subplots(4, 4, figsize=(15, 15))

for i, idx in enumerate(indices):
    key = list(datamodule.data.keys())[idx]

    depth_map = datamodule.data[key]["depth_map"]
    axs.flat[i].imshow(depth_map[0], cmap="gray")