In [1]:
# %autoreload
import sys
sys.path.append("/home/ubuntu/git/eals_data/")
sys.path.append("/home/ubuntu/git/eals_features/")
sys.path.append("/home/ubuntu/git/eals_mogp_new/src/eals_mogp")
import eals_mogp as mogp

In [3]:
import os

import numpy as np
import pandas as pd; pd.set_option('display.max_columns', None)
import warnings
warnings.simplefilter(action="ignore")

import matplotlib.pyplot as plt
plt.rcdefaults()
from matplotlib.gridspec import GridSpec
import seaborn as sns

from scipy.stats import pearsonr, spearmanr
from statsmodels.regression.linear_model import OLS
import statsmodels.api as sm

import utils

# Config
sns.set_context('notebook', font_scale=1.3)

In [4]:
# set all ploting context and style
sns.set_context("talk")

## Functions

In [5]:
import joblib

def plot_clusters_trajectories_mean_std(
    savepath,
    df_run,
    time_col,
    dependant_variable,
    df_clust=None,
    df_onset=None,
    substract_initial_value=False,
    xlim_dict=None,
    ylim=None,
    figsize=(6, 4),
    plot_regression_line=False,
    plot_confidence=False,
    add_onset_date=False,
):
    """
    Plots clusters trajectories.

    Original plot_confidence method documentation:
    https://gpy.readthedocs.io/en/deploy/GPy.core.html#GPy.core.gp.GP


    Args:
        - savepath: Path to the fitted saved model
        - df_run: must contain user_id, dependant_variable and time_col
        - time_col: years_since_first_session, days_since_onset, etc.
        - dependant_variable: fvcPercPred_2019, etc.
        - df_onset (optional): contains onset_date
        - substract_initial_value
        - xlim_dict
        - plot_regression_line
        - plot_confidence: 95% confidence interval of either the expected mean (if plot_raw==True) or the prediction interval (if plot_raw==False).
        - add_onset_date: wether to plot crosses at onset date

    Returns:
        None
    """


    # Load from saved dir
    model_path = './mogp'
    mogp_model_path = f"{model_path}/fvcPercPred_2019.pkl"
    print(f"loading model from: {mogp_model_path}")
    mogp_model = joblib.load(mogp_model_path)

    mogp_metadata_path = f"{model_path}/data_and_metadata.pkl"
    mogp_metadata = joblib.load(mogp_metadata_path)
    results = mogp_metadata["results"]
    df_clusters_all = df_clust.copy()
    # return results, df_clusters_all
    display(results.head(1))
    
    if xlim_dict is None:
        xlim_dict = {
            "years_since_first_session": (-3, 1.5),
            "months_since_first_session": (-3 * 12, 1.5 * 12),
            "days_since_first_session": (-3 * 365.25, 1.5 * 365.25),
            "days_since_onset": (0, 3500),
        }

    # Plot all trajectories with a line for each cluster
    n_subjects = results.cluster.value_counts()
    group_colors = {}
    for index, c in enumerate(
        df_clusters_all.sort_values("intercept", ascending=False).id
    ):
        group_colors[c] = plt.cm.tab10(index)
    print(group_colors)

    results["group"] = results["cluster"]

    dict_var_names = {
        "fvcPercPred_2019": "FVC (% Predicted)",
        "vcPercPred": "SVC (% Predicted)",
    }

    with sns.plotting_context("notebook", font_scale=1.4):
        fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=150)

        legends = []
        uids = df_run.user_id.unique()
        for index, uid in enumerate(uids):
            df_plot = df_run.query("user_id in @uid")
            df_plot.sort_values(time_col, inplace=True)
            if substract_initial_value:
                df_plot[dependant_variable] = (
                    df_plot[dependant_variable] - df_plot[dependant_variable].iloc[0]
                )
            if len(df_plot) > 1 and (uid in results.user_id.values):
                x = df_plot[time_col]
                y = df_plot[dependant_variable]
                cluster_id = results.query("user_id == @uid").cluster.values[0]
                color = group_colors[cluster_id]
                if cluster_id not in legends:
                    label_name = df_clusters_all.query("id == @cluster_id").name.values[0]
                    p = ax.plot(x, y, ".-", color=color, alpha=0.5, label=label_name)
                    legends.append(cluster_id)
                else:
                    p = ax.plot(x, y, ".-", color=color, alpha=0.5)

        ax.set_xlabel(time_col.replace("_", " ").title())
        ax.set_ylabel(dict_var_names[dependant_variable])
        col = "FVC" if "fvc" in dependant_variable else "SVC"
        ax.set_title(f"{col} trajectories")

        # Add sorted legends
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        sorted_labels = [
            df_clusters_all.query("id == @cluster_id").name.values[0]
            for cluster_id in legends
        ]
        sorted_handles = [by_label[label] for label in sorted_labels]
        ax.legend(sorted_handles, sorted_labels)

        for index, cluster_id in enumerate(df_clusters_all.id):
            intercept = df_clusters_all.query("id == @cluster_id").intercept.values[0]
            slope = df_clusters_all.query("id == @cluster_id").slope.values[0]
            color = plt.cm.tab10(index)
            xs = [xlim_dict[time_col][0] * 1.6, xlim_dict[time_col][-1] * 1.6]
            ys = [intercept + slope * x for x in xs]
            if plot_regression_line:
                ax.plot(
                    xs, ys, "--", color=group_colors[cluster_id], linewidth=4, alpha=0.5
                )
            if plot_confidence:
                mogp_model.obsmodel[cluster_id].model.plot_confidence(
                    color=color,
                    ax=ax,
                    label=None,
                    plot_raw=False,  # If True, it returns confidence interval. If False, it returns prediction interval.
                    lower=2.5,
                    upper=97.5,
                    plot_limits=[0, xlim_dict[time_col][-1] * 1.6],
                )

        # add the onset date for each subject
        if add_onset_date:
            if df_onset is None:
                pass
                # df_onset = get_df_onset()
            for index, uid in enumerate(uids):
                df_user = df_onset.query("user_id == @uid")
                if len(df_user) > 0:
                    # if onset date is not nan
                    onset_date = df_user.symptom_onset_date_since_session.values[0]
                    if not pd.isnull(onset_date):
                        # get the year since first session and plot the point x,y
                        if time_col == "years_since_first_session":
                            x = onset_date / 365
                        elif time_col == "months_since_first_session":
                            x = onset_date / 12
                        elif time_col == "days_since_first_session":
                            x = onset_date
                        elif time_col == "days_since_onset":
                            x = 0
                        # get the cluster slope and intercept
                        try:
                            cluster_id = results.query("user_id == @uid").cluster.values[
                                0
                            ]
                        except IndexError:
                            print(f"User {uid} not in results")
                            continue
                        intercept = df_clusters_all.query(
                            "id == @cluster_id"
                        ).intercept.values[0]
                        slope = df_clusters_all.query("id == @cluster_id").slope.values[0]
                        y = intercept + slope * x
                        ax.plot(x, y, "x", color="black", markersize=8, markeredgewidth=2)

        ax.set_xlim(xlim_dict[time_col][0], xlim_dict[time_col][-1])
        if ylim is None:
            ax.set_ylim(0.1, 1.5)
        else:
            ax.set_ylim(ylim)
        plt.grid()
    # plt.show()
    return fig, ax, results

