# Feature Trajectories within each cluster for each patient

### Show how features evolve over time in clustered groups
### Features grouped using hierarchical clustering

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
# get selected features
fts_sel = pd.read_csv('../Output/Submission-Long/features/Features_Selected.csv')
fts_sel = fts_sel['Feature'].values

df_fts = pd.read_csv('../Output/Submission-Long/features/Features_Rescaled.csv')
df_fts = df_fts[df_fts['ContourType'] == 'Manual']
df_fts['Selected'] = 0
df_fts.loc[df_fts['Feature'].isin(fts_sel), 'Selected'] = 1


In [None]:
# get cluster labels
df_cluster_labels = pd.DataFrame()
labels_dir = '../Output/Submission/Clustering/Labels/'
csvs = os.listdir(labels_dir)

for csv in csvs:
    df = pd.read_csv(labels_dir + csv)
    df['PatID'] = csv.split('.')[0]
    df_cluster_labels = pd.concat([df_cluster_labels, df])


fts_clustered = df_cluster_labels['Feature'].unique()

df_cluster_labels

In [None]:
df_fts = df_fts[df_fts['Feature'].isin(fts_clustered)]
df_cluster_labels['PatID'] = df_cluster_labels['PatID'].astype(str)
df_fts['PatID'] = df_fts['PatID'].astype(str)
df_fts = df_fts.merge(df_cluster_labels, on=['Feature', 'PatID'], how='left')


df_fts.sample(5)

In [None]:
# loop through pat 1088 and plot each cluster and save
sns.set_style('whitegrid')

for pat in df_fts['PatID'].unique():
    df_pat = df_fts[df_fts['PatID'] == pat]

    print('-------------------')
    print(pat)

    for i, cluster in enumerate(df_pat['ClusterLabel'].unique()):
        df_cluster = df_pat[df_pat['ClusterLabel'] == cluster]

        fig, ax = plt.subplots(figsize=(4, 4))

        fts_cluster = df_cluster['Feature'].unique()
        for ft in fts_cluster:
            df_ft = df_cluster[df_cluster['Feature'] == ft]
            # change line properties based on selected or not
            color = 'grey' if df_ft['Selected'].values[0] == 0 else 'red'
            label = ft if df_ft['Selected'].values[0] == 1 else None
            alpha = 0.7 if df_ft['Selected'].values[0] == 0 else 1
            width = 2.5 if df_ft['Selected'].values[0] == 1 else 1
            
            ax.plot(df_ft['Fraction'], df_ft['FeatureValue'], color=color, linewidth=width, label=label, alpha=alpha)
        
        ax.set_title(f'Cluster - {i + 1}', fontsize=18)
        ax.set_ylim(0, 1)
        ax.set_xticks(np.arange(1, 5.1, 1))
        ax.yaxis.set_tick_params(labelsize=12)
        ax.xaxis.set_tick_params(labelsize=12)
        
        if not os.path.exists(f'./ExampleCluster/{pat}'):
            os.makedirs(f'./ExampleCluster/{pat}')
        plt.savefig(f'./ExampleCluster/{pat}/Cluster-{i+1}.png', dpi=300, bbox_inches='tight')
        plt.close()

In [None]:
# pick random patients
# loop through each feature and lineplot for each cluster

sns.set_style('whitegrid')

rand_pats = df_fts['PatID'].unique()
for pat in rand_pats:
    df_pat = df_fts[df_fts['PatID'] == pat]

    print(pat)

    # need to pivot the table based on fraction and feature value
    print(df_pat['ClusterLabel'].nunique())
    # make a subplot that is 4 plots per row
    num_rows = int(np.ceil(df_pat['ClusterLabel'].nunique() / 3))
    fig, axs = plt.subplots(num_rows, 3, figsize=(10, 5 * num_rows))

    for i, cluster in enumerate(df_pat['ClusterLabel'].unique()):
        df_cluster = df_pat[df_pat['ClusterLabel'] == cluster]

        if num_rows > 1:
            ax = axs[i // 3, i % 3]
        else:
            ax = axs[i]

        fts_cluster = df_cluster['Feature'].unique()
        for ft in fts_cluster:
            df_ft = df_cluster[df_cluster['Feature'] == ft]
            color = 'grey' if df_ft['Selected'].values[0] == 0 else 'red'
            label = ft if df_ft['Selected'].values[0] == 1 else None
            ax.plot(df_ft['Fraction'], df_ft['FeatureValue'], color=color, linewidth=2, label=label)
        ax.set_title(f'Cluster {i + 1}', fontsize=18)
        ax.set_ylim(0, 1)
        ax.set_xticks(np.arange(1, 5.1, 1))
        ax.legend(loc='upper right', fontsize=12)

        ax.yaxis.set_tick_params(labelsize=12)
        ax.xaxis.set_tick_params(labelsize=12)

        # if plot is empty, remove it
        if len(fts_cluster) == 0:
            fig.delaxes(ax)
    plt.savefig(f'./ExampleCluster/{pat}/{pat}_all.png', dpi=300, bbox_inches='tight')
    plt.show()
    