# Animation of 3D models generated by spateo

In [1]:
import io, warnings, os
from pathlib import Path
import base64
from IPython.display import HTML

import matplotlib as mpl
import numpy as np
import spateo as st
warnings.filterwarnings('ignore')

## Load the data

In [2]:
os.chdir(f"/media/pc001/Yao/Projects/Project_spateo/mouse_heart")
out_image_path = f"animations/mouse_heart_models"
Path(out_image_path).mkdir(parents=True, exist_ok=True)

In [3]:
cpo = "xz"
 
sample_id = "E9.5"
heart_adata = st.read_h5ad(f"updated_data/{sample_id}_h5ad/mouse_{sample_id}_heart_morphogenesis.h5ad")

In [4]:
lscmap = mpl.cm.get_cmap("vlag_r") # gist_rainbow, rainbow, hsv, tab20
regions = ["Left ventricle", "Right ventricle", "Outflow tract", "Right atrium", "Left atrium"]
regions_hex_list = [mpl.colors.to_hex(lscmap(i)) for i in np.linspace(0, 1, len(regions))]
regions_colors = {i: j for i, j in zip(regions, regions_hex_list)}

heart_pc, _ = st.tdr.construct_pc(adata=heart_adata,spatial_key="3d_align_spatial",groupby="heart_regions",key_added="heart_regions", colormap=regions_colors)

## Animation of trajectory model

In [5]:
st.tdr.morphofield_gp(
    adata=heart_adata,
    spatial_key="3d_align_spatial",
    vf_key="VecFld_morpho",
    NX=np.asarray(heart_adata.obsm['3d_align_spatial']),
    inplace=True,
)

In [6]:
st.tdr.morphopath(
    adata=heart_adata,
    vf_key="VecFld_morpho",
    key_added="fate_morpho",
    t_end=2000,
    interpolation_num=10,
    cores=1,
)

In [7]:
trajectory_model, _ = st.tdr.construct_trajectory(
    adata=heart_adata,
    fate_key="fate_morpho",
    n_sampling=1500,
    sampling_method="trn",
    key_added="obs_index",
    label=np.asarray(heart_adata.obs.index), # stage1_tissue_adata.uns["VecFld_morpho"]["V"][:, 2].flatten(),
)

heart_adata.obs["V_z"] = heart_adata.uns["VecFld_morpho"]["V"][:, 2].flatten()
st.tdr.add_model_labels(
        model=trajectory_model,
        key_added="V_z",
        labels=np.asarray(heart_adata[np.asarray(trajectory_model.point_data["obs_index"])].obs["V_z"]),
        colormap="Spectral",
        where="point_data",
        inplace=True,
    )

