# Analysis of the Multilabel Workflow

In this notebook, we want to compare the impact of the number of labels on optimization, dynamic and identifiability of the model.

This notebook mainly serves to create **Figure 6**

In [None]:
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pypesto.visualize.model_fit as model_fit

from matplotlib.colors import LinearSegmentedColormap
from analyze_results import CumulativeResultsLabels

In [None]:
# We define some variables that will reoccur more often.
model_name = 'lipidomics_2023_08_29'
res_dict = '../Results_h5'

We start by loading all results of all different models into a singular object, that can then visualize comparisons etc

In [None]:
cum_res = CumulativeResultsLabels(
    model_name, res_dict, offset=1, n_labels=5, n_res=1000
)

In [None]:
# In addition to the hdf5 file we will be in need of the PEtab problem and the pyPESTO problem accordingly
cum_res.load_petab_problems("../Petab_models_230829/multilabels")

In [None]:
cum_res.load_pypesto_problems(force_compile=True)

In [None]:
cum_res.overview_failed_starts()

In the following, we use the eigenvalues of the Fisher Information Matrix, to get indications about the identifiability of a parameter. This will create **Figure 6D**. First we calculate the eigenvalues for all results we got.

In [None]:
eigenvals = cum_res.calculate_fim_eigenvals()

In [None]:
# We inspect the eigenvalues, which are already sorted by number of labels by their medians. We also for the sake of visualization cutoff the eigenvalues to 1e-20 for very small values. This is just for visalization purposes and has no inpact on the distribution
eigenvals = [np.real(eigval) for eigval in eigenvals]
eig_cutted = [[max(eig, 1e-20) for eig in eigenval] for eigenval in eigenvals]
log_eigenvals = [np.log10(eigval) for eigval in eig_cutted]

eigen_val_median = [np.median(eigval) for eigval in eigenvals]
eigen_val_median_cutted = [np.median(eigval) for eigval in eig_cutted]

print(f"medians: {eigen_val_median}")
print(f"medians: {eigen_val_median_cutted}")

In [None]:
# We then inspect the eigenvalues for mean and small/great values. This gives an indication of the general identifiability of the model.
eigen_val_means = [np.mean(eigval) for eigval in eigenvals]
eigen_val_median = [np.median(eigval) for eigval in eigenvals]
eigen_val_zero = [np.sum(eigval < 1e-15) for eigval in eigenvals]
eigen_val_geq1 = [np.sum(eigval > 1)/len(eigval) for eigval in eigenvals]
print(f"means: {eigen_val_means}")
print(f"medians: {eigen_val_median}")
print(eigen_val_zero)
print(eigen_val_geq1)

In [None]:
colors = ["#E96F6F50", "#E9E96F50", "#65F26550", "#6FE9E950", "#6565F250", "#E96FE950"]

After the inspection, the following codeblock creates violinplots of the eigenvalues. This corresponds to **Figure 6D**

In [None]:
# Figure 6D
print(f"medians: {eigen_val_median}")

eig = log_eigenvals
eigenvals_log_cutted_median = [np.log10(np.median(np.power(10, eigval))) for eigval in eig]
print(f"medians cutted: {eigenvals_log_cutted_median}")
fig, ax = plt.subplots(figsize=(15, 5))


violin_parts = ax.violinplot(eig, showmedians=True)

for i_pc, pc in enumerate(violin_parts['bodies']):
    pc.set_facecolor(colors[i_pc])
    pc.set_edgecolor('black')
for i, v in enumerate(eigenvals_log_cutted_median):
    plt.text((i + 1.15), (v - .5), str(round(v, 4)), fontsize = 12)

plt.tick_params(left = False, bottom = False, which='both')
plt.savefig("Figure5/violins.pdf")
plt.savefig("Figure5/violins.svg")

In **Figure 6A** We Visualize the different dynamics of the model with the example of MAG. This mainly serves to show hoe the different number of labels influence the dynamics. We first define the necessary function:

