In [None]:
# Exhaustive set of imports

from collections import Counter
from collections import defaultdict
from glob import glob
from matplotlib import rcParams
from matplotlib import rcParamsDefault
from matplotlib.colors import ListedColormap
from mpl_toolkits.mplot3d import Axes3D
from pathlib import Path
from scipy.interpolate import CubicSpline
from scipy.ndimage import affine_transform
from scipy.ndimage import binary_erosion
from scipy.ndimage import zoom
from scipy.optimize import minimize
from scipy.spatial import cKDTree
from skimage import measure
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier, StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import MDS
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.svm import SVC
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split, Dataset
from tqdm import tqdm
from xgboost import XGBClassifier
import json
import matplotlib.cm as cm
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import plotly.graph_objects as go
import random
import scipy as sp
import seaborn as sns
import shutil
import sys
import tifffile
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import umap
import umap.umap_ as umap
import umap.umap_ as umap_module

<p style="height: 100px;"></p>


## Pre-Processing

In [None]:
# Must run once to create and save required data / metadata


TARGET_SHAPE = (64, 64, 64)

base_folder = 'source_data/crop_seg'
metadata_path = 'source_data/metadata.csv'
mask_save_folder = 'result_data/masks_64'
metadata_save_path = 'result_data/cell_data_64.json'

os.makedirs(mask_save_folder, exist_ok=True)
df = pd.read_csv(metadata_path)

cell_data = []
i = 0

for root, dirs, files in os.walk(base_folder):
    for file in files:
        if not file.endswith('.ome.tif'):
            continue

        image_path = os.path.join(root, file)

        try:
            im = tifffile.imread(image_path)

            # Handle possible channel axis arrangement
            if im.shape[0] > 2:
                im = im.swapaxes(1, 0)

            if im.shape[0] != 2:
                print(f"Skipping {file}: wrong number of channels ({im.shape[0]})")
                continue

            nucleus_shape = im[0].shape  # (z, y, x)
            cell_shape = im[1].shape     # (z, y, x)

            resized = np.zeros((2, *TARGET_SHAPE), dtype=np.float32)
            for ch in range(2):
                z_factor = TARGET_SHAPE[0] / im[ch].shape[0]
                y_factor = TARGET_SHAPE[1] / im[ch].shape[1]
                x_factor = TARGET_SHAPE[2] / im[ch].shape[2]
                resized[ch] = zoom(im[ch], (z_factor, y_factor, x_factor), order=1)

            binary_mask = (resized > 0).astype(np.float32)

            relative_path = os.path.relpath(image_path, base_folder)
            seg_file_name = 'crop_seg/' + relative_path

            matches = df[df['crop_seg'] == seg_file_name]
            if matches.empty:
                print(f"Warning: No metadata entry found for {seg_file_name}")
                continue

            row = matches.iloc[0]
            cell_id = row['CellId']
            label = row['label']

            mask_path = os.path.join(mask_save_folder, f'{cell_id}.npy')
            if not os.path.exists(mask_path):
                np.save(mask_path, binary_mask)

            cell_entry = {
                'id': cell_id,
                'label': label,
                'seg_file': seg_file_name,
                'mask_path': mask_path,
                'size_nucleus_1': int(nucleus_shape[0]),
                'size_nucleus_2': int(nucleus_shape[1]),
                'size_nucleus_3': int(nucleus_shape[2]),
                'size_cell_1': int(cell_shape[0]),
                'size_cell_2': int(cell_shape[1]),
                'size_cell_3': int(cell_shape[2]),
            }

            cell_data.append(cell_entry)
            i += 1
            if i % 100 == 0:
                print(f"Processed {i} images")

        except Exception as e:
            print(f"Error processing {file}: {e}")

with open(metadata_save_path, 'w') as f:
    json.dump(cell_data, f, indent=2)

print(f"Finished processing {i} images. Metadata saved to: {metadata_save_path}")

In [None]:
# ICP-based rigid alignment of 3D binary masks (align on cell body, apply same transform to nucleus)
# - Uses scipy.spatial.cKDTree and scipy.ndimage.affine_transform
# - Input npy files should be (2, Z, Y, X) with binary-like masks
# - Choose reference_file or it will use the first file in the folder as template


input_dir = Path("result_data/masks_64")
output_dir = Path("result_data/masks_64_rotated")
output_dir.mkdir(parents=True, exist_ok=True)

reference_file = None   # eg "example.npy"
max_icp_points = 5000  
icp_max_iter = 50
icp_tol = 1e-4
binarize_output = True  # if True, threshold at 0.5 before saving to keep binary masks
threshold_val = 0.5


def kabsch(src: np.ndarray, tgt: np.ndarray):
    """Compute rotation R and translation t using the Kabsch algorithm."""
    assert src.shape == tgt.shape and src.shape[1] == 3
    centroid_src = np.mean(src, axis=0)
    centroid_tgt = np.mean(tgt, axis=0)
    src_centered = src - centroid_src
    tgt_centered = tgt - centroid_tgt
    H = src_centered.T @ tgt_centered
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    
    if np.linalg.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt.T @ U.T
    t = centroid_tgt - R @ centroid_src
    return R, t

    
def icp_rigid(src_pts, tgt_pts, max_iterations=50, tol=1e-5, max_points=None):
    """Simple ICP: returns R, t such that approximately tgt ≈ R @ src + t"""
    if src_pts.shape[0] < 3 or tgt_pts.shape[0] < 3:
        raise ValueError("Not enough points for ICP.")

    if max_points is not None and src_pts.shape[0] > max_points:
        idx = np.random.choice(src_pts.shape[0], size=max_points, replace=False)
        src_sub = src_pts[idx].astype(np.float64)
    else:
        src_sub = src_pts.astype(np.float64)

    tgt = tgt_pts.astype(np.float64)
    tgt_kdtree = cKDTree(tgt)

    R_total = np.eye(3, dtype=np.float64)
    t_total = np.zeros(3, dtype=np.float64)

    src_trans = src_sub.copy()
    prev_error = np.inf

    for i in range(max_iterations):
        dists, idxs = tgt_kdtree.query(src_trans, k=1)
        matched_tgt = tgt[idxs]

        R_i, t_i = kabsch(src_trans, matched_tgt)
        R_total = R_i @ R_total
        t_total = (R_i @ t_total) + t_i
        src_trans = (R_total @ src_sub.T).T + t_total

        mean_error = np.mean(dists)
        if np.abs(prev_error - mean_error) < tol:
            break
        prev_error = mean_error

    return R_total, t_total, prev_error

    
def mask_to_points(mask, min_pts=10):
    """Convert 3D mask to (N,3) array of coordinates of non-zero voxels"""
    coords = np.argwhere(mask > 0.5)
    if coords.shape[0] < min_pts:
        return None
    return coords.astype(np.float64)


def apply_rigid_to_volume(vol, R, t, order=1, cval=0.0):
    """Apply rigid transform"""
    invR = np.linalg.inv(R)
    offset = -invR @ t
    transformed = affine_transform(vol, invR, offset=offset, order=order, mode='constant', cval=cval)
    return transformed


# Choose reference template (target)
files = sorted([p for p in input_dir.glob("*.npy")])
if len(files) == 0:
    raise RuntimeError(f"No .npy files found in {input_dir}")

if reference_file is None:
    ref_path = files[0]
else:
    ref_path = input_dir / reference_file
    if not ref_path.exists():
        raise FileNotFoundError(f"Reference file {ref_path} not found in {input_dir}")

print(f"Reference (target) file: {ref_path}")

ref_arr = np.load(ref_path)
if ref_arr.ndim != 4 or ref_arr.shape[0] < 2:
    raise ValueError(f"Reference array has unexpected shape: {ref_arr.shape}")

ref_cell = ref_arr[1].astype(np.float32)  # Use cell body (channel 1) of reference
ref_pts = mask_to_points(ref_cell, min_pts=10)
if ref_pts is None:
    raise ValueError("Reference cell body has too few points for ICP.")


if ref_pts.shape[0] > max_icp_points:
    sel = np.random.choice(ref_pts.shape[0], max_icp_points, replace=False)
    ref_pts_sub = ref_pts[sel]
else:
    ref_pts_sub = ref_pts


errors = []
for file_path in tqdm(files, desc="Aligning (ICP)"):
    try:
        arr = np.load(file_path)
        if arr.ndim != 4 or arr.shape[0] < 2:
            raise ValueError(f"Unexpected array shape {arr.shape} in {file_path.name}")

        nucleus = arr[0].astype(np.float32)
        cell_body = arr[1].astype(np.float32)

        src_pts = mask_to_points(cell_body, min_pts=10)
        if src_pts is None:
            raise ValueError("Too few cell-body voxels for ICP; skipping.")

        if src_pts.shape[0] > max_icp_points:
            idxs = np.random.choice(src_pts.shape[0], max_icp_points, replace=False)
            src_sub = src_pts[idxs]
        else:
            src_sub = src_pts

        R, t, final_err = icp_rigid(src_pts, ref_pts_sub,
                                    max_iterations=icp_max_iter,
                                    tol=icp_tol,
                                    max_points=max_icp_points)

        transformed_cell = apply_rigid_to_volume(cell_body, R, t, order=1, cval=0.0)
        transformed_nucleus = apply_rigid_to_volume(nucleus, R, t, order=1, cval=0.0)

        if binarize_output:
            transformed_cell = (transformed_cell >= threshold_val).astype(np.uint8)
            transformed_nucleus = (transformed_nucleus >= threshold_val).astype(np.uint8)

        aligned = np.stack([transformed_nucleus, transformed_cell], axis=0)

        out_path = output_dir / file_path.name
        np.save(out_path, aligned)

    except Exception as e:
        errors.append((file_path.name, str(e)))
        print(f"[WARN] {file_path.name}: {e}")


print(f"Done. Saved aligned files to: {output_dir}")
if errors:
    print(f"Completed with {len(errors)} warnings/errors. Example: {errors[:5]}")
else:
    print("No errors.")

In [None]:
# Ensure all voxel values at 0 or 1 after rotation based on threshold


# dirr = "result_data/masks_64_rotated"
dirr = "result_data/masks_64"

threshold = 0.2

npy_files = [f for f in os.listdir(dirr) if f.endswith(".npy")]

print(f"Thresholding {len(npy_files)} files...")

for filename in npy_files:
    file_path = os.path.join(dirr, filename)
    
    try:
        arr = np.load(file_path)
    except Exception as e:
        print(f"Error loading {filename}: {e}")
        continue

    thresholded_arr = (arr >= threshold).astype(np.uint8)

    try:
        np.save(file_path, thresholded_arr)
        print(f"Thresholded and saved: {filename}")
    except Exception as e:
        print(f"Error saving {filename}: {e}")

print("Done.")

In [None]:
# Display objects before and after rotation for sanity check


original_dir = "result_data/masks_64"
rotated_dir = "result_data/masks_64_rotated"

all_files = [f for f in os.listdir(original_dir) if f.endswith('.npy')]
sample_files = random.sample(all_files, 1)


def plot_3d_volume(vol, title):
    
    z, y, x = vol.shape
    X, Y, Z = np.mgrid[0:x, 0:y, 0:z]

    vol = np.transpose(vol, (2, 1, 0))

    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=vol.flatten(),
        isomin=0.5,
        isomax=1.0,
        opacity=0.1,
        surface_count=15,
        colorscale='Viridis'
    ))

    fig.update_layout(
        scene=dict(
            xaxis_title='X', yaxis_title='Y', zaxis_title='Z',
            xaxis=dict(showticklabels=False),
            yaxis=dict(showticklabels=False),
            zaxis=dict(showticklabels=False),
        ),
        title=title,
        width=500,
        height=500,
    )
    fig.show()


for file in sample_files:
    print(f"Processing file: {file}")

    original = np.load(os.path.join(original_dir, file))
    rotated = np.load(os.path.join(rotated_dir, file))

    print(f"Original shape: {original.shape}")
    print(f"Rotated shape: {rotated.shape}")

    original_nucleus = original[0]  # 0 = nucleus or 1 = cell body
    rotated_nucleus = rotated[0]

    plot_3d_volume(original_nucleus, f"Original - {file}")
    plot_3d_volume(rotated_nucleus, f"Rotated - {file}")

<p style="height: 200px;"></p>

## Quantile-Embedding Creation

In [None]:
# Returns coordinates of surface voxels


def extract_surface_points(volume):
    """Extracts surface points from a 3D volume."""
    eroded = binary_erosion(volume)
    surface = volume & ~eroded  # Keep only the boundary points
    surface_points = np.argwhere(surface)
    return surface_points

In [None]:
# Populates data and data_cell with extracted surface points


