# Training with different topologies - Ubiquitin mutants

Run this notebook on Google Colab:

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AG-Peter/encodermap/blob/main/tutorials/notebooks_MD/devel_Ub_mutants.ipynb)

Find the documentation of EncoderMap:

https://ag-peter.github.io/encodermap

**Goals:**

In this tutorial you will learn:
- [Why different topologies can be challenging to combine into an ML model.](#problems_topologies)
- [How EncoderMap uses sparse tensors and sparse matrix multiplication.](#sparse_matrices)

### For Google Colab only:

If you're on Google colab, please uncomment these lines and install EncoderMap.

In [None]:
# !wget https://gist.githubusercontent.com/kevinsawade/deda578a3c6f26640ae905a3557e4ed1/raw/b7403a37710cb881839186da96d4d117e50abf36/install_encodermap_google_colab.sh
# !sudo bash install_encodermap_google_colab.sh

## Import Libraries
Before we can get started using EncoderMap we first need to import the EncoderMap library:

In [None]:
import encodermap as em
import numpy as np
from pathlib import Path

%load_ext autoreload
%autoreload 2

<a id='problems_topologies'></a>

## Protein topologies

The topology of a protein can be understood as a hierarchical catalogue of chains, residues, atoms, and their connectivity via bonds, angles, and dihedrals.

- chain 0:
  - resiude 0: MET1
    - atom 0: H1
    - atom 1: H2
    - atom 3: N
    - atom 4: CA
    - atom 5: CB
    - atom 6: CG
    - atom 7: SD
    - atom 8: CE
    - atom 9: C
    - atom 10: O
  - residue 1: ALA2
    - ...
  - ...
 
- Bonds:
  - H1 - N
  - H2 - N
  - N - CA
  - ...
 
- Angles:
  - N - CA - C
  - ...
 
- Dihedrals:
  - $\psi$ of MET1 (N-CA-C-N)
  - $\omega$ of MET1 (CA-C-N-CA)
  - $\phi$ of ALA2 (C-N-CA-C)
  - $\psi$ of ALA2 (N-CA-C-N)
  - ...
  - $\chi_1$ of MET1 (N-CA-CB-CG)
  - $\chi_2$ of MET1 (CA-CB-CG-SD)
  - $\chi_3$ of MET1 (CB-CG-SD-CE)
 
Let's have a look at a two very short peptides and their topology. The coordinate files `MAE.pdb`, `MAD.pdb`, and `MAEG.pdb` represent peptides with the sequences MET-ALA-GLU, MET-ALA-ASP, and MET-ALA-GLU-GLY respectively. The function `load` of EncoderMap loads them as `SingleTraj` classes:

In [None]:
output_dir = Path("/home/kevin/git/encoder_map_private/tests/data/topological_examples/")

mae = em.load(output_dir / "MAE.pdb")
mad = em.load(output_dir / "MAD.pdb")
maeg = em.load(output_dir / "MAEG.pdb")
mage = em.load(output_dir / "MAGE.pdb")

mae

EncoderMap offers a way of quickly looking at peptides with the `plot_ball_and_stick()` function from the `plot` module. We can let EncoderMap highlight atoms, bonds, angles or dihedrals.

In [None]:
# You can change the global styles of EncoderMap's plots with this variable:
# This layout is more suitable for publishing your work with EncoderMap
# GLOBAL_LAYOUT = {
#     "paper_bgcolor": 'rgba(0,0,0,0)',
#     "plot_bgcolor": 'rgba(0,0,0,0)',
#     "template": "plotly_white",
#     "scene": {
#         "xaxis": {
#             "visible": False,
#         },
#         "yaxis": {
#             "visible": False,
#         },
#         "zaxis": {
#             "visible": False,
#         },
#     },
# }
# em.plot.plotting.GLOBAL_LAYOUT = GLOBAL_LAYOUT

# reset with
em.plot.plotting.GLOBAL_LAYOUT = {}

# You can change the themes of EncoderMap's plotly plots with
# import plotly.io as pio
# pio.templates.default = "plotly_dark"

In [None]:
fig = em.plot.plot_ball_and_stick(
    mae,
    highlight="angles",
)

In [None]:
fig = em.plot.plot_ball_and_stick(
    mad,
    highlight="angles",
)

In [None]:
fig = em.plot.plot_ball_and_stick(
    maeg,
    highlight="dihedrals",
)

In [None]:
fig = em.plot.plot_ball_and_stick(
    mage,
    highlight="dihedrals",
)

From playing around with these topologies, we can deduce that all three peptides have their own unique topologies.

But somehow, they are still very similar.

**MAE and MAD differ only in the number of carbon atoms at the sidechain of glutamic acid and aspartic acid.**

This is called a point mutation. In nature, even just this single exchange for residues can have drastic changes, as far as hindering the function of the complete expressed protein. Diseases caused by point mutations can range from cancer, to neurodegenrative diseases, such as neurofibromatosis.

**MAE and MAEG are differ only in the C-terminal tail**

Often, proteins consist of a more rigid center (the globular domain - DG) and a flexible N- and C-terminal tail to either side of the sequence. These tails are oftentimes more flexible. These tails are often overlooked, because they can't be analyzed via X-ray diffraction but can still interact with the environment of the protein.

**MAEG and MAGE have the same number of backbone torsions and sidechain torsions but exhibit different sequences**

These two examples can be understood as sequence homologs or members of an evolutional protein family. Proteins in such families are similar, save for a few regions of evolutional variety (so-called non-conserved) regions.

For these reasons, it would be beneficial to compare proteins with different topologies and describe them with one unifying theory. At the start of such an endeavor we will first look at the alignment of the four peptides.

```
CLUSTAL W (1.83) multiple sequence alignment

MAE             MA-E
MAD             MA-D
MAGE            MAGE
MAEG            MAEG
                **  
```


### The feature space of the four peptides

<center><img src="topology_examples.png" width="800"/></center>

<center><img src="topology_examples.png" width="800"/></center>

In the scheme above, you can get a better idea, how.

## Analyzing different topologies

We have now concluded, that these peptide sequences are similar and want to conduct some analysis on them. However, due to the difference in topology, we need to 

<div class="alert alert-info" role="alert">How can we treat these different topologies?</div>

### Solution 1: Choose the Intersection of features

For the intersection of features, we could use:

In [None]:
mae.featurizer.add_list_of_feats("all", periodic=False)
mad.featurizer.add_list_of_feats("all", periodic=False)
intersection = np.in1d(mae.featurizer.describe(), mad.featurizer.describe())
(np.array(mae.featurizer.describe())[intersection]).tolist()

**EncoderMap now can deal with these kinds of different topologies**

EncoderMap does this by combining multiple `SingleTraj` into a `TrajEnsemble`, which can collect trajectories with different topologies and align them by using generic feature names:

In [None]:
trajs = em.TrajEnsemble(
    [mae, mad, maeg, mage],
)
trajs

EncoderMap also offers a way to quickly load the most important features of such a peptide. With the `load_CVs("all")` method of the `TrajEnsemble` class, we can directly load:

- backbone positions
- backbone distances
- backbone angles
- backbone dihedrals
- sidechain dihedrals

In [None]:
trajs.load_CVs("all", ensemble=False, periodic=False)

We can then have a look at the data using the `_CVs` attribute of the `TrajEnsemble`. In this case, we are just interesed in the `side_dihedrals`:

In [None]:
trajs._CVs.side_dihedrals

The side dihedrals of `MAE.pdb` are defined as:

- SIDECHDIH CHI1  RESID  MET:   1 CHAIN 0
- SIDECHDIH CHI2  RESID  MET:   1 CHAIN 0
- SIDECHDIH CHI3  RESID  MET:   1 CHAIN 0
- SIDECHDIH CHI1  RESID  GLU:   3 CHAIN 0
- SIDECHDIH CHI2  RESID  GLU:   3 CHAIN 0
- SIDECHDIH CHI3  RESID  GLU:   3 CHAIN 0

<a id='sparse_matrices'></a>

## Excourse: Sparse matrices

Simply put, a normal matrix multiplication ($\times$) is carried out by calculating the row- and column-wise dot-product. In the example below, the matrices $A^{2\times3}$  and $B^{3\times2}$ are multiplied to yield $C^{2\times2}$. The first element of $C$ is obtained by solving: $1 * 7 + 2 * 9 + 3 * 11 = 58$. In python, NumPy arrays can be multiplied with the `matmul` operator `@`:

In [None]:
A = np.array([
    [1, 2, 3],
    [4, 5, 6],
])

B = np.array([
    [7, 8],
    [9, 10],
    [11, 12],
])

C = A @ B
C

The same can be done with tensorflow tensors:

In [None]:
A_t = tf.convert_to_tensor(A)
B_t = tf.convert_to_tensor(B)
C_t = A_t @ B_t
C_t

However, when some of the values are unkown, we can't carry out normal matrix multiplication, because the multiplication of a number with not-a-number yields not-a-number.

In [None]:
A = np.array([
    [1, 2, 3],
    [4, np.nan, 6],
])

B = np.array([
    [7, 8],
    [9, 10],
    [11, 12],
])

C = A @ B
C

The same happens in tensorflow. However, because tensorflow matrix-multiplications are often carried out in sequence (a dense layer of a neural network carries out the operation $\hat{y} = w \cdot y + b$), nans can propagate:

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Dense(units=3, input_shape=(3, )),
    tf.keras.layers.Dense(units=2),
])