In [None]:
def experimental_setup(result, problem, n_labels):
    a = lambda i: "ul" if i==0 else f"{i}"
    obs_names = [
        f'observable_MAG_{a(i_label)}'
        for i_label in range(n_labels)
    ]
    # quick function to get all dags
    obs_names_all = obs_names
    
    ax_temp = model_fit.time_trajectory_model(
        result=result,
        problem=problem,
        timepoints=np.linspace(0, 130, 1301),
        state_names=[],
        state_ids=[],
        observable_ids=obs_names_all)
    for i_line, line in enumerate(ax_temp.lines):
        xx = line.get_xdata()
        yy = line.get_ydata()
        if i_line == 0:
            all_dag = yy
            all_dag_x = xx
        else:
            all_dag += yy
    all_dag_x = np.concatenate(([-40], all_dag_x))
    all_dag = np.concatenate(([all_dag[0]], all_dag))
    # if n_labels > 1:
    #  obs_names = obs_names[1:]
    ax = model_fit.time_trajectory_model(
        result=result,
        problem=problem,
        timepoints=np.linspace(0, 130, 1301),
        state_names=[],
        state_ids=[],
        observable_ids=obs_names)

    label_colors = [
        (0.68, 0.74, 0.84),
        (0.502, 0.733, 0.509),
        (0.944, 0.730, 0.418),
        (0.944, 0.558, 0.744),
        (0.509, 0.502, 0.733),
        (0.733, 0.502, 0.509),
    ]
    color_t0 = (0.4, 0.4, 0.4)
    
    time_division = 120
    if n_labels > 1:
        time_division = 120/(n_labels-1)
        
    for i_line, line in enumerate(ax.lines):
        if i_line == 1:
            ax.plot(
                [line.get_xdata()[1201] - (i_line-1) * time_division],
                [line.get_ydata()[1201]],
                color=label_colors[i_line],
                marker='o'
            )
        xx = line.get_xdata()
        yy = line.get_ydata()
        xx = np.concatenate(([-30], xx))
        yy = np.concatenate(([yy[0]], yy))
        line.set_data(xx, yy)
        line.set_markevery([1201])
        line.set_marker('o')
        line.set_c(label_colors[i_line])
        ax.annotate(
            # Label and coordinate
            f'',
            xy=(line.get_xdata()[1201]- (i_line-1) * time_division, line.get_ydata()[1201]),
            xytext=(line.get_xdata()[1201], line.get_ydata()[1201]),
            color=line.get_color(),
            # Custom arrow
            arrowprops=dict(arrowstyle='->', color=line.get_color(),
                            lw=1, ls='--')
        )
    ax.plot(all_dag_x, all_dag, label = "Sum over all label combinations", color = (0.23, 0.46, 0.69))
    xticks = [(i-1) * time_division for i in range(n_labels + 2)]
    xticks[0] = -20
    xlabels = [f"$T_{i-1}$" for i in range(n_labels + 2)]
    ax.set_xticks(xticks, labels=xlabels)
    ax.axvline(x=120, color='#A7A7A7', linestyle='dashed', linewidth=1)
    ax.axvline(x=0, color='#A7A7A7', linestyle='dashed', linewidth=1)
    ax.get_xticklabels()[0].set_color(color_t0)
    ax.grid(False)

    labeling = [item.get_text() for item in ax.get_yticklabels()]
    labeling[1] = '0'
    ax.set_yticklabels(labeling)
    ax.set_xlim((-20, 130))
    plt.title('')
    plt.ylabel('')
    plt.xlabel('')
    ax.figure.set_size_inches(5, 2.5)
    ax.set_yticks([])
    ax.get_legend().remove()
    # ensure 0 is within axis limits
    ax.axhline(y=0, color='#00000000') 

    return ax

The following code then produces the individual figures for **Figure 6A**. To be transparent, we visualize the general trajectories, once without the sum over all labels and once with. This is to show that in the case of zero labels, while it might look like it is moving, it just means the numerically found steady state has not been perfect. This can be seen from the very small differences between maximum and minimum value of the y-axis. If we include the 0 into our axis, the final figure shows us, that there is no dynamic.

In [None]:
# Figure 6A