In [11]:
from anndata import AnnData
from typing import Optional, Union, Tuple
def construct_genesis(
    adata: AnnData,
    fate_key: str = "fate_morpho",
    n_steps: int = 100,
    logspace: bool = False,
    t_end: Optional[Union[int, float]] = None,
    key_added: str = "genesis",
    label: Optional[Union[str, list, np.ndarray]] = None,
    color: Union[str, list, dict] = "skyblue",
    alpha: Union[float, list, dict] = 1.0,
):
    """
    Reconstruction of cell-level cell developmental change model based on the cell fate prediction results. Here we only
    need to enter the three-dimensional coordinates of the cells at different developmental stages.

    Args:
        adata: AnnData object that contains the fate prediction in the ``.uns`` attribute.
        fate_key: The key under which are the active fate information.
        n_steps: The number of times steps fate prediction will take.
        logspace: Whether or to sample time points linearly on log space. If not, the sorted unique set of all times
                  points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps`
                  time points.
        t_end: The length of the time period from which to predict cell state forward or backward over time.
        key_added: The key under which to add the labels.
        label: The label of cell developmental change model. If ``label == None``, the label will be automatically generated.
        color: Color to use for plotting model.
        alpha: The opacity of the color to use for plotting model.

    Returns:
        A MultiBlock contains cell models for all stages.
        plot_cmap: Recommended colormap parameter values for plotting.
    """

    from dynamo.vectorfield import SvcVectorField
    from spateo.tdr.morphometrics.morphofield_dg.GPVectorField import GPVectorField
    from scipy.integrate import odeint
    from spateo.tdr.models.models_migration.morphopath_model import construct_genesis_X
    
    if fate_key not in adata.uns_keys():
        raise Exception(
            f"You need to first perform develop_trajectory prediction before animate the prediction, please run"
            f"st.tdr.develop_trajectory(adata, key_added='{fate_key}' before running this function"
        )

    t_ind = np.asarray(list(adata.uns[fate_key]["t"].keys()), dtype=int)
    t_sort_ind = np.argsort(t_ind)
    t = np.asarray(list(adata.uns["fate_morpho"]["t"].values()))[t_sort_ind]
    flats = np.unique([int(item) for sublist in t for item in sublist])
    flats = np.hstack((0, flats))
    flats = np.sort(flats) if t_end is None else np.sort(flats[flats <= t_end])
    time_vec = (
        np.logspace(0, np.log10(max(flats) + 1), n_steps) - 1
        if logspace
        else flats[(np.linspace(0, len(flats) - 1, n_steps)).astype(int)]
    )
    print(adata.uns["VecFld_"+fate_key[5:]]["method"])
    if adata.uns["VecFld_"+fate_key[5:]]["method"] == "gaussian_process":
        vf = GPVectorField()
        vf.from_adata(adata, vf_key="VecFld_"+fate_key[5:])
        f = lambda x, _: vf.func(x)[0]
    else:
        vf = SvcVectorField()
        vf.from_adata(adata, basis=fate_key[5:])
        f = lambda x, _: vf.func(x)
    displace = lambda x, dt: odeint(f, x, [0, dt])

    init_states = adata.uns[fate_key]["init_states"]
    pts = [i.tolist() for i in init_states]
    stages_X = []
    for i in range(n_steps):
        pts = [displace(cur_pts, time_vec[i])[1].tolist() for cur_pts in pts]
        stages_X.append(np.asarray(pts))

    cells_developmental_model, plot_cmap = construct_genesis_X(
        stages_X=stages_X, n_spacing=None, key_added=key_added, label=label, color=color, alpha=alpha
    )

    return cells_developmental_model, plot_cmap

In [16]:
cells_models, _ = construct_genesis(
    adata=heart_adata,
    fate_key="fate_morpho",
    n_steps = 100,
    logspace=True,
    t_end=500,
    label=[heart_adata.uns["VecFld_morpho"]["V"][:, 2]] * 100,
    color="RdBu_r"
)

In [17]:
st.pl.three_d_animate(
    models=cells_models,
    stable_model=trajectory_model,
    key="genesis",
    stable_kwargs=dict(
        key="V_z",
        model_style="wireframe",
        model_size=5,
        opacity=0.5,
        colormap="RdBu_r",
        show_legend=False,
    ),
    filename=os.path.join(out_image_path, f"Heart_morphofield_model_animation.mp4"),
    colormap="RdBu_r",
    model_style="points",
    model_size=12,
    jupyter="static",
    background="white",
    text=f"\nPredicted migration trajectories of Heart (Mouse E9.5)",
    text_kwargs={"font_size":20},
    window_size=(2560, 2048),
    cpo="yz",
    framerate=6)

## Merge all animations

In [18]:
st.pl.merge_animations(
    mp4_files=[
        os.path.join(out_image_path, f"Heart_mapping_model_animation.mp4"),
        os.path.join(out_image_path, f"Heart_morphofield_model_animation.mp4"),
    ],
    filename=os.path.join(out_image_path, f"Mouse_Heart_organgenesis_animation.mp4"),
)