A = np.array([
    [1, 2, 3],
    [4, np.nan, 6],
    [7, 8, 9],
])

B = np.array([
    [7, 8],
    [9, 10],
    [11, 12],
])

model.compile(optimizer="Adam", loss="mse")
history = model.fit(A, B, batch_size=2, epochs=10)

model(A)

This can be solved by sparse matrix multiplication. Normally sparse matrix multiplication is a tool to increase the speed of matrix multiplication when a lot of the matrix elements are zeros. We can use sparse matrix multiplication to allow some elements to be undefined.

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(3, ), sparse=True),
    tf.keras.layers.Dense(units=3),
    tf.keras.layers.Dense(units=2),
])

A = np.array([
    [1, 2, 3],
    [4, np.nan, 6],
    [7, 8, 9],
])
indices = np.where(~np.isnan(A))
A = tf.SparseTensor(
    indices=np.vstack(indices).T,
    values=A[indices],
    dense_shape=A.shape,
)

B = np.array([
    [7, 8],
    [9, 10],
    [11, 12],
])

model.compile(optimizer="Adam", loss="mse")
history = model.fit(A, B, batch_size=2, epochs=10)

model(A)

This model can now be used with a wide variety of sparse data, that has potentially missing values.

In [None]:
C = np.array([
    [1, 2, 3],
    [4, 5, np.nan],
    [7, 8, np.nan],
])