plt.rcParams['axes.edgecolor']='black'
for i in range(5):
    print(f"n_labels = {i+1}")
    ax = experimental_setup(cum_res.results[i*200], cum_res.pypesto_problems[i], n_labels=i+1)
    y_limits = ax.get_ylim()
    y_limits = (y_limits[0], y_limits[1] * 1.1)
    ax.set_ylim(y_limits)
    plt.tick_params(left = False, bottom = False, which='both')
    plt.savefig(f"Figure5/Experimental_setup_{i}labels_sum.svg")

In [None]:
# Only for legend
i = 4
print(f"n_labels = {i+1}")
ax = experimental_setup(cum_res.results[i*200], cum_res.pypesto_problems[i], n_labels=i+1)
ax.set_ylim(y_limits)

# add legend manually
label_colors = [
    (0.502, 0.733, 0.509),
    (0.944, 0.730, 0.418),
    (0.944, 0.558, 0.744),
    (0.509, 0.502, 0.733),
    (0.23, 0.46, 0.69),
    (0.68, 0.74, 0.84),
    "black",
    "black"
]
labels = [
    "Label 1",
    "Label 2",
    "Label 3",
    "Label 4",
    "Sum over all Labels",
    "Unlabeled",
    "Measurement",
    "Retrieved Timepoint"
]
linestyles = [
    "-",
    "-",
    "-",
    "-",
    "-",
    "-",
    "None",
    "None"
]
markers = ["", "", "", "", "", "", "o", r"$\dashleftarrow$"]
handles = [
    mlines.Line2D([], [], color=color, linestyle=linestyle, label=label, marker=marker, markersize=8)
    for color, linestyle, label, marker in zip(label_colors, linestyles, labels, markers)
]
# remove grid
ax.grid(False)

ax.legend(handles=handles, loc='center left',
          bbox_to_anchor=(1, 0.5), ncol=2)
# set the legend box color to black
ax.legend_.get_frame().set_edgecolor('black')
# tight layout
plt.tight_layout()

plt.tick_params(left = False, bottom = False, which='both')

plt.savefig(f"Figure5/Experimental_setup_legends.pdf")

While this is not a figure, we include this to really show, that in the case of zero labels, we actually are in a numerically found steady state over all species.

In [None]:
a  = cum_res.pypesto_problems[0].objective(cum_res.results[0].optimize_result[0].x, return_dict=True)["rdatas"][0]

for key, value in a.items():
    if key[0:6] == 'preeq_':
        print('%20s: ' % key, value)

In [None]:
# assure we are in steady state
ax = model_fit.time_trajectory_model(
        result=cum_res.results[0],
        problem=cum_res.pypesto_problems[0],
        timepoints=np.linspace(0, 130, 1301),
        state_names=[],
        state_ids=[],
)

In [None]:
for i_lab, problem in enumerate(cum_res.pypesto_problems):
    print(f"{i_lab} Labels")
    print(f"Number of Observables: {problem.objective.amici_model.ny}")
    print(f"Number of States: {problem.objective.amici_model.nx_rdata}")

We compare computation time of the different models in **Figure 6C**. The below two code blocks generates that Figure, taking into account the computation time of the whole optimization, split by label.

In [None]:
# computation times
mean_times = np.array([np.mean(result.optimize_result.time) for result in cum_res.results])
mean_times_per_label = np.array([
    mean_times[cum_res.res_labels == i+1] for i in range(cum_res.n_labels)
])

In [None]:
# Figure 6C
plt.rcParams['axes.edgecolor']='black'
fig, ax = plt.subplots()
for i in range(mean_times_per_label.shape[0]):
    ax.boxplot(
        mean_times_per_label.T[i],
        positions=[i*0.4],
        boxprops=dict(facecolor=colors[i]),
        medianprops=dict(linewidth=2),
        labels=[i],
        patch_artist=True,
    )
# log scale
ax.set_yscale("log")
ax.set_xlabel("Labels")
ax.set_ylabel("Computation time (s)")
ax.set_xlim((-0.25, 1.85))
plt.tick_params(left = False, bottom = False, which='both')
plt.savefig("Figure5/Time_cost.pdf")

