# Testing datamodule

## 1. Workspace setup

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [None]:
import sys
import os
import open3d as o3d
import numpy as np
import matplotlib.pyplot as plt
from spinediffusion.datamodule.datamodule import SpineDataModule

## 2. Instantiate datamodule

In [None]:
transform_args = {
    "spine_length_normalize": {
        "transform_number": 0,
        "norm_length": 1,
    },
    "resample_3d_curve": {
        "transform_number": 1,
        "n_points": 256,
    },
    "resample_point_cloud": {
        "transform_number": 2,
        "n_points": 65536,
        "method": "uniform",
    },
    # "project_to_plane": {
    #     "transform_number": 3,
    #     "height": 124,
    #     "width": 124,
    #     "intensity": 1,
    # },
    # "close_depthmap": {"transform_number": 4, "se_size": (3, 3)},
    "tensorize": {"transform_number": 3},
}

exclude_patients = {
    "balgrist": [1, 5],
    "croatian": [4, 5],
    "italian": [],
    "ukbb": [],
}

datamodule = SpineDataModule(
    data_dir="P:\\Projects\\LMB_4Dspine\\back_scan_database",
    batch_size=32,
    train_fraction=0.8,
    val_fraction=0.1,
    test_fraction=0.1,
    transform_args=transform_args,
    n_subjects=5,
    exclude_patients=exclude_patients,
    use_cache=True,
    cache_dir="../cache/",
)
datamodule.setup(stage=None)

In [None]:
plt.imshow(datamodule.train_data[3][0].detach().numpy()[0], cmap="gray")

In [None]:
datamodule.train_dataloader

In [None]:
for value in datamodule.data.values():
    print(value.keys())

In [None]:
datamodule.train_data[0]

In [None]:
indices = np.random.choice(30, 16, replace=False)

keys = list(datamodule.data.keys())

# Single View
fig = plt.figure(figsize=(16, 16))

for i, idx in enumerate(indices):
    ax = fig.add_subplot(4, 4, i + 1, projection="3d")
    points = datamodule.data[keys[idx]]["backscan"]
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=1)
    ax.axis("off")

plt.show()

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(16, 16))

for i, ax in enumerate(axs.flatten()):
    depthmap = datamodule.data[keys[indices[i]]]["depth_map"]
    ax.imshow(depthmap, cmap="gray")
    ax.axis("off")

In [None]:
import numpy as np
from scipy.ndimage import grey_closing

# Assume `image` is the input NumPy array representing the grayscale image
# The structuring element can be customized; here we use a simple 3x3 square
structuring_element = np.ones((3, 3))


fig, axs = plt.subplots(4, 4, figsize=(16, 16))

for i, ax in enumerate(axs.flatten()):
    depthmap = datamodule.data[keys[indices[i]]]["depth_map"]
    depthmap = grey_closing(depthmap, structure=structuring_element)
    ax.imshow(depthmap, cmap="gray")
    ax.axis("off")