indices = np.where(~np.isnan(C))

C = tf.SparseTensor(
    indices=np.vstack(indices).T,
    values=C[indices],
    dense_shape=C.shape,
)

model(C)

## Ub mutants

### Load data

In [None]:
output_dir = em.get_from_kondata(
    "Ub_K11_mutants",
    mk_parentdir=True,
    silence_overwrite_message=True,
)
output_dir = Path(output_dir)

In [None]:
OVERWRITE = False

trajs_file = output_dir / "trajs.h5"

if trajs_file.is_file() and not OVERWRITE:
    trajs = em.load(trajs_file)
else:
    custom_aas = {
    "KAC": (
        "K",
        {
            "bonds": [
                ("-C", "N"),  # the peptide bond to the previous aa
                ("N", "CA"),
                ("N", "H"),
                ("CA", "C"),
                ("C", "O"),
                ("CA", "CB"),
                ("CB", "CG"),
                ("CG", "CD"),
                ("CD", "CE"),
                ("CE", "NZ"),
                ("NZ", "HZ"),
                ("NZ", "CH"),
                ("CH", "OI2"),
                ("CH", "CI1"),
                ("C", "+N"),  # the peptide bond to the next aa
            ],
            "CHI1": ["N", "CA", "CB", "CG"],
            "CHI2": ["CA", "CB", "CG", "CD"],
            "CHI3": ["CB", "CG", "CD", "CE"],
            "CHI4": ["CG", "CD", "CE", "NZ"],
            "CHI5": ["CD", "CE", "NZ", "CH"],
        },
    )
}
    trajs = list(output_dir.rglob("*.xtc"))
    tops = [t.parent / "start.pdb" for t in trajs]
    common_str = ["wt", "Ac", "Q", "R", "C"]
    basename_fn = lambda x: x.split("/")[-2]
    trajs = em.load(
        trajs=trajs,
        tops=tops,
        common_str=common_str,
        basename_fn=basename_fn,
    )
    trajs.load_custom_topology(custom_aas)
    trajs.load_CVs("all", ensemble=True)
    trajs.save(trajs_file, overwrite=True)