df = pd.read_csv('source_data/metadata.csv')

np.random.seed(42)
cell_types = ['M0',
 'M1M2',
 'M3',
 'M4M5',
 'M6M7_early',
 'M6M7_half',
 'blob',
 'dead',
 'wrong']

np.random.seed(42)
import numpy as np

N = 5767
data = [0] * N

data_full = [0] * N
data_cell = [0] * N
import open3d as o3d
idx = 0
import os

directory = 'source_data/crop_seg'

labels = [0] * N
counter = 0
file_names = [0] * N


for root, dirs, files in os.walk(directory):
    
    print(root)
    
    for file in files:
        if file.endswith('.ome.tif'):
            
            im = tifffile.imread(os.path.join(root, file))
            if im.shape[0] > 2:
                
                im = im.swapaxes(1,0)
            imshape = im.shape
                        
            name = os.path.join(root, file)
            name = os.path.relpath(name, directory)
            name = 'crop_seg/' + name

            print("looking for: " + name)

            index = np.where(df['crop_seg'] == name)[0][0]
            
            im_cell = im[ np.ix_([0,1], range(0,imshape[1],8), range(0,imshape[2],20), range(0,imshape[3],20) ) ]
            # Downsizes by (8, 20, 20), change to speed up Wasserstein distance computation
            
            im = im[ np.ix_([0,1], range(0,imshape[1],4), range(0,imshape[2],10), range(0,imshape[3],10) ) ]


            if im.shape[0] == 2:
                
                data_full[counter] = im
                im_cell = im_cell[1,:,:,:]>0
                
                im = im[0,:,:,:]>0                
                
                index = range(im[0].shape[0])
                data[counter] = [0]
                data[counter] = extract_surface_points(im)

                data_cell[counter] = [0]
                data_cell[counter] = extract_surface_points(im_cell)

                print(os.path.join(root, file))

                index = np.where( df['crop_seg'] == name)[0][0]
                labels[counter] = df['label'][index]
                
                file_names[counter] = os.path.join(root, file)
                counter += 1
                print(counter)

In [None]:
# Sanity check


N_shapes = counter
for i in range(N_shapes):
    data_cell[i] = np.array(data_cell[i], dtype=float)
    print(data_cell[i].shape[0])
    data[i] = np.array(data[i], dtype=float)
    print(data[i].shape)

In [None]:
# Computes quantiles of Wasserstein distances


def compute_WassKernel(data,n_pt=10000, metric='Euclidean',normalize=False, return_distance=False):
    n = len(data)
    quantiles_all = [0]*len(data)
    print(n_pt)
    for i in range(n):
        if data[i].shape[0] == 0:
            quantiles_all[i] = np.inf + np.zeros( (n_pt))
        else:
            
            C1 = sp.spatial.distance.cdist(data[i], data[i])
            
            if normalize:
                C1 = C1/np.median(C1)
                
            quantiles_all[i] = np.quantile(C1.ravel(), np.linspace(0,1,n_pt,endpoint=True))
            
    if return_distance == False:       
        return sp.spatial.distance.squareform(sp.spatial.distance.pdist(quantiles_all))/n_pt
    else:
        return sp.spatial.distance.squareform(sp.spatial.distance.pdist(quantiles_all))/n_pt, quantiles_all

In [None]:
# Compute the quantile embedding on full entire dataset


start = time.time()
K_W_cell, embed_W_cell = compute_WassKernel(data_cell[0:N_shapes],n_pt=100, metric='Euclidean',normalize=False, return_distance=True)
print(time.time()-start)
K_W, embed_W = compute_WassKernel(data[0:N_shapes],n_pt=100, metric='Euclidean',normalize=False, return_distance=True)
print(time.time()-start)

# np.save('result_data/All_3DShape_W.npy',K_W_cell)
# np.save('result_data/All_3DShape_W_embed.npy',embed_W_cell)

# np.save('result_data/All_3DShape_Nucleus_W.npy',K_W)
# np.save('result_data/All_3DShape_Nucleus_W_embed.npy',embed_W)

In [None]:
# Visualize UMAP of embeddings across labels


out_path = "result_data/figures"
os.makedirs(out_path, exist_ok=True)

class_names = cell_types
labels = labels[0:N_shapes]

plt.rcParams["axes.prop_cycle"] = rcParamsDefault["axes.prop_cycle"]

outlier_markers = ['x', '^', 's']
viridis = plt.get_cmap('viridis', 6)
continuous_colors = [viridis(i) for i in range(6)]

outlier_colors = ['#e41a1c', '#000000', '#f0027f']  # red, black, magenta
color_list = continuous_colors + outlier_colors

fig = plt.figure(figsize=(18, 5))

fit = umap.UMAP(n_components=2, random_state=42)
mapper_w = fit.fit_transform(np.hstack([embed_W, embed_W_cell]))

ax0 = fig.add_subplot(131)
for i in range(len(class_names)):
    idx = [j for j in range(len(labels)) if labels[j] == class_names[i]]

    if i < len(class_names)-3:
        ax0.scatter(mapper_w[idx, 0], mapper_w[idx, 1], s=30, label=class_names[i], color=color_list[i])
    else:
        ax0.scatter(mapper_w[idx, 0], mapper_w[idx, 1], s=30, label=class_names[i], color=color_list[i],
                    marker=outlier_markers[i-(len(class_names)-3)])

ax0.set_title('Combined Wasserstein')
ax0.set_xlabel("UMAP 1")
ax0.set_ylabel("UMAP 2")

# Nucleus Only Chart
fit = umap.UMAP(n_components=2, random_state=52)
mapper_ws = fit.fit_transform(embed_W)

ax1 = fig.add_subplot(132)
for i in range(len(class_names)):
    idx = [j for j in range(len(labels)) if labels[j] == class_names[i]]

    if i < len(class_names)-3:
        ax1.scatter(mapper_ws[idx, 0], mapper_ws[idx, 1], s=30, label=class_names[i], color=color_list[i])
    else:
        ax1.scatter(mapper_ws[idx, 0], mapper_ws[idx, 1], s=30, label=class_names[i], color=color_list[i],
                    marker=outlier_markers[i-(len(class_names)-3)])

ax1.set_title('Nucleus only')
ax1.set_xlabel("UMAP 1")
ax1.set_ylabel("UMAP 2")


# Cell Only Chart
mapper_cell = fit.fit_transform(embed_W_cell)

ax2 = fig.add_subplot(133)
for i in range(len(class_names)):
    idx = [j for j in range(len(labels)) if labels[j] == class_names[i]]

    if i < len(class_names)-3:
        ax2.scatter(mapper_cell[idx, 0], mapper_cell[idx, 1], s=30, label=class_names[i], color=color_list[i])
    else:
        ax2.scatter(mapper_cell[idx, 0], mapper_cell[idx, 1], s=30, label=class_names[i], color=color_list[i],
                    marker=outlier_markers[i-(len(class_names)-3)])

ax2.set_title('Cell only')
ax2.set_xlabel("UMAP 1")
ax2.set_ylabel("UMAP 2")

handles, labels_axis = [], []
for ax in [ax0, ax1, ax2]:
    h, l = ax.get_legend_handles_labels()
    handles += h
    labels_axis += l

    
unique = dict()
for h, l in zip(handles, labels_axis):
    if l not in unique:
        unique[l] = h

fig.legend(unique.values(), unique.keys(), loc='center left', bbox_to_anchor=(1.01, 0.5), borderaxespad=0.)
plt.subplots_adjust(right=0.97)
plt.savefig("result_data/figures/WASS_UMAP.png", format="png", bbox_inches="tight")
plt.show()

<p style="height: 200px;"></p>

## Quantile-Embeddings for Classification

In [None]:
# Predict label probabilities based on embedding using basic ML classifier
# - predict every cell label using cross-training, label as "bad_prediction" if confidently predicts incorrectly
# - can be used to filter out potential mislabels, or to assist expansion of labels to new unlabelled data


X_diff = np.abs(np.array(embed_W) - np.array(embed_W_cell))
X = np.hstack([embed_W, embed_W_cell, X_diff])
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = np.array(labels)

le = LabelEncoder()
y_int = le.fit_transform(y) 
class_names = le.classes_.tolist()

json_path = "result_data/cell_data_64.json"
with open(json_path, "r") as f:
    data = json.load(f)

id_to_entry = {entry["id"]: entry for entry in data}

# Define classifier
def build_classifier(seed):
    base_learners = [
        ('rf', RandomForestClassifier(n_estimators=50, random_state=seed)),
        ('knn', KNeighborsClassifier(n_neighbors=15, weights='distance')),
        ('svm', SVC(kernel='rbf', gamma='scale', C=50, probability=True)),
        ('xgb', XGBClassifier(random_state=seed, eval_metric='mlogloss', n_estimators=50, subsample=0.8))
    ]
    meta_learner = RandomForestClassifier(n_estimators=50, random_state=seed)
    return StackingClassifier(estimators=base_learners, final_estimator=meta_learner, cv=3)

# 50/50 split (train/test)
np.random.seed(42)
indices = np.arange(len(X))
np.random.shuffle(indices)
mid = len(indices) // 2
split1, split2 = indices[:mid], indices[mid:]

predictions = {}

for train_idx, test_idx in [(split1, split2), (split2, split1)]:
    clf = build_classifier(seed=0)
    clf.fit(X[train_idx], y_int[train_idx])
    probs = clf.predict_proba(X[test_idx])
    preds = np.argmax(probs, axis=1)

    for i, idx in enumerate(test_idx):
        prob_vector = probs[i]
        pred_class_idx = preds[i]
        true_class_idx = y_int[idx]

        highest_prob = float(np.max(prob_vector))
        highest_label = class_names[pred_class_idx]

        true_prob = float(prob_vector[true_class_idx])
        true_label = class_names[true_class_idx]

        bad = bool((pred_class_idx != true_class_idx) and (highest_prob >= 5 * true_prob))

        predictions[idx] = {
            "highest_prediction": (round(highest_prob, 4), highest_label),
            "true_prediction": (round(true_prob, 4), true_label),
            "bad_prediction": bad
        }


for idx, pred in predictions.items():
    cell_id = data[idx]["id"]
    id_to_entry[cell_id].update(pred)

with open(json_path, "w") as f:
    json.dump(list(id_to_entry.values()), f, indent=2)

bad_counts = Counter()
for idx, pred in predictions.items():
    if pred["bad_prediction"]:
        true_label = pred["true_prediction"][1]
        bad_counts[true_label] += 1

print("Bad prediction counts by true class:")
total = 0
for cls in class_names:
    print(f"{cls}: {bad_counts[cls]}")
    total += bad_counts[cls]

print(f"Total: {total} / 5764")

In [None]:
# A further classifier to expand set of labelled data, trained only on data with previous "bad_prediction" = false
# - classifier accuracy is gauged on test data after training on data with "bad_prediction" = false
# - note that the classifier includes dead, blob, and wrong labels since expanded data may also contain these, however this decreases 
#   performance quality of the classifier due to the inclusion of these additional labels with minimal training instances.
# The nature of this discrete classification along a continuous process is also inherently difficult at borders


json_path = "result_data/cell_data_64.json"
n_repeats = 5
random_seed_base = 42

with open(json_path, "r") as f:
    data = json.load(f)

if 'N_shapes' in globals():
    n_samples = int(globals()['N_shapes'])
else:
    n_samples = min(len(X), len(y), len(data))

if len(data) < n_samples:
    raise RuntimeError(f"JSON contains {len(data)} entries but expected at least {n_samples}.")

print(f"Using first n_samples = {n_samples} entries (must correspond to rows of X and labels y).")

X = np.asarray(X)[:n_samples]
y = np.asarray(y)[:n_samples]

keep_mask = np.array([ not data[i].get("bad_prediction", True) for i in range(n_samples) ], dtype=bool)

num_total = len(keep_mask)
num_kept = keep_mask.sum()
print(f"Total considered samples: {num_total}; samples with bad_prediction==False (kept): {num_kept}")

if num_kept < 10:
    raise RuntimeError("Too few samples passed the bad_prediction==False filter. Aborting.")

X_filtered = X[keep_mask]
y_filtered = y[keep_mask]

class_counts = Counter(y_filtered)
print("Class distribution (kept):")
for cls, cnt in class_counts.items():
    print(f"  {cls}: {cnt}")

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_filtered)


