In [1]:
# Attention: ACDC data must be obtained and put into the folder 'ACDC-Data' to run this notebook

import numpy as np
import matplotlib.pyplot as plt
import joblib
from ipywidgets import interact
import pickle
import jax.numpy as jnp

from morphomatics.manifold import Kendall
from morphomatics.manifold import Bezierfold
from morphomatics.stats import RiemannianRegression

In [6]:
n_segments = 2
degree = 3

def reg(Y, t):
    return RiemannianRegression(M, Y, t, degree, n_segments, iscycle=True).trend

# setup Kendall shape space
M = Kendall([263, 2])


#list of groups to map to kendall 
group_list = ["ACDC-Data/contours_registered_DCM.joblib", "ACDC-Data/contours_registered_HCM.joblib", "ACDC-Data/contours_registered_MINF.joblib","ACDC-Data/contours_registered_NOR.joblib"]
pickle_list = ["trends/trends_DCM.pkl", "trends/trends_HCM.pkl", "trends/trends_MINF.pkl", "trends/trends_NOR.pkl"]

surf_full = []
for j, group in enumerate(group_list):
    contours = joblib.load(group)
    contours = list(contours.values())

    # read in shapes (encoded as n-by-2 array)
    surfs = [np.array(c) for c in contours] # <- surfs[k][l] holds l-th frame of k-th subject
    surf_full.append(surfs)

    subjecttrends = []

    for i in range(len(surfs)):

        # map to shape space
        surfs = [jnp.array([M.to_coords(s) for s in surfs_i]) for surfs_i in surfs]

        # set corresponding times
        times = [jnp.linspace(0., 2, len(surfs_i)+1)[:-1] for surfs_i in surfs]

        subjecttrends.append(reg(surfs[i], times[i]))
        
        
    pickle.dump(subjecttrends, open(pickle_list[j], 'wb'), pickle.HIGHEST_PROTOCOL)


In [7]:
trends_dcm = pickle.load(open('./trends/trends_DCM.pkl', 'rb'))
trends_hcm = pickle.load(open('./trends/trends_HCM.pkl', 'rb'))
trends_minf = pickle.load(open('./trends/trends_MINF.pkl', 'rb'))
trends_nor = pickle.load(open('./trends/trends_NOR.pkl', 'rb'))


group_idx = 0  # DCM = 0, HCM = 1, MINF = 2, NOR = 3
sample_idx = 0

@interact
def plot(i=(0, len(surf_full[group_idx][sample_idx]) - 1, 1)):
    t = times[sample_idx][i]
    pts = trends_nor[sample_idx].eval(t)
    plt.plot(*surfs[sample_idx][i].T)
    plt.plot(*pts.T)
    plt.xlim(-0.1, 0.1)
    plt.ylim(-0.08, 0.08)
    return plt.gca()

interactive(children=(IntSlider(value=14, description='i', max=29), Output()), _dom_classes=('widget-interact'…

In [None]:
# Create manifold of Bezier splines
B = Bezierfold(M, n_segments, degree, isscycle=True)