### Make images of topology

In [None]:
%load_ext autoreload
%autoreload 2

from encodermap.plot.plotting import _plot_ball_and_stick
import warnings

In [None]:
trajs.trajs_by_common_str.keys()

In [None]:
traj = trajs.trajs_by_common_str["wt"][0]

GLOBAL_LAYOUT = {
    "paper_bgcolor": 'rgba(0,0,0,0)',
    "plot_bgcolor": 'rgba(0,0,0,0)',
    "template": "plotly_white",
    "scene": {
        "xaxis": {
            "visible": False,
        },
        "yaxis": {
            "visible": False,
        },
        "zaxis": {
            "visible": False,
        },
    },
}
em.plot.plotting.GLOBAL_LAYOUT = GLOBAL_LAYOUT

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fig = _plot_ball_and_stick(
        traj,
        atom_indices=traj.top.select("resid 9 to 11 and not element H"),
        highlight="side_dihedrals",
        add_angle_arcs=True,
        persistent_hover=True,
        angle_arcs_true_to_value=False,
        flatten=False,
    )

fig.show()

In [None]:
traj = trajs.trajs_by_common_str["Ac"][0]

GLOBAL_LAYOUT = {
    "paper_bgcolor": 'rgba(0,0,0,0)',
    "plot_bgcolor": 'rgba(0,0,0,0)',
    "template": "plotly_white",
    "scene": {
        "xaxis": {
            "visible": False,
        },
        "yaxis": {
            "visible": False,
        },
        "zaxis": {
            "visible": False,
        },
    },
}
em.plot.plotting.GLOBAL_LAYOUT = GLOBAL_LAYOUT

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fig = _plot_ball_and_stick(
        traj,
        atom_indices=traj.top.select("resid 9 to 11 and not element H"),
        highlight="side_dihedrals",
        add_angle_arcs=True,
        persistent_hover=True,
        angle_arcs_true_to_value=False,
        flatten=False,
    )

fig.show()

In [None]:
traj = trajs.trajs_by_common_str["Ac"][0]

GLOBAL_LAYOUT = {
    "paper_bgcolor": 'rgba(0,0,0,0)',
    "plot_bgcolor": 'rgba(0,0,0,0)',
    "template": "plotly_white",
    "scene": {
        "xaxis": {
            "visible": False,
        },
        "yaxis": {
            "visible": False,
        },
        "zaxis": {
            "visible": False,
        },
    },
}
em.plot.plotting.GLOBAL_LAYOUT = GLOBAL_LAYOUT

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fig = _plot_ball_and_stick(
        traj,
        atom_indices=traj.top.select("resid 9 to 11"),
        highlight="side_dihedrals",
        add_angle_arcs=True,
        persistent_hover=True,
        angle_arcs_true_to_value=False,
    )