Lastly, we want to see how the identifiability of the parameters changes with increasing number of labels. For this, we compare the correlation of true and estimated parameters. An increase would point towards better identifiability. This reproduces **Figure 6E**

In [None]:
# Define color function
def determine_color(value):
    if value < 0.2:
        return '#FF6601'  # Orange
    elif 0.2 <= value < 0.8:
        return '#CC99FF'  # Purple
    else:
        return '#4A4AFF'  # Blue

cumRes = cum_res
parameters_est = []
parameters_true = []
for i in range(cumRes.n_labels):
    indices = np.where(cumRes.res_labels == i+cumRes.offset)[0]
    parameters_est.append(
        np.array([cumRes.results[i].optimize_result[0].x for i in indices])
    )
    parameters_true.append(
        np.array([cumRes.petab_problems[i].x_nominal_free_scaled for i in indices])
    )

# do that for each element in the list
correlations = [
    np.array([
        np.corrcoef(par_est[:, i], par_true[:, i])[0, 1] for i in range(par_est.shape[1])
    ]) for par_est, par_true in zip(parameters_est, parameters_true)
]

# Calculate counts of parameters below 0.2 and above 0.8
below_0_2 = [np.sum(np.array(correlation) < 0.2) for correlation in correlations]
above_0_8 = [np.sum(np.array(correlation) > 0.8) for correlation in correlations]

# plot violinplots
fig, ax = plt.subplots(figsize=(15, 5))
# Create a colormap using the colors from the violin plot labels
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
# violin_parts = ax.violinplot(correlations, showmedians=True)
# for i_pc, pc in enumerate(violin_parts['bodies']):
#     pc.set_facecolor(colors[i_pc])
#     pc.set_edgecolor('black')
jitters = []
# Plot data points with jitter and determine color based on value
for i, data in enumerate(correlations):
    jitter = np.random.normal(scale=0.1, size=len(data))  # Add jitter for better visualization
    jitter_x = np.full_like(data, i + 1) + jitter
    jitters.append(jitter)
    for x, y in zip(jitter_x, data):
        color = determine_color(y)
        ax.scatter(x, y, color=color, alpha=0.5)
# # Plot data points
# for i, data in enumerate(correlations):
#     jitter = np.random.normal(scale=0.1, size=len(data))  # Add jitter for better visualization
#     jitters.append(jitter)
#     ax.scatter(np.full_like(data, i + 1) + jitter, data, color=colors[i], alpha=0.5)
# Connect points between different correlations with grey color
for i in range(len(correlations[0])):
    for j in range(len(correlations) - 1):
        ax.plot([j + 1 + jitters[j][i] , j + 2 + jitters[j+1][i]], [correlations[j][i], correlations[j + 1][i]], color='gray', alpha=0.5)

# Add table with counts
labels = [f'{i + 1} Label' for i in range(len(correlations))]
table_data = [below_0_2, above_0_8]
table = ax.table(cellText=table_data, loc='top', cellLoc='center', rowLabels=['Below 0.2', 'Above 0.8'], colLabels=labels, cellColours=[['white'] * len(correlations), ['white'] * len(correlations)], colColours=['lightgray'] * len(correlations))

# Adjust table properties
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 1.5)

# Remove borders around the row labels
for key, cell in table.get_celld().items():
    cell.set_linewidth(0)
# Remove grid
ax.grid(False)

# Add legend
legend_labels = ['< 0.2', '> 0.2, < 0.8', '> 0.8']
legend_colors = ['#FF6601', '#CC99FF', '#4A4AFF']
legend_markers = ['o', 'o', 'o']
legend = ax.legend(legend_labels, title='Correlation', loc='upper left', title_fontsize='large', fontsize='medium', markerscale=1.5, frameon=True)
legend.get_frame().set_edgecolor('black')
for handle, marker, color in zip(legend.legend_handles, legend_markers, legend_colors):
    handle.set_color(color)  # Set marker color
    handle.set_edgecolor('black')  # Set marker edge color

plt.tick_params(left = False, bottom = False, which='both')
plt.tight_layout()
plt.savefig("Figure5/correlations.pdf")