def accuracy_stack_and_plot2(X_in, y_in, seed, do_plot):
    
    X_train, X_test, y_train, y_test = train_test_split(X_in, y_in, test_size=0.2, random_state=seed, stratify=y_in)

    base_learners = [
        ('rf', RandomForestClassifier(n_estimators=50, random_state=seed)),
        ('knn', KNeighborsClassifier(n_neighbors=15, weights='distance')),
        ('svm', SVC(kernel='rbf', gamma='scale', C=50, probability=True, random_state=seed)),
        ('xgb', XGBClassifier(random_state=seed, eval_metric='mlogloss', n_estimators=50, subsample=0.8, use_label_encoder=False))
    ]

    meta_learner = RandomForestClassifier(n_estimators=50, random_state=seed)

    stack = StackingClassifier(estimators=base_learners, final_estimator=meta_learner, cv=3)
    stack.fit(X_train, y_train)

    y_pred = stack.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)

    class_names_local = sorted(np.unique(np.concatenate([y_train, y_test])))

    cm = confusion_matrix(y_test, y_pred, labels=class_names_local)
    result_matrix = pd.DataFrame(cm, index=class_names_local, columns=class_names_local)

    print("")
    print("Confusion Matrix:")
    print(result_matrix)
    print(f"\nAccuracy: {accuracy:.4f}\n")

    if do_plot:
        plt.figure(figsize=(10, 8))
        sns.heatmap(result_matrix, annot=True, fmt='d', cmap="Blues", cbar=True)
        plt.title("Confusion Matrix (Stacking Classifier)")
        plt.xlabel("Predicted")
        plt.ylabel("Actual")
        plt.tight_layout()
        plt.show()

        # Normalized confusion
        row_sums = cm.sum(axis=1)
        normalized_cm = np.zeros_like(cm, dtype=float)
        for i in range(len(class_names_local)):
            for j in range(len(class_names_local)):
                denom = np.sqrt(row_sums[i] * row_sums[j])
                normalized_cm[i, j] = cm[i, j] / denom if denom > 0 else 0

        normalized_df = pd.DataFrame(normalized_cm,
                                     index=[f"True_{cls}" for cls in class_names_local],
                                     columns=[f"Pred_{cls}" for cls in class_names_local])

        plt.figure(figsize=(10, 8))
        sns.heatmap(normalized_df, annot=True, fmt=".5f", cmap="viridis", cbar=True)
        plt.title("Normalized Confusion Matrix (Row Frequency Normalization)")
        plt.xlabel("Predicted")
        plt.ylabel("Actual")
        plt.tight_layout()
        plt.show()

    return accuracy


accuracy_stack_witness = []
for i in range(n_repeats):
    seed = random_seed_base + i
    do_plot = (i == n_repeats - 1)
    print(f"Run {i+1}/{n_repeats} (seed={seed})")
    acc = accuracy_stack_and_plot2(X_scaled, y_filtered.copy(), seed=seed, do_plot=do_plot)
    accuracy_stack_witness.append(acc)
    print("Accuracies so far:", accuracy_stack_witness)

print("\nMean Stacking Accuracy:", np.mean(accuracy_stack_witness))

In [None]:
# Create a roughly balanced dataset by phase labels, filtering out mislabels/invalid stages.
# Copy both original and rotated masks into balanced folders and write balanced JSON containing both paths
# Note this rotation is unnecessary for our pre-aligned dataset, but is vital for other unaligned 3D datasets


random.seed(42)

original_metadata_path = 'result_data/cell_data_64.json'
balanced_mask_dir = 'result_data/masks_64_balanced'
balanced_metadata_path = 'result_data/cell_data_64_balanced.json'
rotated_balanced_dir = 'result_data/masks_64_rotated_balanced'

os.makedirs(balanced_mask_dir, exist_ok=True)
os.makedirs(rotated_balanced_dir, exist_ok=True)

excluded_labels = {"blob", "dead", "wrong"}

with open(original_metadata_path, 'r') as f:
    all_entries = json.load(f)

valid_data_by_label = defaultdict(list)
for entry in all_entries:
    label = entry.get("label")
    if label in excluded_labels:
        continue
    if entry.get("bad_prediction") is False:
        valid_data_by_label[label].append(entry)

label_counts = {label: len(lst) for label, lst in valid_data_by_label.items()}
if not label_counts:
    raise RuntimeError("No entries found with bad_prediction == False after excluding labels. Aborting.")

print("Good-prediction counts per label (after excluding blob/dead/wrong):")
for lbl, cnt in label_counts.items():
    print(f"  {lbl}: {cnt}")

# Determine the minimum available per label
min_count = min(label_counts.values())
print(f"\nMinimum available per-label (good predictions): {min_count}")

# For non-min labels we will take up to min_count + 100
balanced_metadata = []
copy_warnings = []
for label, entries in valid_data_by_label.items():
    available = len(entries)
    if available == min_count:
        sample_count = min_count
    else:
        sample_count = min(min_count + 100, available)

    if sample_count == 0:
        print(f"Skipping label {label} because sample_count == 0")
        continue

    sampled = random.sample(entries, sample_count)
    print(f"Sampling {len(sampled)} / {available} for label {label}")

    for entry in sampled:
        src_mask = entry.get("mask_path")
        if not src_mask:
            copy_warnings.append(f"Missing mask_path for id {entry.get('id')}")
            continue

        dst_mask = os.path.join(balanced_mask_dir, os.path.basename(src_mask))

        try:
            if os.path.exists(src_mask):
                shutil.copy(src_mask, dst_mask)
            else:
                copy_warnings.append(f"Mask source not found: {src_mask} (id {entry.get('id')})")
        except Exception as e:
            copy_warnings.append(f"Error copying mask {src_mask} -> {dst_mask}: {e}")

        rotated_src = entry.get("rotated_mask_path")
        if not rotated_src:
            rotated_src = src_mask.replace('masks_64', 'masks_64_rotated')

        dst_rotated = os.path.join(rotated_balanced_dir, os.path.basename(rotated_src))

        try:
            if os.path.exists(rotated_src):
                shutil.copy(rotated_src, dst_rotated)
            else:
                copy_warnings.append(f"Rotated mask source not found: {rotated_src} (id {entry.get('id')})")
        except Exception as e:
            copy_warnings.append(f"Error copying rotated mask {rotated_src} -> {dst_rotated}: {e}")

        new_entry = entry.copy()
        new_entry["mask_path"] = dst_mask
        new_entry["rotated_mask_path"] = dst_rotated

        balanced_metadata.append(new_entry)

with open(balanced_metadata_path, 'w') as f:
    json.dump(balanced_metadata, f, indent=2)


print(f"\nBalanced dataset created with {len(balanced_metadata)} total images.")
print(f"Balanced masks copied to: {balanced_mask_dir}")
print(f"Balanced rotated masks copied to: {rotated_balanced_dir}")
print(f"Balanced metadata saved to: {balanced_metadata_path}")

if copy_warnings:
    print("\n Warnings encountered during copying:")
    for w in copy_warnings:
        print("  -", w)

<p style="height: 200px;"></p>

## VAE Architecture

In [None]:
# Required for retrieval


class CellMaskDataset(Dataset):
    def __init__(self, mask_dir, metadata_path=None):
        self.paths = sorted(glob(os.path.join(mask_dir, '*.npy')))
        
        if metadata_path:
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
            valid_paths = set(entry["mask_path"] for entry in metadata)
            # valid_paths = set(entry["rotated_mask_path"] for entry in metadata)
            self.paths = [p for p in self.paths if p in valid_paths]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        arr = np.load(self.paths[idx])
        return torch.tensor(arr, dtype=torch.float32), self.paths[idx]

In [None]:
# Functions specifying encoding and decoding details in VAE