fig.show()

In [None]:
traj = trajs.trajs_by_common_str["Q"][0]

GLOBAL_LAYOUT = {
    "paper_bgcolor": 'rgba(0,0,0,0)',
    "plot_bgcolor": 'rgba(0,0,0,0)',
    "template": "plotly_white",
    "scene": {
        "xaxis": {
            "visible": False,
        },
        "yaxis": {
            "visible": False,
        },
        "zaxis": {
            "visible": False,
        },
    },
}
em.plot.plotting.GLOBAL_LAYOUT = GLOBAL_LAYOUT

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fig = _plot_ball_and_stick(
        traj,
        atom_indices=traj.top.select("resid 9 to 11"),
        highlight="side_dihedrals",
        add_angle_arcs=True,
        persistent_hover=True,
        angle_arcs_true_to_value=False,
    )

fig.show()

In [None]:
traj = trajs.trajs_by_common_str["C"][0]

GLOBAL_LAYOUT = {
    "paper_bgcolor": 'rgba(0,0,0,0)',
    "plot_bgcolor": 'rgba(0,0,0,0)',
    "template": "plotly_white",
    "scene": {
        "xaxis": {
            "visible": False,
        },
        "yaxis": {
            "visible": False,
        },
        "zaxis": {
            "visible": False,
        },
    },
}
em.plot.plotting.GLOBAL_LAYOUT = GLOBAL_LAYOUT

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fig = _plot_ball_and_stick(
        traj,
        atom_indices=traj.top.select("resid 9 to 11"),
        highlight="side_dihedrals",
        add_angle_arcs=True,
        persistent_hover=True,
        angle_arcs_true_to_value=False,
    )

fig.show()

### Choose parameters

In [None]:
import xarray as xr
da = xr.open_dataset("../tests/data/linear_dimers/trajs.h5", group="CVs", engine="h5netcdf").central_cartesians
linear_dimers_cartesians = da.stack({"frame": ("traj_num", "frame_num")}).transpose("frame", ...).dropna("frame", how="all")
linear_dimers_cartesians.shape

In [None]:
em.plot.distance_histogram_interactive(
    data=em.misc.pairwise_dist(
        linear_dimers_cartesians[::1000, 1::3],
    ),
    periodicity=float("inf"),
    initial_guess=[40, 10, 5, 1, 2, 5],
    n_values=1000,
)

In [None]:
p = em.ADCParameters(
    cartesian_pwd_start=1,
    cartesian_pwd_step=3,
    cartesian_dist_sig_parameters=(7, 6, 3, 1, 2, 3),
)

em.plot.distance_histogram_interactive(
    data=em.misc.pairwise_dist(
        trajs.central_cartesians[::1000, p.cartesian_pwd_start::p.cartesian_pwd_step],
    ),
    periodicity=float("inf"),
    initial_guess=p.cartesian_dist_sig_parameters,
    n_values=1000,
)

### Train

In [None]:
total_steps = 5000

parameters = em.ADCParameters(
    main_path=em.misc.run_path(output_dir / "runs"),
    use_sidechains=True,
    use_backbone_angles=True,
    cartesian_dist_sig_parameters=(7, 6, 3, 1, 2, 3),
    n_steps=total_steps,
    cartesian_cost_scale=1,
    cartesian_cost_variant="mean_abs",
    cartesian_cost_scale_soft_start=(
        int(total_steps / 10 * 9),
        int(total_steps / 10 * 9) + total_steps // 50,
    ),
    cartesian_pwd_start=1,
    cartesian_pwd_step=3,
    dihedral_cost_scale=1,
    dihedral_cost_variant="mean_abs",
    distance_cost_scale=0,
    cartesian_distance_cost_scale=100,
    checkpoint_step=max(1, int(total_steps / 10)),
    l2_reg_constant=0.001,
    center_cost_scale=0,
    tensorboard=True,
)