## Data

In [6]:
savepath = 'notebooks/paper_rpft_update_12_24/code_data_replicability/mogp/'
print(savepath)

notebooks/paper_rpft_update_12_24/code_data_replicability/mogp/


In [7]:
PATH = 'data'
FILE_DATA = 'data_fig4_mogp.csv'
FILE_CLUST = 'data_fig4_clusters.csv'
# df_run2.to_csv(os.path.join(PATH, FILE), index=False)
df_run = pd.read_csv(os.path.join(PATH, FILE_DATA))
df_run.rename(columns={'subject_order':'user_id'}, inplace=True)

df_clusters_all = pd.read_csv(os.path.join(PATH, FILE_CLUST))
display(df_run.head())
display(df_clusters_all.head())

Unnamed: 0,user_id,session_id_hash,UTC_date,days_since_first_session,months_since_first_session,years_since_first_session,n_fvc_atleast_usable,pftType,fvcPercPred_2019,vcPercPred,is_proctored,slope_fvcPercPred_2019,intercept_fvcPercPred_2019,slope_vcPercPred,intercept_vcPercPred,y_pred_fvcPercPred_2019,y_pred_vcPercPred
0,pALS 31,d198f225e1ad61f6db090ff9a11697f4578bbab87f237e...,2023-03-08,29,0.966667,0.079452,2,fvc,0.83,,True,-0.001184,0.857624,,,0.857624,
1,pALS 31,fed0e8f99d12054066518cc1139315278e72c09d9a0b88...,2023-03-14,35,1.166667,0.09589,4,fvc,0.89,,True,-0.001184,0.857624,,,0.850518,
2,pALS 31,e5705d02b8cb67a9fa62407f28c35c2e95c755cc4a19ce...,2023-03-21,42,1.4,0.115068,3,fvc,0.9,,True,-0.001184,0.857624,,,0.842228,
3,pALS 31,24874555757aa344d0b9d6bbd389fac0e5e48d2d538c31...,2023-03-29,50,1.666667,0.136986,4,fvc,0.8,,True,-0.001184,0.857624,,,0.832754,
4,pALS 31,2a4c4538b8e4f09a63444b0f737fb1dc1c3319b274043f...,2023-04-04,56,1.866667,0.153425,4,fvc,0.78,,True,-0.001184,0.857624,,,0.825648,