class EncoderBranch(nn.Module):
    
    def __init__(self, input_shape=(1, 64, 64, 64), latent_dim=16, base_channels=8,
                 kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.conv1 = nn.Conv3d(1, base_channels, kernel_size, stride=stride, padding=stride-1)
        self.conv2 = nn.Conv3d(base_channels, base_channels * 2, kernel_size, stride=stride, padding=stride-1)

        with torch.no_grad():
            dummy_input = torch.zeros(1, *input_shape)
            x = F.relu(self.conv1(dummy_input))
            x = F.relu(self.conv2(x))
            self.flattened_size = x.numel()

        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(self.flattened_size, latent_dim)
        self.fc_logvar = nn.Linear(self.flattened_size, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        return self.fc_mu(x), self.fc_logvar(x)



class Decoder(nn.Module):
    
    def __init__(self, output_shape=(2, 64, 64, 64), latent_dim=32, base_channels=8,
                 kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.base_channels = base_channels
        self.output_shape = output_shape

        with torch.no_grad():
            dummy = torch.zeros(1, *output_shape)
            conv1 = nn.Conv3d(1, base_channels, kernel_size, stride=stride, padding=stride-1)
            conv2 = nn.Conv3d(base_channels, base_channels * 2, kernel_size, stride=stride, padding=stride-1)
            x = F.relu(conv1(dummy[:, 0:1]))
            x = F.relu(conv2(x))
            self.unflatten_shape = x.shape[1:]
            self.linear_input_size = x.numel()

        self.fc = nn.Linear(latent_dim, self.linear_input_size)
        self.unflatten = nn.Unflatten(1, self.unflatten_shape)

        self.deconv1 = nn.ConvTranspose3d(base_channels * 2, base_channels, kernel_size,
                                          stride=stride, padding=stride-1, output_padding=stride-1)
        self.deconv2 = nn.ConvTranspose3d(base_channels, output_shape[0], kernel_size,
                                          stride=stride, padding=stride-1, output_padding=stride-1)

    def forward(self, z):
        x = F.relu(self.fc(z))
        x = self.unflatten(x)
        x = F.relu(self.deconv1(x))
        return torch.sigmoid(self.deconv2(x))

In [None]:
# Defines how to encode using dual-branch (easily extendible to more independent branches)


class DualBranchVAE(nn.Module):
    
    def __init__(self, latent_dim=32, input_shape=(2, 64, 64, 64), base_channels=8,
                 kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.latent_dim = latent_dim
        self.branch_nucleus = EncoderBranch(input_shape=(1, *input_shape[1:]),
                                            latent_dim=latent_dim // 2,
                                            base_channels=base_channels,
                                            kernel_size=kernel_size,
                                            stride=stride,
                                            padding=stride-1)
        self.branch_cell = EncoderBranch(input_shape=(1, *input_shape[1:]),
                                         latent_dim=latent_dim // 2,
                                         base_channels=base_channels,
                                         kernel_size=kernel_size,
                                         stride=stride,
                                         padding=stride-1)
        self.decoder = Decoder(output_shape=input_shape,
                               latent_dim=latent_dim,
                               base_channels=base_channels,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=stride-1)

    def encode(self, x):
        nucleus = x[:, 0:1]  # shape (B, 1, D, H, W)
        cell = x[:, 1:2]
        mu_n, logvar_n = self.branch_nucleus(nucleus)
        mu_c, logvar_c = self.branch_cell(cell)
        mu = torch.cat([mu_n, mu_c], dim=1)
        logvar = torch.cat([logvar_n, logvar_c], dim=1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
def loss_function(recon_x, x, mu, logvar, beta=1.0):
    """
    - mu, logvar: Mean and log-variance of latent space (for KL divergence)
    - beta: The weight for KL divergence term (controls regularization)
    
    Returns:
    - total_loss: Sum of reconstruction and KL losses.
    - bce_loss: Binary cross-entropy loss (reconstruction)
    - kl_loss: KL divergence loss (latent regularization)
    """

    recon_x = torch.clamp(recon_x, min=1e-10, max=1 - 1e-10)
    x = torch.clamp(x, min=1e-10, max=1 - 1e-10)

    if torch.any(recon_x < 0) or torch.any(recon_x > 1):
        print("recon_x out of bounds", recon_x.min().item(), recon_x.max().item())
        recon_x = recon_x.clamp(min=1e-10, max=1 - 1e-10)

    if torch.any(x < 0) or torch.any(x > 1):
        print("x out of bounds", x.min().item(), x.max().item())
        x = x.clamp(min=1e-10, max=1 - 1e-10)
    
    bce_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return bce_loss + beta * kl_loss, bce_loss, kl_loss

In [None]:
# Defines training using the set of best HPs from previous grid-search results. To recreate HP-search, enter full list
# of candidate values in sample_hyperparams() below


def run_training(config, dataset_dir, metadata_path, input_shape, latent_dim, device):
    dataset = CellMaskDataset(dataset_dir, metadata_path)
    if len(dataset) == 0:
        raise ValueError(f"No valid samples found in {dataset_dir}")

    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_set, val_set = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_set, batch_size=2, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=2, shuffle=True)

    model = DualBranchVAE(
        latent_dim=latent_dim,
        base_channels=config['base_channels'],
        stride=config['stride']
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

    best_val_loss = float('inf')
    epochs_no_improve = 0
    early_stop_patience = 2
    best_epoch = 0
    epoch_num = 0

    for epoch in range(20):

        print(f"epoch number: {epoch_num}")
        epoch_num += 1
        
        model.train()
        total_train_loss = 0
        for batch, _ in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            recon, mu, logvar = model(batch)
            loss, _, _ = loss_function(recon, batch, mu, logvar, beta=config['beta'])
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        print("Epoch trained")
        
        model.eval()
        total_val_recon_loss = 0
        with torch.no_grad():
            for batch, _ in val_loader:
                batch = batch.to(device)
                recon, mu, logvar = model(batch)
                val_loss, recon_loss, kl_loss = loss_function(recon, batch, mu, logvar, beta=config['beta'])
                total_val_recon_loss += recon_loss.item()

        avg_val_loss = total_val_recon_loss / len(val_loader.dataset)
        scheduler.step(avg_val_loss)

        print(f"Validation reconstruction loss: {avg_val_loss}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict()
            best_mu = mu.cpu().numpy()
            best_logvar = logvar.cpu().numpy()
            best_epoch = epoch + 1
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= early_stop_patience:
            break

    model.load_state_dict(best_model_state)

    model_save_path = 'result_data/dual_branch_vae_64_balanced.pth'
    # model_save_path = 'result_data/dual_branch_vae_64_rotated_balanced.pth'
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to: {model_save_path}")

    return best_val_loss, best_epoch, config, model, best_mu, best_logvar, train_loader, val_loader, train_set, val_set


def sample_hyperparams(latent_dim):
    lrs = [5e-4]  # If running grid-search, adjust to all candidate values for the 4 HPs below
    betas = [0.5]
    strides = [2]
    base_channels_options = [latent_dim // 4]

    combos = []
    for lr in lrs:
        for beta in betas:
            for stride in strides:
                for base in base_channels_options:
                    combos.append({
                        'lr': lr,
                        'beta': beta,
                        'stride': stride,
                        'base_channels': base
                    })

    return random.sample(combos, 1)

In [None]:
# Run training on a 80/20 split of balanced dataset


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_configs = [
    {
        'input_shape': (2, 64, 64, 64),
        'mask_dir': 'result_data/masks_64_balanced',   # result_data/masks_64_rotated_balanced
        'metadata_path': 'result_data/cell_data_64_balanced.json'
    }
]

latent_dims = [64]  # If running grid-search, adjust to all candidate values [16, 32, 64]
final_results = []

for config_dict in input_configs:
    input_shape = config_dict['input_shape']
    dataset_dir = config_dict['mask_dir']
    metadata_path = config_dict['metadata_path']

    for latent_dim in latent_dims:
        trials = []
        best_loss = float('inf')
        best_trial = None

        print(f"\nStarting search for shape {input_shape}, latent dimension {latent_dim}")

        configs = list(sample_hyperparams(latent_dim))
        total_configs = len(configs)

        for idx, config in enumerate(configs, start=1):

            print(f"\n[{idx}/{total_configs}] Testing config: {config}")
            
            try:
                val_loss, num_epochs, used_config, model, mu, logvar, train_loader, val_loader, train_set, val_set = run_training(
                    config=config,
                    dataset_dir=dataset_dir,
                    metadata_path=metadata_path,
                    input_shape=input_shape,
                    latent_dim=latent_dim,
                    device=device
                )
            except ValueError as e:
                print(f"Skipping config due to data error: {e}")
                continue

            print(f"Completed config {idx}/{total_configs} with val loss: {val_loss:.4f}")

            trials.append({
                'config': used_config,
                'val_loss': val_loss
            })

            if val_loss < best_loss:
                best_loss = val_loss
                best_trial = {
                    'combination': {
                        'input_shape': input_shape,
                        'latent_dim': latent_dim
                    },
                    'best_config': used_config,
                    'val_loss': val_loss,
                    'epochs_ran': num_epochs,
                    'tried_configs': trials,
                    'mu': mu.tolist(),
                    'logvar': logvar.tolist()
                }

        if best_trial:
            final_results.append(best_trial)

<p style="height: 200px;"></p>

## VAE Evaluation

In [None]:
# Load model again if required, test on new data and save metrics


# Loading model:
# dataset = CellMaskDataset('result_data/masks_64_balanced', 'result_data/cell_data_64_balanced.json')
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_set, val_set = random_split(dataset, [train_size, val_size])
# train_loader = DataLoader(train_set, batch_size=2, shuffle=True)
# val_loader = DataLoader(val_set, batch_size=2, shuffle=True)

# model = DualBranchVAE(latent_dim=16, base_channels=8, stride=2).to(device)
# model.load_state_dict(torch.load('result_data/dual_branch_vae_64_balanced.pth'))


test_loader = val_loader
test_set = val_set

model.eval()
latent_vectors = {}
logvar_vectors = {}
recon_losses = {}
voxel_accuracies = {}
dice_scores = {}


def dice_coefficient(pred, target):
    """Compute DICE = 2|A∩B| / (|A|+|B|)"""
    intersection = (pred & target).sum().item()
    pred_sum = pred.sum().item()
    target_sum = target.sum().item()
    if pred_sum + target_sum == 0:
        return 1.0  # both empty = perfect match
    return (2.0 * intersection) / (pred_sum + target_sum)


with torch.no_grad():
    for batch, path in tqdm(test_loader):
        batch = batch.to(device)
        recon, mu, logvar = model(batch)

        recon = torch.clamp(recon, min=1e-10, max=1 - 1e-10)
        batch = torch.clamp(batch, min=1e-10, max=1 - 1e-10)

        mu = mu.cpu()
        logvar = logvar.cpu()
        recon = recon.cpu()
        batch = batch.cpu()

        for i in range(batch.size(0)):
            sample_id = os.path.basename(path[i]).replace('.npy', '')

            latent_vectors[sample_id] = mu[i].tolist()
            logvar_vectors[sample_id] = logvar[i].tolist()

            recon_sample = torch.clamp(recon[i], 1e-10, 1 - 1e-10)
            batch_sample = torch.clamp(batch[i], 1e-10, 1 - 1e-10)

            bce = F.binary_cross_entropy(recon_sample, batch_sample, reduction='sum')

            recon_binary = (recon_sample >= 0.5).int()
            batch_binary = (batch_sample >= 0.5).int()
            
            correct_voxels = (recon_binary == batch_binary).sum().item()
            total_voxels = batch_sample.numel()
            accuracy = correct_voxels / total_voxels

            dice = dice_coefficient(recon_binary, batch_binary)

            recon_losses[sample_id] = bce.item()
            voxel_accuracies[sample_id] = accuracy
            dice_scores[sample_id] = dice


with open('result_data/latent_space_test_64_balanced.json', 'w') as f:
    json.dump(latent_vectors, f)

with open('result_data/logvars_test_64_balanced.json', 'w') as f:
    json.dump(logvar_vectors, f)

with open('result_data/reconstruction_losses_test_64_balanced.json', 'w') as f:
    json.dump(recon_losses, f)

with open('result_data/voxel_accuracies_test_64_balanced.json', 'w') as f:
    json.dump(voxel_accuracies, f)

with open('result_data/dice_scores_test_64_balanced.json', 'w') as f:
    json.dump(dice_scores, f)

In [None]:
# Observe voxel-level reconstruction accuracies by phase


with open('result_data/voxel_accuracies_test_64_balanced.json', 'r') as f:
    voxel_accuracies = json.load(f)

with open('result_data/dice_scores_test_64_balanced.json', 'r') as f:
    dice_scores = json.load(f)

with open('result_data/cell_data_64_balanced.json', 'r') as f:
    metadata = json.load(f)

id_to_label = {entry["id"]: entry["label"] for entry in metadata}
valid_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half']

acc_values = list(voxel_accuracies.values())
dice_values = list(dice_scores.values())

print("Overall Metrics:")
print(f"   Binary Voxel Accuracy — Mean: {np.mean(acc_values) * 100:.2f}%, Std: {np.std(acc_values) * 100:.2f}%")
print(f"   DICE Score            — Mean: {np.mean(dice_values) * 100:.2f}%, Std: {np.std(dice_values) * 100:.2f}%")

label_to_accs = defaultdict(list)
label_to_dices = defaultdict(list)

for id_ in voxel_accuracies.keys():
    label = id_to_label.get(id_)
    if label in valid_labels:
        label_to_accs[label].append(voxel_accuracies[id_])
        label_to_dices[label].append(dice_scores[id_])

print("Metrics by Label:")
for label in valid_labels:
    accs = label_to_accs[label]
    dices = label_to_dices[label]

    if accs:
        print(f"  {label:12s} — "
              f"Acc: {np.mean(accs) * 100:.2f}% ± {np.std(accs) * 100:.2f}%, "
              f"DICE: {np.mean(dices) * 100:.2f}% ± {np.std(dices) * 100:.2f}%, "
        )
    else:
        print(f"{label:12s} — No samples found")

In [None]:
# Generate images from latent space distribution, visualize central 2D cuts of structures


with open('result_data/latent_space_test_64_balanced.json', 'r') as f:   # latent_space_test_64_rotated_balanced.json
    mu_dict = json.load(f)

with open('result_data/logvars_test_64_balanced.json', 'r') as f:   # logvars_test_64_rotated_balanced.json
    logvar_dict = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ids = list(mu_dict.keys())
mu_array = np.array([mu_dict[i] for i in ids])
logvar_array = np.array([logvar_dict[i] for i in ids])

num_samples = 5
sample_indices = random.sample(range(len(mu_array)), num_samples)

samples = []
for i in sample_indices:
    mu = np.array(mu_array[i])
    logvar = np.array(logvar_array[i])
    std = np.exp(0.5 * logvar)
    epsilon = np.random.randn(*mu.shape)
    z = mu + std * epsilon
    samples.append(z)

samples_tensor = torch.tensor(samples, dtype=torch.float32).to(device)

model.eval()
with torch.no_grad():
    generated = model.decoder(samples_tensor)
    generated = (generated >= 0.2).int()

for i in range(num_samples):
    volume = generated[i].cpu().numpy()
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    axs[0].imshow(volume[0, 32, :, :], cmap='gray')
    axs[0].set_title("Channel 1 (Nucleus)")
    axs[1].imshow(volume[1, 32, :, :], cmap='gray')
    axs[1].set_title("Channel 2 (Cell)")
    plt.suptitle(f"Sampled from Posterior {i+1}")
    plt.tight_layout()
    # plt.savefig(f"result_data/figures/POSTERIOR_SAMPLE_{i}.png", format="png", bbox_inches="tight")
    plt.show()

In [None]:
# Visualize reconstructed images from latent space, before and after


def plot_vae_reconstructions_per_label(model, latent_file_path, metadata_path, dataset_dir, n_per_label=3, slice_idx=32):

    with open(latent_file_path, 'r') as f:
        latent_vectors = json.load(f)

    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    id_to_label = {entry["id"]: entry["label"] for entry in metadata}

    label_to_ids = defaultdict(list)
    for id_ in latent_vectors:
        if id_ in id_to_label:
            label_to_ids[id_to_label[id_]].append(id_)

    valid_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half']

    device = model.device if hasattr(model, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu'
    model.eval()

    j=0

    for label in valid_labels:
        selected_ids = random.sample(label_to_ids[label], k=min(n_per_label, len(label_to_ids[label])))
        samples = [latent_vectors[id_] for id_ in selected_ids]
        samples_tensor = torch.tensor(samples, dtype=torch.float32).to(device)

        j+=1

        with torch.no_grad():
            reconstructions = model.decoder(samples_tensor)

        for i, id_ in enumerate(selected_ids):
            recon = reconstructions[i].cpu().numpy()
            recon = np.clip(recon, 1e-10, 1 - 1e-10)
            recon_binary = (recon >= 0.2).astype(np.uint8)

            try:
                mask_path = os.path.join(dataset_dir, f"{id_}.npy")
                original = np.load(mask_path)
            except Exception as e:
                print(f"Could not load original for ID {id_}: {e}")
                continue

            fig, axs = plt.subplots(2, 2, figsize=(8, 6))

            axs[0, 0].imshow(original[0, slice_idx, :, :], cmap='gray')
            axs[0, 0].set_title("Original Channel 1")
            axs[0, 0].axis('off')

            axs[0, 1].imshow(original[1, slice_idx, :, :], cmap='gray')
            axs[0, 1].set_title("Original Channel 2")
            axs[0, 1].axis('off')

            axs[1, 0].imshow(recon_binary[0, slice_idx, :, :], cmap='gray')
            axs[1, 0].set_title("Reconstructed Channel 1")
            axs[1, 0].axis('off')

            axs[1, 1].imshow(recon_binary[1, slice_idx, :, :], cmap='gray')
            axs[1, 1].set_title("Reconstructed Channel 2")
            axs[1, 1].axis('off')

            plt.suptitle(f"Label: {label} | ID: {id_}")
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            # plt.savefig(f"result_data/figures/RECONSTRUCTION_{i}{j}.png", format="png", bbox_inches="tight")
            plt.show()


plot_vae_reconstructions_per_label(
    model=model,
    latent_file_path='result_data/latent_space_test_64_balanced.json',   # latent_space_test_64_rotated_balanced.json
    metadata_path='result_data/cell_data_64_balanced.json',
    dataset_dir='result_data/masks_64_balanced',  # masks_64_rotated_balanced
    n_per_label=3,
    slice_idx=32
)

In [None]:
# Compute latent space of full balanced dataset, save original 64D embedding and 70D augmented (with scale factors).
# scaling factors are themselves scaled before augmented, such they contribute 2x variance of normal dimension for curve fitting


model_path = 'result_data/dual_branch_vae_64_balanced.pth'
metadata_path = 'result_data/cell_data_64_balanced.json'
out_latent_path = 'result_data/latent_space_full_64_balanced.json'
out_augmented_path = 'result_data/augmented_latent_space_full_64_balanced.json'

LATENT_DIM = 64
BASE_CHANNELS = 16
STRIDE = 2
BATCH_SIZE = 8

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

entries = []
for entry in metadata:
    sample_id = entry.get('id')
    mask_path = entry.get('mask_path')
    required_sizes = ['size_nucleus_1', 'size_nucleus_2', 'size_nucleus_3','size_cell_1', 'size_cell_2', 'size_cell_3']
    
    has_sizes = all(k in entry for k in required_sizes)
    if not sample_id or not mask_path:
        print(f"Skipping entry missing id or mask_path: {entry}")
        continue
    if not os.path.exists(mask_path):
        print(f"Mask missing for {sample_id}: {mask_path} — skipping")
        continue
    if not has_sizes:
        print(f"Size fields missing for {sample_id} — skipping")
        continue
    entries.append(entry)

N = len(entries)
print(f"Found {N} metadata entries with masks+sizes to process.")
if N == 0:
    raise RuntimeError("No valid entries found. Check metadata/mask paths/size fields.")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

try:
    model = DualBranchVAE(latent_dim=LATENT_DIM, base_channels=BASE_CHANNELS, stride=STRIDE).to(device)
except NameError as e:
    raise RuntimeError("DualBranchVAE is not defined in the current notebook. "
                       "Define/import the model class before running this cell.") from e

state = torch.load(model_path, map_location=device)
model.load_state_dict(state)
model.eval()
print("Model loaded and set to eval()")


# Helper to load a mask and convert to tensor
def load_mask_tensor(path):
    arr = np.load(path)
    if not isinstance(arr, np.ndarray):
        raise RuntimeError(f"Loaded object not numpy array from {path}")
    arr = arr.astype(np.float32)
    if arr.ndim != 4 or arr.shape[0] != 2:
        raise RuntimeError(f"Unexpected mask shape {arr.shape} for {path}. Expected (2, D, H, W).")
    t = torch.from_numpy(arr).unsqueeze(0)
    return t


id_list = []
embeddings = []
sizes_list = []

batch_ids = []
batch_tensors = []

with torch.no_grad():
    for entry in tqdm(entries, desc="Preparing and computing latents"):
        sid = entry['id']
        mask_path = entry['mask_path']

        try:
            t = load_mask_tensor(mask_path)
        except Exception as e:
            print(f"Error loading {mask_path}: {e} — skipping")
            continue

        batch_ids.append(sid)
        batch_tensors.append(t)

        sizes = np.array([
            entry['size_nucleus_1'],
            entry['size_nucleus_2'],
            entry['size_nucleus_3'],
            entry['size_cell_1'],
            entry['size_cell_2'],
            entry['size_cell_3']
        ], dtype=float)
        sizes_list.append(sizes)

        if len(batch_tensors) >= BATCH_SIZE or len(batch_tensors) + len(embeddings) >= N:
            batch_stack = torch.cat(batch_tensors, dim=0).to(device)
            try:
                _, mu, _ = model(batch_stack)
            except Exception as e:
                if hasattr(model, 'encode'):
                    mu, _ = model.encode(batch_stack)
                else:
                    raise

            mu_np = mu.cpu().numpy()
            for i, sid in enumerate(batch_ids):
                embeddings.append(mu_np[i].astype(float))
                id_list.append(sid)

            batch_ids = []
            batch_tensors = []


embeddings = np.vstack(embeddings)
sizes_arr = np.vstack(sizes_list[:len(embeddings)])
ids_processed = id_list
M = embeddings.shape[0]
print(f"Computed latents for {M} samples (expected <= {N}).")

latent_map = {ids_processed[i]: embeddings[i].tolist() for i in range(M)}
os.makedirs(os.path.dirname(out_latent_path), exist_ok=True)
with open(out_latent_path, 'w') as f:
    json.dump(latent_map, f, indent=2)
print(f"Saved latent space to: {out_latent_path}")

var_sizes_per_dim = np.var(sizes_arr, axis=0)
mean_var_sizes = float(np.mean(var_sizes_per_dim))

var_latent_per_dim = np.var(embeddings, axis=0)
mean_var_latent = float(np.mean(var_latent_per_dim))

if mean_var_sizes == 0:
    scale_factor = 1.0
    print("Mean variance of size features is zero. Using scale_factor = 1.0")
else:
    scale_factor = float(np.sqrt((2.0 * mean_var_latent) / mean_var_sizes))

print(f"Scale factor computed: {scale_factor:.6g}")
print(f"Mean variance (latent dim): {mean_var_latent:.6g}")
print(f"Mean variance (size dims, pre-scale): {mean_var_sizes:.6g}")
print(f"Mean variance (size dims, post-scale target): {2.0 * mean_var_latent:.6g}")


augmented_map = {}
for i, sid in enumerate(ids_processed):
    emb = embeddings[i]  
    sizes = sizes_arr[i]  
    scaled_sizes = (sizes * scale_factor).astype(float)
    augmented_vec = np.concatenate([emb, scaled_sizes])
    augmented_map[sid] = augmented_vec.tolist()

out_dict = {"scale_factor": scale_factor}
out_dict.update(augmented_map)

with open(out_augmented_path, 'w') as f:
    json.dump(out_dict, f, indent=2)

print(f"Saved augmented latent space (with scale_factor at top) to: {out_augmented_path}")
print("Done.")

In [None]:
# Produce 10 UMAP plots (5 nucleus, 5 cell) with colorings to visualize various properties of the learned latent space


base_folder = "source_data/crop_seg"
augmented_latent_path = "result_data/augmented_latent_space_full_64_balanced.json"
metadata_path = "result_data/cell_data_64_balanced.json"

valid_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half']
label_to_index = {lab: i for i, lab in enumerate(valid_labels)}
cmap_labels = plt.get_cmap("viridis")


def extract_two_channels(arr):
    """
    detect which axis is the channel axis (size 2) and return array shape (2, z, y, x)
    """
    a = np.asarray(arr)
    if a.ndim == 4:
        axes_eq_2 = [ax for ax, s in enumerate(a.shape) if s == 2]
        if len(axes_eq_2) == 1:
            cax = axes_eq_2[0]
            return np.moveaxis(a, cax, 0)
        else:
            if a.shape[0] == 2:
                return a
            if a.shape[-1] == 2:
                return np.moveaxis(a, -1, 0)
            if a.shape[1] == 2:
                return np.moveaxis(a, 1, 0)
            cax_guess = int(np.argmin(a.shape))
            if a.shape[cax_guess] == 2:
                return np.moveaxis(a, cax_guess, 0)
    elif a.ndim == 3:
        z, y, x = a.shape
        return np.stack([a, a], axis=0)
    raise ValueError(f"can't determine channel axis (shape: {a.shape})")


def compute_height(mask):
    z_nonzero = np.any(mask, axis=(1,2))
    return float(z_nonzero.sum())

def compute_volume(mask):
    return float(np.count_nonzero(mask))

def compute_major_tilt_degrees(mask):
    coords = np.argwhere(mask)
    if coords.shape[0] < 3:
        return float(np.nan)
    coords_centered = coords - coords.mean(axis=0)
    pca = PCA(n_components=3)
    try:
        pca.fit(coords_centered)
    except Exception:
        return float(np.nan)
    principal = pca.components_[0]
    z_axis = np.array([1.0, 0.0, 0.0])
    dot = np.abs(np.dot(principal, z_axis))
    dot = np.clip(dot / (np.linalg.norm(principal) * np.linalg.norm(z_axis)), -1.0, 1.0)
    angle_rad = np.arccos(dot)
    angle_deg = np.degrees(angle_rad)
    if angle_deg > 90:
        angle_deg = 180 - angle_deg
    return float(angle_deg)

def compute_sphericity(mask):
    V = np.count_nonzero(mask)
    if V <= 0:
        return float(np.nan)
        
    try:
        verts, faces, normals, values = measure.marching_cubes(mask.astype(np.uint8), level=0.5)
        tris = verts[faces]
        vec0 = tris[:,1] - tris[:,0]
        vec1 = tris[:,2] - tris[:,0]
        cross_prod = np.cross(vec0, vec1)
        tri_areas = 0.5 * np.linalg.norm(cross_prod, axis=1)
        A = float(tri_areas.sum())
        if A <= 0:
            raise RuntimeError("zero surface area")
        sphericity = (np.pi ** (1.0/3.0)) * ((6.0 * V) ** (2.0/3.0)) / A
        return float(sphericity)
    except Exception:
        padded = np.pad(mask.astype(np.uint8), pad_width=1, mode='constant', constant_values=0)
        neighbors_sum = (
            padded[2:,1:-1,1:-1] + padded[:-2,1:-1,1:-1] +
            padded[1:-1,2:,1:-1] + padded[1:-1,:-2,1:-1] +
            padded[1:-1,1:-1,2:] + padded[1:-1,1:-1,:-2]
        )
        interior = (neighbors_sum == 6) & (padded[1:-1,1:-1,1:-1]==1)
        boundary_voxels = np.count_nonzero((padded[1:-1,1:-1,1:-1]==1) & (~interior))
        A_approx = float(boundary_voxels)
        if A_approx <= 0:
            return float(np.nan)
        sphericity = (np.pi ** (1.0/3.0)) * ((6.0 * V) ** (2.0/3.0)) / A_approx
        return float(sphericity)


with open(augmented_latent_path, 'r') as f:
    aug = json.load(f)

aug_map = dict(aug)
if "scale_factor" in aug_map:
    del aug_map["scale_factor"]

ids = list(aug_map.keys())
vectors = np.array([np.asarray(aug_map[i], dtype=float) for i in ids])
N = vectors.shape[0]
print(f"Loaded {N} augmented vectors (70-D)")

nucleus_idx = list(range(0,32)) + [64,65,66]   # first 32 latent dims + nucleus sizes
cell_idx = list(range(32,64)) + [67,68,69]    # next 32 latent dims + cell sizes

nucleus_X = vectors[:, nucleus_idx]
cell_X = vectors[:, cell_idx]

with open(metadata_path, 'r') as f:
    metadata = json.load(f)
id_to_meta = {entry['id']: entry for entry in metadata}

label_numeric = np.full((N,), np.nan)
nuc_height = np.full((N,), np.nan)
nuc_volume = np.full((N,), np.nan)
nuc_sphericity = np.full((N,), np.nan)
nuc_tilt = np.full((N,), np.nan)

cell_height = np.full((N,), np.nan)
cell_volume = np.full((N,), np.nan)
cell_sphericity = np.full((N,), np.nan)
cell_tilt = np.full((N,), np.nan)


print("Computing morphological metrics for nucleus & cell (this can take a while)...")
for i, sid in enumerate(tqdm(ids)):
    meta = id_to_meta.get(sid)
    if meta is None:
        continue
    lab = meta.get('label')
    if lab in label_to_index:
        label_numeric[i] = label_to_index[lab]
    else:
        label_numeric[i] = np.nan

    seg_rel = meta.get('seg_file')
    if seg_rel is None:
        continue
        
    if seg_rel.startswith("crop_seg/") or seg_rel.startswith("./crop_seg/"):
        seg_rel_path = seg_rel.split("crop_seg/", 1)[1] if "crop_seg/" in seg_rel else seg_rel
        image_path = os.path.join(base_folder, seg_rel_path)
    else:
        image_path = os.path.join(base_folder, seg_rel)

    if not os.path.exists(image_path):
        alt = os.path.join(base_folder, seg_rel)
        if os.path.exists(alt):
            image_path = alt
        else:
            print(f"Image file not found for id {sid}: {image_path}")
            continue

    try:
        im = tifffile.imread(image_path)
        chs = extract_two_channels(im)
    except Exception as e:
        print(f"Could not read or parse {image_path}: {e}")
        continue

    nuc_mask = (chs[0] > 0.5)
    cell_mask = (chs[1] > 0.5)

    nuc_height[i] = compute_height(nuc_mask)
    nuc_volume[i] = compute_volume(nuc_mask)
    nuc_sphericity[i] = compute_sphericity(nuc_mask)
    nuc_tilt[i] = compute_major_tilt_degrees(nuc_mask)

    cell_height[i] = compute_height(cell_mask)
    cell_volume[i] = compute_volume(cell_mask)
    cell_sphericity[i] = compute_sphericity(cell_mask)
    cell_tilt[i] = compute_major_tilt_degrees(cell_mask)


print("Running UMAP on nucleus and cell feature sets...")
umap_nuc = umap_module.UMAP(n_components=2, random_state=42)
umap_cell = umap_module.UMAP(n_components=2, random_state=42)

nuc_emb = umap_nuc.fit_transform(nucleus_X)
cell_emb = umap_cell.fit_transform(cell_X)

def normalize_for_cmap(vals):
    vals = np.array(vals, dtype=float)
    mask = ~np.isnan(vals)
    v = vals.copy()
    if np.any(mask):
        mn, mx = np.nanmin(v), np.nanmax(v)
        if mn == mx:
            return np.zeros_like(v)
        return (v - mn) / (mx - mn)
    else:
        return np.zeros_like(v)


# Plotting:
fig, axes = plt.subplots(2, 6, figsize=(24, 8))
plt.subplots_adjust(wspace=0.6, hspace=0.3)

label_norm = (label_numeric - np.nanmin(label_numeric)) / (np.nanmax(label_numeric) - np.nanmin(label_numeric) + 1e-12)
label_colors = cmap_labels(np.clip(label_norm, 0, 1))

unique_labels = sorted(set(label_to_index.keys()), key=lambda x: label_to_index[x])
for row in [0, 1]:
    ax = axes[row, 0]
    ax.axis("off")
    ax.set_title("Label key", fontsize=12)
    for i, lab in enumerate(unique_labels):
        color = cmap_labels(label_to_index[lab] / (len(unique_labels)-1 + 1e-12))
        ax.scatter([], [], c=[color], s=40, label=lab)
    ax.legend(loc="center", fontsize=10, frameon=False)

nuc_metrics = [
    ("True Phase Label", label_numeric, label_colors, None),
    ("Height (slices)", nuc_height, None, "viridis"),
    ("Volume (voxels)", nuc_volume, None, "plasma"),
    ("Sphericity", nuc_sphericity, None, "magma"),
    ("Major tilt (deg)", nuc_tilt, None, "cividis"),
]
for j, (title, vals, color_vals, cmap_name) in enumerate(nuc_metrics):
    ax = axes[0, j+1]
    if vals is None:
        vals = np.zeros(N)
    if color_vals is None:
        norm_vals = normalize_for_cmap(vals)
        colors = plt.get_cmap(cmap_name)(norm_vals)
    else:
        colors = color_vals
    ax.scatter(nuc_emb[:,0], nuc_emb[:,1], c=colors, s=8)
    ax.set_title(f"Nucleus UMAP — {title}")
    ax.set_xlabel("UMAP1"); ax.set_ylabel("UMAP2")
    ax.set_xticks([]); ax.set_yticks([])
    if color_vals is None:
        sm = plt.cm.ScalarMappable(cmap=cmap_name, norm=plt.Normalize(vmin=np.nanmin(vals), vmax=np.nanmax(vals)))
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=8)

cell_metrics = [
    ("True label", label_numeric, label_colors, None),
    ("Height (slices)", cell_height, None, "viridis"),
    ("Volume (voxels)", cell_volume, None, "plasma"),
    ("Sphericity", cell_sphericity, None, "magma"),
    ("Major tilt (deg)", cell_tilt, None, "cividis"),
]
for j, (title, vals, color_vals, cmap_name) in enumerate(cell_metrics):
    ax = axes[1, j+1]
    if vals is None:
        vals = np.zeros(N)
    if color_vals is None:
        norm_vals = normalize_for_cmap(vals)
        colors = plt.get_cmap(cmap_name)(norm_vals)
    else:
        colors = color_vals
    ax.scatter(cell_emb[:,0], cell_emb[:,1], c=colors, s=8)
    ax.set_title(f"Cell UMAP — {title}")
    ax.set_xlabel("UMAP1"); ax.set_ylabel("UMAP2")
    ax.set_xticks([]); ax.set_yticks([])
    if color_vals is None:
        sm = plt.cm.ScalarMappable(cmap=cmap_name, norm=plt.Normalize(vmin=np.nanmin(vals), vmax=np.nanmax(vals)))
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=8)

plt.suptitle("UMAP (Nucleus / Cell) — Colorings: label, height, volume, sphericity, tilt", fontsize=18)
plt.savefig("result_data/figures/UMAP_COLOURINGS.png", format="png", bbox_inches="tight")
plt.show()

<p style="height: 200px;"></p>

## Principal Curve

In [None]:
# Compute mean latent space vectors for each phase label based on the full balanced data, visualize 2D slices


latent_path = 'result_data/latent_space_full_64_balanced.json'
metadata_path = 'result_data/cell_data_64_balanced.json'
model_path = 'result_data/dual_branch_vae_64_balanced.pth'

with open(latent_path, 'r') as f:
    latent_data = json.load(f)

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

id_to_label = {entry["id"]: entry["label"] for entry in metadata}

label_to_latents = {}
for id_, vec in latent_data.items():
    if id_ in id_to_label:
        label = id_to_label[id_]
        label_to_latents.setdefault(label, []).append(vec)

label_means = {
    label: np.mean(vectors, axis=0)
    for label, vectors in label_to_latents.items()
}

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

ordered_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half', 'M0']

with torch.no_grad():
    for label in ordered_labels:
        if label not in label_means:
            print(f"Skipping label {label}: no latent data.")
            continue

        mean_vec = torch.tensor(label_means[label], dtype=torch.float32).to(device).unsqueeze(0)
        recon = model.decode(mean_vec).cpu().squeeze(0).numpy()
        
        recon = np.clip(recon, 1e-10, 1 - 1e-10)
        recon_binary = (recon >= 0.15).astype(np.uint8)

        fig, axs = plt.subplots(1, 2, figsize=(6, 3))
        axs[0].imshow(recon_binary[0, :, 32, :], cmap='gray')
        axs[0].set_title(f"{label} — Nucleus")
        axs[0].axis('off')

        axs[1].imshow(recon_binary[1, :, 32, :], cmap='gray')
        axs[1].set_title(f"{label} — Cell Body")
        axs[1].axis('off')

        plt.suptitle(f"Mean Shape Reconstruction: {label}", fontsize=14)
        plt.tight_layout()
        plt.show()

In [None]:
# Visualize same phase means as above, but in 3D interactive graphs


latent_path = 'result_data/latent_space_full_64_balanced.json'
metadata_path = 'result_data/cell_data_64_balanced.json'
model_path = 'result_data/dual_branch_vae_64_balanced.pth'

with open(latent_path, 'r') as f:
    latent_data = json.load(f)

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

id_to_label = {entry["id"]: entry["label"] for entry in metadata}

label_to_latents = {}
for id_, vec in latent_data.items():
    if id_ in id_to_label:
        label = id_to_label[id_]
        label_to_latents.setdefault(label, []).append(vec)

label_means = {
    label: np.mean(vectors, axis=0)
    for label, vectors in label_to_latents.items()
}

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

ordered_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half', 'M0']


def plot_3d_voxel(voxel_data, title="3D Reconstruction"):
    """Plot a 3D binary voxel volume using Plotly with fixed axis scales."""
    x, y, z = np.where(voxel_data == 1)
    fig = go.Figure(data=go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(size=2, color=z, colorscale='Viridis', opacity=0.8)
    ))
    fig.update_layout(
        scene=dict(
            xaxis=dict(title='X', range=[0, 64]),
            yaxis=dict(title='Y', range=[0, 64]),
            zaxis=dict(title='Z', range=[0, 64]),
            aspectmode='cube'
        ),
        title=title,
        margin=dict(l=0, r=0, b=0, t=30)
    )
    fig.show()


with torch.no_grad():
    for label in ordered_labels:
        if label not in label_means:
            print(f"Skipping label {label}: no latent data.")
            continue

        mean_vec = torch.tensor(label_means[label], dtype=torch.float32).to(device).unsqueeze(0)
        recon = model.decode(mean_vec).cpu().squeeze(0).numpy()

        recon = np.clip(recon, 1e-10, 1 - 1e-10)
        recon_binary = (recon >= 0.15).astype(np.uint8)

        print(f"Label: {label} — Nucleus")
        plot_3d_voxel(recon_binary[0], title=f"{label} — Nucleus")

        print(f"Label: {label} — Cell Body")
        plot_3d_voxel(recon_binary[1], title=f"{label} — Cell Body")

In [None]:
# Defines implementation for our own principle curve finder; it is a L1 regression regularized by arc length and curvature,
#         that must pass through all phase means and complete a full cycle -- 100 points are then sampled from the curve. 
#         The curve is defined by the parameter t, with cubic splines to represent each dimension of the curve


def load_phase_means(latent_path, metadata_path, valid_labels):
    with open(latent_path, 'r') as f:
        latent_data = json.load(f)
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    id_to_label = {entry['id']: entry['label'] for entry in metadata}
    vectors_by_phase = {label: [] for label in valid_labels}
    for id_, vec in latent_data.items():
        label = id_to_label.get(id_)
        if label in valid_labels:
            vectors_by_phase[label].append(vec)
    phase_means = []
    for label in valid_labels:
        phase_vecs = np.array(vectors_by_phase[label])
        phase_means.append(phase_vecs.mean(axis=0))
    return np.array(phase_means)


def curve_length(points):
    diffs = np.diff(points, axis=0)
    return np.sum(np.linalg.norm(diffs, axis=1))


def curvature_penalty(points):
    second_diffs = np.diff(points, n=2, axis=0)
    return np.sum(np.linalg.norm(second_diffs, axis=1) ** 2)


def initialize_curve_from_means(fixed_points, num_points):
    dim = fixed_points.shape[1]
    num_phases = len(fixed_points)
    points_per_segment = num_points // num_phases
    remainder = num_points % num_phases
    interp_points = []
    for i in range(num_phases):
        start = fixed_points[i]
        end = fixed_points[(i + 1) % num_phases]
        steps = points_per_segment + (1 if i < remainder else 0)
        alphas = np.linspace(0, 1, steps + 1, endpoint=False)
        if i > 0:
            alphas = alphas[1:]
        for alpha in alphas:
            interp_points.append((1 - alpha) * start + alpha * end)
    curve_init = np.array(interp_points[:num_points])
    assert curve_init.shape == (num_points, dim)
    return curve_init


def curve_objective_free(params, fixed_points, num_points, lambda_length, lambda_curvature, lambda_regression, data_points):
    dim = fixed_points.shape[1]
    num_fixed = len(fixed_points)
    num_free = num_points - num_fixed
    curve = np.zeros((num_points, dim))
    curve[:num_fixed] = fixed_points
    curve[num_fixed:] = params.reshape((num_free, dim))
    length_term = curve_length(curve)
    curvature_term = curvature_penalty(curve)
    regression_term = 0                        # Regression error (L1 to nearest point on curve)
    for dp in data_points:
        dists = np.sum(np.abs(curve - dp), axis=1)
        regression_term += np.min(dists)
    regression_term /= len(data_points)

    total = (
        lambda_length * length_term +
        lambda_curvature * curvature_term +
        lambda_regression * regression_term
    )

    print(f"Length: {length_term:.4f}, Curv: {curvature_term:.4f}, Reg: {regression_term:.4f}, Total: {total:.4f}")
    return total


def optimize_closed_curve(fixed_points, num_points, max_iter, lambda_length, lambda_curvature, lambda_regression, data_points):
    dim = fixed_points.shape[1]
    num_fixed = len(fixed_points)
    num_free = num_points - num_fixed
    curve_init = initialize_curve_from_means(fixed_points, num_points)
    free_part_init = curve_init[num_fixed:]
    x0 = free_part_init.flatten()

    result = minimize(
        lambda x: curve_objective_free(x, fixed_points, num_points, lambda_length, lambda_curvature, lambda_regression, data_points),
        x0,
        method='L-BFGS-B',
        options={'maxiter': max_iter, 'disp': True}
    )

    optimized = np.zeros((num_points, dim))
    optimized[:num_fixed] = fixed_points
    optimized[num_fixed:] = result.x.reshape((num_free, dim))
    return optimized


def parametrize_curve(points):
    diffs = np.diff(points, axis=0)
    segment_lengths = np.linalg.norm(diffs, axis=1)
    arc_lengths = np.insert(np.cumsum(segment_lengths), 0, 0)
    t_vals = arc_lengths / arc_lengths[-1]
    return t_vals


def get_curve_interpolator(t_vals, curve_points):
    interpolators = []
    for dim in range(curve_points.shape[1]):
        cs = CubicSpline(t_vals, curve_points[:, dim], bc_type='natural')
        interpolators.append(cs)

    def gamma(t_query):
        t_query = np.asarray(t_query)
        return np.stack([cs(t_query) for cs in interpolators], axis=-1)

    return gamma


def sample_curve(gamma_fn, num_points=100):
    ts = np.linspace(0, 0.525, num_points, endpoint=False)
    return gamma_fn(ts)


def save_curve(points, path):
    with open(path, 'w') as f:
        json.dump(points.tolist(), f)

In [None]:
# Runs fitting of curve to our full latent space data


metadata_path = 'result_data/cell_data_64_balanced.json'
valid_labels = ['M0','M1M2','M3','M4M5','M6M7_early','M6M7_half']

latent_path = 'result_data/latent_space_full_64_balanced.json'
output_path = 'result_data/latent_principal_curve_64D_100pts.json'

# latent_path = 'result_data/augmented_latent_space_full_64_balanced.json'       # Run after to produce augmented principal curve
# output_path = 'result_data/augmented_latent_principal_curve_64D_100pts.json'

phase_means = load_phase_means(latent_path, metadata_path, valid_labels)

with open(latent_path, 'r') as f:
    latent_data = json.load(f)

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

id_to_label = {entry['id']: entry['label'] for entry in metadata}

data_points = []
for id_, vec in latent_data.items():
    if id_to_label.get(id_) in valid_labels:
        data_points.append(vec)

data_points = np.array(data_points)

optimized_curve = optimize_closed_curve(
    fixed_points=phase_means,
    num_points=100,
    max_iter=10000,
    lambda_length=1,
    lambda_curvature=1,
    lambda_regression=1,
    data_points=data_points
)

t_vals = parametrize_curve(optimized_curve)
gamma = get_curve_interpolator(t_vals, optimized_curve)
sampled_points = sample_curve(gamma, num_points=100)
save_curve(sampled_points, output_path)

print("done")

In [None]:
# Creates and saves a smooth, mean-reverting curve with stochastic deviations off of the principal curve


def sample_deviated_curve(base_curve, deviation_scale=0.1, reversion_strength=0.05, seed=72):
    """
    Adds smooth, mean-reverting stochastic deviations to a principal curve
    """
    np.random.seed(seed)
    num_points, dim = base_curve.shape
    deviations = np.zeros((num_points, dim))
    
    for d in range(dim):
        for i in range(1, num_points):
            noise = np.random.normal(0, deviation_scale)
            drift = -reversion_strength * deviations[i - 1, d]
            deviations[i, d] = deviations[i - 1, d] + drift + noise
            
    deviated_curve = base_curve + deviations
    return deviated_curve


deviated_curve = sample_deviated_curve(
    base_curve=sampled_points,
    deviation_scale=0.1,        # adjusts noise level
    reversion_strength=0.05,     # controls strength of tendency to stay near the original curve
    seed=1337 
)

deviated_output_path = 'result_data/latent_principal_curve_64D_100pts_deviated.json'
save_curve(deviated_curve, deviated_output_path)
print("Saved deviated curve to:", deviated_output_path)

In [None]:
# Creates 3D animation along principal curve (without consideration of scaling factors)


with open('result_data/latent_principal_curve_64D_100pts.json', 'r') as f:
    curve_points = json.load(f)

model_path = 'result_data/dual_branch_vae_64_balanced.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

recon_channel = []
with torch.no_grad():
    for i, vec in enumerate(curve_points):
        latent = torch.tensor(vec, dtype=torch.float32).to(device).unsqueeze(0)
        recon = model.decode(latent).cpu().squeeze(0).numpy()
        recon = np.clip(recon, 1e-10, 1 - 1e-10)
        recon_binary = (recon >= 0.15).astype(np.uint8)
        recon_channel.append(recon_binary[0])
        # recon_channel.append(recon_binary[1])   # for cell channel opposed to nucleus

frames = []
for i, volume in enumerate(recon_channel):
    x, y, z = np.where(volume == 1)
    scatter = go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(size=2, color=z, colorscale='Viridis', opacity=0.8),
        name=f"Frame {i}"
    )
    frame = go.Frame(data=[scatter], name=str(i))
    frames.append(frame)

init_x, init_y, init_z = np.where(recon_channel[0] == 1)
# init_x, init_y, init_z = np.where(recon_channel[1] == 1)   # for cell channel opposed to nucleus
fig = go.Figure(
    data=[go.Scatter3d(
        x=init_x, y=init_y, z=init_z,
        mode='markers',
        marker=dict(size=2, color=init_z, colorscale='Viridis', opacity=0.7)
    )],
    layout=go.Layout(
        title="3D Cell Nucleus Morphing Along Principal Curve",
        scene=dict(
            xaxis=dict(range=[0, 64], title='X'),
            yaxis=dict(range=[0, 64], title='Y'),
            zaxis=dict(range=[0, 64], title='Z'),
            aspectmode='cube'
        ),
        updatemenus=[dict(
            type='buttons',
            buttons=[
                dict(label='▶️ Play', method='animate', args=[None, {"frame": {"duration": 300, "redraw": True}, "fromcurrent": True, "loop": True}]),
                dict(label='⏸ Pause', method='animate', args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}])
            ]
        )]
    ),
    frames=frames
)

fig.show()

fig.write_html("result_data/figures/nucleus_morphing_principal_curve.html")
# fig.write_html("result_data/figures/cell_morphing_principal_curve.html")

In [None]:
# Creates same 3D animation along principal curve but with mesh visualization


with open('result_data/latent_principal_curve_64D_100pts.json', 'r') as f:
    curve_points = json.load(f)

model_path = 'result_data/dual_branch_vae_64_balanced.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

recon_channel = []
with torch.no_grad():
    for vec in curve_points:
        latent = torch.tensor(vec, dtype=torch.float32).to(device).unsqueeze(0)
        recon = model.decode(latent).cpu().squeeze(0).numpy()
        recon = np.clip(recon, 1e-10, 1 - 1e-10)
        recon_binary = (recon >= 0.15).astype(np.uint8)
        recon_channel.append(recon_binary[0])   # nucleus channel
        # recon_channel.append(recon_binary[1]) # cell channel

def volume_to_mesh(volume, level=0.5):
    verts, faces, normals, values = measure.marching_cubes(volume, level=level)
    x, y, z = verts.T
    i, j, k = faces.T
    mesh = go.Mesh3d(
        x=x, y=y, z=z,
        i=i, j=j, k=k,
        color='teal',
        opacity=0.6,
        name="mesh"
    )
    return mesh

frames = []
for i, volume in enumerate(recon_channel):
    mesh = volume_to_mesh(volume, level=0.5)
    frame = go.Frame(data=[mesh], name=str(i))
    frames.append(frame)

init_mesh = volume_to_mesh(recon_channel[0], level=0.5)   # nucleus channel
# init_mesh = volume_to_mesh(recon_channel[1], level=0.5)  # cell channel

fig = go.Figure(
    data=[init_mesh],
    layout=go.Layout(
        title="3D Cell Nucleus Morphing Along Principal Curve (Mesh Surface)",
        scene=dict(
            xaxis=dict(range=[0, 64], title='X'),
            yaxis=dict(range=[0, 64], title='Y'),
            zaxis=dict(range=[0, 64], title='Z'),
            aspectmode='cube'
        ),
        updatemenus=[dict(
            type='buttons',
            buttons=[
                dict(label='▶️ Play', method='animate',
                     args=[None, {"frame": {"duration": 300, "redraw": True},
                                  "fromcurrent": True, "loop": True}]),
                dict(label='⏸ Pause', method='animate',
                     args=[[None], {"frame": {"duration": 0, "redraw": False},
                                    "mode": "immediate",
                                    "transition": {"duration": 0}}])
            ]
        )]
    ),
    frames=frames
)

fig.show()

fig.write_html("result_data/figures/nucleus_morphing_principal_curve_mesh.html")
# fig.write_html("result_data/figures/cell_morphing_principal_curve_mesh.html")

In [None]:
# Creates 3D animation along principal curve with consideration of scaling factors, changing axis-scale zooms as needed.
# Note that the 3D images were provided with higher single-layer pixels than number of layers, so depth axis (z)
# is also scaled by average ratio of axes to be compared properly against the other dimensions


def get_ratio():
    metadata_path = "result_data/cell_data_64_balanced.json"
    with open(metadata_path, "r") as f:
        metadata = json.load(f)
    size_cell_1_vals = []
    size_cell_23_vals = []
    for entry in metadata:
        if "size_cell_1" in entry and "size_cell_2" in entry and "size_cell_3" in entry:
            s1 = entry["size_cell_1"]
            s2 = entry["size_cell_2"]
            s3 = entry["size_cell_3"]
            size_cell_1_vals.append(s1)
            size_cell_23_vals.append((s2 + s3) / 2.0)
    avg_s1 = np.mean(size_cell_1_vals)
    avg_s23 = np.mean(size_cell_23_vals)
    ratio = avg_s23 / avg_s1 if avg_s1 != 0 else float("inf")
    print(f"Ratio (avg_s23 / avg_s1): {ratio:.4f}")
    return ratio

ratio = get_ratio()

with open('result_data/augmented_latent_principal_curve_64D_100pts.json', 'r') as f:
    curve_points = json.load(f)
    
model_path = 'result_data/dual_branch_vae_64_balanced.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

recon_channel = []
scale_vectors = [] # holds (sz, sy, sx)

with torch.no_grad():
    for vec in curve_points:
        vec = np.array(vec, dtype=np.float32)

        latent64 = torch.tensor(vec[:64], dtype=torch.float32).to(device).unsqueeze(0)
        scale_vec = vec[64:67]
        #scale_vec = vec[67:70]  # for cell channel

        recon = model.decode(latent64).cpu().squeeze(0).numpy()
        recon = np.clip(recon, 1e-10, 1 - 1e-10)
        recon_binary = (recon >= 0.15).astype(np.uint8)
        volume = recon_binary[0]   
        # volume = recon_binary[1]    # for cell channel

        recon_channel.append(volume)
        scale_norm = 1 + 0.5 * np.tanh(scale_vec)  
        scale_vectors.append(scale_vec*0.07)


frames = []
for i, (volume, scale) in enumerate(zip(recon_channel, scale_vectors)):
    x, y, z = np.where(volume == 1)
    scatter = go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(size=2, color=z, colorscale='Viridis', opacity=0.8),
        name=f"Frame {i}"
    )

    frame = go.Frame(
        data=[scatter],
        name=str(i),
        layout=go.Layout(
            scene=dict(
                aspectratio=dict(z = ratio * scale[0], y=scale[1], x=scale[2])
            )
        )
    )
    frames.append(frame)

init_x, init_y, init_z = np.where(recon_channel[0] == 1)
# init_x, init_y, init_z = np.where(recon_channel[1] == 1)   # for cell channel
fig = go.Figure(
    data=[go.Scatter3d(
        x=init_x, y=init_y, z=init_z,
        mode='markers',
        marker=dict(size=2, color=init_z, colorscale='Viridis', opacity=0.8)
    )],
    layout=go.Layout(
        title="3D Cell Nucleus Morphing with Latent-Derived Scaling",
        scene=dict(
            xaxis=dict(range=[0, 64], title='X', showgrid=True),
            yaxis=dict(range=[0, 64], title='Y', showgrid=True),
            zaxis=dict(range=[0, 64], title='Z', showgrid=True),
            aspectratio=dict(
                z=ratio*scale_vectors[1][0],
                y=scale_vectors[1][1],
                x=scale_vectors[1][2]
            )
        ),
        updatemenus=[dict(
            type='buttons',
            buttons=[
                dict(label='▶️ Play', method='animate',
                     args=[None, {"frame": {"duration": 300, "redraw": True},
                                  "fromcurrent": True, "loop": True}]),
                dict(label='⏸ Pause', method='animate',
                     args=[[None], {"frame": {"duration": 0, "redraw": False},
                                    "mode": "immediate", "transition": {"duration": 0}}])
            ]
        )]
    ),
    frames=frames
)

fig.show()
fig.write_html("result_data/figures/nucleus_morphing_principal_curve_scaled.html")
# fig.write_html("result_data/figures/cell_morphing_principal_curve_scaled.html")

In [None]:
# Visualize central 2D slices of reconstructed points along entire principle curve


curve_path = 'result_data/latent_principal_curve_64D_100pts.json'
model_path = 'result_data/dual_branch_vae_64_balanced.pth'

with open(curve_path, 'r') as f:
    curve_points = json.load(f)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

with torch.no_grad():
    for i, vec in enumerate(curve_points):
        latent = torch.tensor(vec, dtype=torch.float32).to(device).unsqueeze(0)
        recon = model.decode(latent).cpu().squeeze(0).numpy()

        recon = np.clip(recon, 1e-10, 1 - 1e-10)
        recon_binary = (recon >= 0.15).astype(np.uint8)

        fig, axs = plt.subplots(1, 2, figsize=(6, 3))
        axs[0].imshow(recon_binary[0, :, 32, :], cmap='gray')
        axs[0].set_title(f"Nucleus – Point {i+1}")
        axs[0].axis('off')

        axs[1].imshow(recon_binary[1, :, 32, :], cmap='gray')
        axs[1].set_title(f"Cell Body – Point {i+1}")
        axs[1].axis('off')

        plt.suptitle(f"Principal Curve Reconstruction – Point {i+1}", fontsize=14)
        plt.tight_layout()
        plt.show()

In [None]:
# Visualize principal curve UMAP in 3D with randomnly sampled stochastic path


latent_path = 'result_data/augmented_latent_space_full_64_balanced.json'
metadata_path = 'result_data/cell_data_64_balanced.json'
curve_path = 'result_data/augmented_latent_principal_curve_64D_100pts.json'
curve_path_deviated = 'result_data/augmented_latent_principal_curve_64D_100pts_deviated.json'

with open(latent_path, 'r') as f:
    latent_data = json.load(f)

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

with open(curve_path, 'r') as f:
    principal_curve = json.load(f)

with open(curve_path_deviated, 'r') as f:
    deviated_curve = json.load(f)

id_to_label = {entry["id"]: entry["label"] for entry in metadata}
label_to_latents = defaultdict(list)
all_latents = []
all_labels = []

for id_, vec in latent_data.items():
    if id_ in id_to_label:
        label = id_to_label[id_]
        label_to_latents[label].append(vec)
        all_latents.append(vec)
        all_labels.append(label)

combined = np.array(all_latents + principal_curve + deviated_curve)

umap_model = umap.UMAP(n_components=3, random_state=42)
embedding = umap_model.fit_transform(combined)

n_data = len(all_latents)
n_principal = len(principal_curve)
n_deviated = len(deviated_curve)

data_points = embedding[:n_data]
curve_embedded = embedding[n_data:n_data + n_principal]
deviated_embedded = embedding[n_data + n_principal:]

fig = go.Figure()

label_colors = {
    'M0': 'blue',
    'M1M2': 'green',
    'M3': 'orange',
    'M4M5': 'purple',
    'M6M7_early': 'brown',
    'M6M7_half': 'pink'
}

for label, color in label_colors.items():
    indices = [i for i, l in enumerate(all_labels) if l == label]
    if not indices:
        continue
    points = data_points[indices]
    fig.add_trace(go.Scatter3d(
        x=points[:, 0], y=points[:, 1], z=points[:, 2],
        mode='markers',
        marker=dict(size=3, color=color, opacity=0.3),
        name=label,
        text=[label]*len(points),
        hoverinfo='text'
    ))

fig.add_trace(go.Scatter3d(
    x=[p[0] for p in curve_embedded],
    y=[p[1] for p in curve_embedded],
    z=[p[2] for p in curve_embedded],
    mode='lines+markers',
    line=dict(color='green', width=4),
    marker=dict(size=5, color='green'),
    name='Principal Curve',
    text=[f"Curve Point {i+1}" for i in range(len(curve_embedded))],
    hoverinfo='text'
))

fig.add_trace(go.Scatter3d(
    x=[p[0] for p in deviated_embedded],
    y=[p[1] for p in deviated_embedded],
    z=[p[2] for p in deviated_embedded],
    mode='lines+markers',
    line=dict(color='red', width=4),
    marker=dict(size=5, color='red'),
    name='Deviated Curve',
    text=[f"Deviated Point {i+1}" for i in range(len(deviated_embedded))],
    hoverinfo='text'
))

fig.update_layout(
    title="3D UMAP: Principal Curve Through Cell Cycle Latent Space",
    scene=dict(
        xaxis_title='UMAP-1',
        yaxis_title='UMAP-2',
        zaxis_title='UMAP-3'
    ),
    legend=dict(x=0.02, y=0.98),
    margin=dict(l=0, r=0, b=0, t=40)
)

fig.show()
fig.write_html("result_data/figures/augmented_principal_curve_3D_UMAP.html")

In [None]:
# Visualize UMAP in 3D with tube of specified variance surrounding principal curve,
# Requires quantiles of projection distances to be calculated


latent_path = 'result_data/augmented_latent_space_full_64_balanced.json'
metadata_path = 'result_data/cell_data_64_balanced.json'
curve_path = 'result_data/augmented_latent_principal_curve_64D_100pts.json'

tube_percentile = 50          # e.g. 50 => tube covers ~50% of points per section
percentiles_to_report = [25, 50, 75, 90, 95]
n_chunks = 20                 # number of sections along curve
n_theta = 20                  
scaling_multiplier = 0.5
random_state = 42

with open(latent_path, 'r') as f:
    latent_data_dict = json.load(f)

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

with open(curve_path, 'r') as f:
    principal_curve = np.array(json.load(f))

id_to_label = {entry['id']: entry['label'] for entry in metadata}
valid_labels = ['M0','M1M2','M3','M4M5','M6M7_early','M6M7_half']

all_ids = []
all_latents = []
all_labels = []
for id_, vec in latent_data_dict.items():
    lbl = id_to_label.get(id_)
    if lbl is None:
        continue
    all_ids.append(id_)
    all_latents.append(np.array(vec, dtype=float))
    all_labels.append(lbl)

all_latents = np.vstack(all_latents)
principal_curve = np.asarray(principal_curve, dtype=float)
n_data, D = all_latents.shape
n_curve = principal_curve.shape[0]

print(f"Loaded {n_data} data points, latent dim = {D}, curve points = {n_curve}")

combined = np.vstack((all_latents, principal_curve))
um = umap.UMAP(n_components=3, random_state=random_state)
embedding = um.fit_transform(combined)
data_emb = embedding[:n_data]
curve_emb = embedding[n_data:]

diff = all_latents[:, None, :] - principal_curve[None, :, :]
distances_high = np.linalg.norm(diff, axis=2)

nearest_idx = np.argmin(distances_high, axis=1)
d_high_min = distances_high[np.arange(n_data), nearest_idx]

d_low_min = np.linalg.norm(data_emb - curve_emb[nearest_idx], axis=1)

eps = 1e-9
ratios = d_low_min / (d_high_min + eps)
ratios_clean = ratios[np.isfinite(ratios)]
if ratios_clean.size == 0:
    scale_ratio = 1.0
else:
    scale_ratio = float(np.median(ratios_clean))
print(f"Scale ratio (median low/high) used to convert radii -> UMAP units: {scale_ratio:.6g}")

radii_high = np.zeros(n_curve, dtype=float)
global_percentile_default = np.percentile(d_high_min, tube_percentile)
for k in range(n_curve):
    mask_k = (nearest_idx == k)
    if np.any(mask_k):
        radii_high[k] = np.percentile(d_high_min[mask_k], tube_percentile)
    else:
        radii_high[k] = global_percentile_default

radii_umb = radii_high * scale_ratio * scaling_multiplier

vec_norms = np.linalg.norm(all_latents, axis=1)
typical_size = np.median(vec_norms)
print("\nGlobal projection distance percentiles (HIGH-D) and relative to median vector norm:")
for p in percentiles_to_report:
    val = np.percentile(d_high_min, p)
    print(f"  {p:2d}th pct: {val:.6g}   (relative = {val/typical_size:.6g} of median vector norm)")

indices_per_chunk = np.array_split(np.arange(n_curve), n_chunks)
chunk_mean_high = []
chunk_count = []
for chunk_idx, inds in enumerate(indices_per_chunk):
    assigned_mask = np.isin(nearest_idx, inds)
    vals = d_high_min[assigned_mask]
    chunk_count.append(vals.size)
    if vals.size:
        chunk_mean_high.append(float(np.mean(vals)))
    else:
        chunk_mean_high.append(float(np.nan))

print(f"\nChunk averages (HIGH-D) over {n_chunks} chunks (NaN => no assigned points):")
for ci, (cnt, meanv) in enumerate(zip(chunk_count, chunk_mean_high)):
    print(f"  chunk {ci+1:2d}: count={cnt:4d}, mean_dist={np.nan if np.isnan(meanv) else round(meanv,6)}")


def make_tube_mesh(curve_pts, radii, n_theta=20):
    """Return vertices (Vx3) and triangle indices (i,j,k) for a tube around curve_pts"""
    n = len(curve_pts)
    thetas = np.linspace(0, 2*np.pi, n_theta, endpoint=False)
    verts = []
    for i in range(n):
        p = curve_pts[i]
        
        if i == 0:
            tangent = curve_pts[1] - curve_pts[0]
        elif i == n-1:
            tangent = curve_pts[-1] - curve_pts[-2]
        else:
            tangent = curve_pts[i+1] - curve_pts[i-1]
        tangent = tangent.astype(float)
        tn = np.linalg.norm(tangent)
        if tn < 1e-8:
            tangent = np.array([1.0, 0.0, 0.0])
            tn = 1.0
        tangent = tangent / tn
        arb = np.array([0.0, 0.0, 1.0])
        if abs(np.dot(arb, tangent)) > 0.9:
            arb = np.array([0.0, 1.0, 0.0])
        u = np.cross(tangent, arb)
        u_norm = np.linalg.norm(u)
        if u_norm < 1e-8:
            u = np.array([1.0, 0.0, 0.0])
            u_norm = 1.0
        u = u / u_norm
        v = np.cross(tangent, u)
        v = v / np.linalg.norm(v)
        r = radii[i]
        for th in thetas:
            verts.append(p + r * (u * np.cos(th) + v * np.sin(th)))
    verts = np.array(verts)

    tri_i = []
    tri_j = []
    tri_k = []
    for i in range(n - 1):
        for j in range(n_theta):
            a = i * n_theta + j
            b = i * n_theta + (j + 1) % n_theta
            c = (i + 1) * n_theta + (j + 1) % n_theta
            d = (i + 1) * n_theta + j
            tri_i.append(a); tri_j.append(b); tri_k.append(c)
            tri_i.append(a); tri_j.append(c); tri_k.append(d)
    return verts, (np.array(tri_i), np.array(tri_j), np.array(tri_k))

verts, (tri_i, tri_j, tri_k) = make_tube_mesh(curve_emb, radii_umb, n_theta=n_theta)
print(f"\nTube mesh: vertices={verts.shape[0]}, triangles={tri_i.shape[0]}")

label_colors = {
    'M0': 'blue',
    'M1M2': 'green',
    'M3': 'orange',
    'M4M5': 'purple',
    'M6M7_early': 'brown',
    'M6M7_half': 'pink'
}

fig = go.Figure()

for label, color in label_colors.items():
    idxs = [i for i, l in enumerate(all_labels) if l == label]
    if not idxs:
        continue
    pts = data_emb[idxs]
    fig.add_trace(go.Scatter3d(
        x=pts[:,0], y=pts[:,1], z=pts[:,2],
        mode='markers', marker=dict(size=3, color=color, opacity=0.3),
        name=label, text=[label]*len(pts), hoverinfo='text'
    ))

fig.add_trace(go.Scatter3d(
    x=curve_emb[:,0], y=curve_emb[:,1], z=curve_emb[:,2],
    mode='lines+markers', line=dict(color='red', width=3),
    marker=dict(size=4, color='red'), name='Principal Curve'
))

fig.add_trace(go.Mesh3d(
    x=verts[:,0], y=verts[:,1], z=verts[:,2],
    i=tri_i, j=tri_j, k=tri_k,
    color='rgba(173,216,230,0.6)', opacity=0.40, name=f'Tube ({tube_percentile}th pct)'
))

fig.update_layout(
    title=f"3D UMAP with Principal Curve + Variable-Radius Tube",
    scene=dict(xaxis_title='UMAP-1', yaxis_title='UMAP-2', zaxis_title='UMAP-3'),
    margin=dict(l=0, r=0, b=0, t=40)
)

fig.show()
fig.write_html("result_data/figures/augmented_principal_curve_3D_UMAP_Tube.html")

## Ciao!