emap = em.AngleDihedralCartesianEncoderMap(
    trajs=trajs,
    parameters=p,
    read_only=False,
    use_dataset_when_possible=True,
)
emap.add_images_to_tensorboard()

### Preliminary plots

In [None]:
trajs, emap = em.load_project("Ub_K11_mutants", load_autoencoder=True)

In [None]:
lowd = emap.encode()
trajs.load_CVs(lowd, "lowd")

In [None]:
em.plot.plot_trajs_by_parameter(trajs, "free_energy")

In [None]:
em.plot.plot_trajs_by_parameter(trajs, "common_str")

In [None]:
em.plot.plot_trajs_by_parameter(trajs, "common_str", type="heatmap", nbins=50)

In [None]:
em.plot.plot_trajs_by_parameter(trajs, "encoded_frame")

In [None]:
em.plot.plot_trajs_by_parameter(trajs, "traj_num")

### Decoding

In [None]:
trajs, emap = em.load_project("Ub_K11_mutants", load_autoencoder=True)

In [None]:
sess = em.InteractivePlotting(autoencoder=emap)

In [None]:
# sess.generate(None)
sess.save(None)

### Look at saved paths with `interactive_path_visualization`

In [None]:
import encodermap as em
import pandas as pd
import numpy as np

%load_ext autoreload
%autoreload 2

traj = em.load("/home/kevin/git/encoder_map_private/tests/data/Ub_K11_mutants/checkpoints/finished_training/tf2_15/generated_paths/2024-03-20T18:15:15+01:00/generated.xtc", "/home/kevin/git/encoder_map_private/tests/data/Ub_K11_mutants/checkpoints/finished_training/tf2_15/generated_paths/2024-03-20T18:15:15+01:00/generated.pdb")
lowd = pd.read_csv("/home/kevin/git/encoder_map_private/tests/data/Ub_K11_mutants/checkpoints/finished_training/tf2_15/generated_paths/2024-03-20T18:15:15+01:00/lowd.csv")
path = np.load("/home/kevin/git/encoder_map_private/tests/data/Ub_K11_mutants/checkpoints/finished_training/tf2_15/generated_paths/2024-03-20T18:15:15+01:00/path.npy")

em.plot.interactive_path_visualization(
	traj,
	lowd,
	path,
)

## linear Ub-dimers and FAT10

### Load data

In [None]:
import encodermap as em
from pathlib import Path
import xarray as xr
import numpy as np
from tqdm import tqdm
import re

%load_ext autoreload
%autoreload 2

In [None]:
ub_dimer_dir = Path(em.get_from_kondata(
    "linear_dimers",
    mk_parentdir=True,
    silence_overwrite_message=True,
))

fat10_dir = Path(em.get_from_kondata(
    "FAT10",
    mk_parentdir=True,
    silence_overwrite_message=True,
    download_checkpoints=True,
))

In [None]:
# don't need full trajs, so use subset for topology
# trajs = em.TrajEnsemble.from_dataset(fat10_dir / "linear_dimers_and_fat10.h5")

# load just the first Ubi dimer and first FAT10 sim
sub_trajs = em.load(
    trajs=[
        ub_dimer_dir / "01.xtc",
        fat10_dir / "01.xtc",
    ],
    tops=[
        ub_dimer_dir / "01.pdb",
        fat10_dir / "01.pdb",
    ],
    basename_fn=lambda x: "_".join(list(x.split("/")[-2:])),
    common_str=["linear_dimer", "FAT10"],
    traj_num=[0, 12],
)

# xr.open_dataset does not tax the system memory at all
ds = xr.open_dataset(fat10_dir / "linear_dimers_and_fat10.h5", group="CVs")

# load the dataset
sub_trajs.load_CVs(
    ds.sel(traj_num=[0, 12]),
)
sub_trajs