Unnamed: 0,id,name,n_subjects,intercept,intercept_std,slope,sd,cluster_sorted_intercept
0,2,B,4,122.7,0.2,-0.41,0.02,A
1,7,A,14,102.7,0.5,-0.02,0.04,B
2,4,C,12,81.1,0.6,-0.95,0.05,C
3,11,D,4,51.6,1.2,-2.09,0.1,D


## Figure

In [8]:
# from eals_radcliff.data_modules.mogp.plotting import plot_clusters_trajectories_mean_std2
time_col = 'years_since_first_session'
dependant_variable = 'fvcPercPred_2019'

xlim_dict = {
    "years_since_first_session": (0, 1.5),
    # "months_since_first_session": (0, 1.5*12),
    # "days_since_first_session": (0, 1.5*365.25),
}

dict_var_names = {'fvcPercPred_2019':'FVC % predicted',
                  'vcPercPred':'SVC % predicted',
                  'months_since_first_session': 'Months since enrollment',
                  'months_since_onset': 'Months since onset',
                  'days_since_first_session': 'Days since enrollment',
                  'days_since_onset': 'Days since onset',
                  'years_since_first_session': 'Years since enrollment',
                  'years_since_onset': 'Years since onset',}

In [9]:
fig, ax, res = plot_clusters_trajectories_mean_std(
                                                    '',
                                                    df_run,#df_run,
                                                    time_col,
                                                    dependant_variable,
                                                    df_clust= df_clusters_all,
                                                    substract_initial_value=False,
                                                    xlim_dict=xlim_dict,
                                                    plot_regression_line=True,
                                                    plot_confidence=True,
                                                )

ax.set_title('')
if time_col == 'years_since_first_session':
    ax.set_xticks(np.arange(0, 1.5, 1/12 * 4))
ax.set_ylabel(dict_var_names[dependant_variable])
ax.set_yticks(ax.get_yticks())  # Mantener las posiciones actuales de los ticks
ax.set_yticklabels([int(tick * 100) for tick in ax.get_yticks()])  # Escalar las etiquetas de los ticks
ax.set_ylim(0, 1.5)
ax.set_xlim(0, 16/12)

ax.set_xticks(ax.get_xticks())  # Mantener las posiciones actuales de los ticks
ax.set_xticklabels([int(np.round(tick * 12)) for tick in ax.get_xticks()])  # Escalar las etiquetas de los ticks
ax.set_xlabel(dict_var_names['months_since_first_session'])

# Put the legend outside the plot
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

# turn off legend
# plt.gca().get_legend().remove()

plt.yticks(np.arange(0,1.51,0.5))

loading model from: ./mogp/fvcPercPred_2019.pkl


KeyError: 0