In [None]:
emap = em.AngleDihedralCartesianEncoderMap.from_checkpoint(
    trajs=sub_trajs,
    checkpoint_path=fat10_dir / "checkpoints/finished_training/tf2_15/saved_model_50000.keras",
)

In [None]:
keras_file = Path("/home/kevin/encodermap/tests/data/FAT10/runs/2nd_run_lower_training_rate_to_0_0001/saved_model_50000.keras")

emap_for_steps = em.AngleDihedralCartesianEncoderMap.from_checkpoint(
    trajs=sub_trajs,
    checkpoint_path=keras_file,
)

total = ds.sizes["traj_num"] * (ds.sizes["frame_num"] // 100)
print(f"projecting for {total} steps")

lowd = {}

with tqdm(total=total) as pbar:
    for traj_num, sub_ds in ds.groupby("traj_num", squeeze=False):
        sub_ds = sub_ds.squeeze("traj_num")
        length = sub_ds.sizes["frame_num"]
        indices = np.split(
            np.arange(length), np.arange(101, length, 101)
        )
        for ind in indices:
            data = [
                sub_ds.central_angles.values[ind],
                sub_ds.central_dihedrals.values[ind],
                sub_ds.side_dihedrals.values[ind],
            ]
            l = emap_for_steps.encode(data=data)
            lowd.setdefault(traj_num, []).append(l)
            pbar.update()
        lowd[traj_num] = np.vstack(lowd[traj_num])
total_lowd = np.vstack(
    [lowd[k] for k in range(ds.sizes["traj_num"])],
)


In [None]:
# iterate over trajs in dataset to save memory

lowd_per_steps_file = fat10_dir / "all_lowd.npz"
lowd_per_steps = {}

if not lowd_per_steps_file.is_file():
    for keras_file in (fat10_dir / "runs/1st_run").glob("*.keras"):
        if "2024" in str(keras_file):
            continue
        steps = int(re.findall(r"\d+", keras_file.stem)[0])
    
        lowd_file = Path(fat10_dir / f"complete_linear_dimers_and_FAT10_lowd_{steps}.npy")
    
        emap_for_steps = em.AngleDihedralCartesianEncoderMap.from_checkpoint(
            trajs=sub_trajs,
            checkpoint_path=keras_file,
        )
        
        if not lowd_file.is_file():
            total = ds.sizes["traj_num"] * (ds.sizes["frame_num"] // 100)
            
            lowd = {}
            
            with tqdm(total=total) as pbar:
                for traj_num, sub_ds in ds.groupby("traj_num", squeeze=False):
                    sub_ds = sub_ds.squeeze("traj_num")
                    length = sub_ds.sizes["frame_num"]
                    indices = np.split(
                        np.arange(length), np.arange(101, length, 101)
                    )
                    for ind in indices:
                        data = [
                            sub_ds.central_angles.values[ind],
                            sub_ds.central_dihedrals.values[ind],
                            sub_ds.side_dihedrals.values[ind],
                        ]
                        l = emap_for_steps.encode(data=data)
                        lowd.setdefault(traj_num, []).append(l)
                        pbar.update()
                    lowd[traj_num] = np.vstack(lowd[traj_num])
            total_lowd = np.vstack(
                [lowd[k] for k in range(ds.sizes["traj_num"])],
            )
        
            np.save(lowd_file, total_lowd)
        else:
            total_lowd = np.load(lowd_file)
        
        lowd_per_steps[steps] = total_lowd
else:
    np.savez(lowd_per_steps_file, **lowd_per_steps)

In [None]:
np.savez(lowd_per_steps_file, **{str(k): v for k, v in lowd_per_steps.items()})

In [None]:
lowd_per_steps.keys()

In [None]:
import plotly.graph_objects as go

trace = em.plot.plot_free_energy(*total_lowd.T)

fig = go.Figure(
    data=[trace],
    layout={"width": 800, "height": 800},
)

fig.show()

In [None]:
emap_for_steps.encode()