In [72]:
from pathlib import Path
from enum import Enum, auto
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns
from tqdm import tqdm
from scipy.spatial import distance
from scipy.cluster import hierarchy


from src.custom_model_videos_res18 import Resnet18Rnn
from src.data_loader_diagnosis_videos import get_loader

class Task(Enum):
    IJA = auto()
    RJA_LOW = auto()
    RJA_HIGH = auto()

In [62]:
binary_handles = [
    Patch(facecolor="g", edgecolor="g", label="TD"),
    Patch(facecolor="y", edgecolor="y", label="ASD"),
]
binary_labels = ["TD", "ASD"]

multiclass_handles = [
    Patch(facecolor="g", edgecolor="g", label="non-ASD"),
    Patch(facecolor="y", edgecolor="y", label="mild-moderate ASD"),
    Patch(facecolor="blue", edgecolor="blue", label="severe ASD"),
]
multiclass_labels = ["non-ASD", "mild-moderate ASD", "severe ASD"]

In [71]:
for target_name in [
                    'label', 
                    'sev_ados'
                    ] :
    for task_name in ['IJA', 'RJA_LOW', 'RJA_HIGH'] :
        if target_name == 'label' : 
            color_map = {0 : 'g', 1 : 'y'}
            project_path = Path(f'BINARY_FOLD_{target_name}').joinpath(task_name)
            handles = binary_handles
            class_labels = binary_labels
        else : 
            color_map = {0 : 'g', 1 : 'y', 2 : 'blue'}
            project_path = Path(f'MULTI_FOLD_{target_name}').joinpath(task_name)
            handles = multiclass_handles
            class_labels = multiclass_labels
        
        if task_name == 'IJA' : 
            new_task_name = 'IJA'
            seq_len = 100
        elif task_name == 'RJA_LOW' :
            new_task_name = 'RJA_low'
            seq_len = 50
        elif task_name == 'RJA_HIGH' :
            new_task_name = 'RJA_high'
            seq_len = 50
            
        df_path = project_path.joinpath('participant_information_df.csv')
        df = pd.read_csv(df_path)
        df['file_num'] = df['file_name'].apply(lambda x : int(x[-2:].replace('_','')))
        id_severity_dict = dict(zip(df['id'], df['severity']))
        file_num_file_name_dict = dict(zip(df['file_num'], df['file_name']))

        file_df = df.groupby(['id'])['file_num'].agg('min').reset_index()
        file_df['file_name'] = file_df['file_num'].map(file_num_file_name_dict)
        file_df['severity'] = file_df['id'].map(id_severity_dict)

        whole_alphas = []
        for fold_num in tqdm(range(10)) : 
            alpha_path = project_path.joinpath(f"fold_{fold_num}").joinpath(f"alphas_v1.npy")
            alphas = np.load(alpha_path)
            whole_alphas.append(alphas)
        whole_alphas = np.array(whole_alphas)
        whole_alphas = whole_alphas.mean(axis=0)

        alpha_df = pd.DataFrame(whole_alphas, columns=[f'frame_{i}' for i in range(1,whole_alphas.shape[1]+1)])

        label_colors = file_df['severity'].map(color_map).values

        correlations = alpha_df.corr()
        correlations_array = np.asarray(alpha_df.corr())

        row_linkage = hierarchy.linkage(
            distance.pdist(alpha_df), method='average')


        yticklabels = file_df['id'].to_list()
        
        
        clust = sns.clustermap( alpha_df, 
                        # row_linkage=row_linkage, #col_linkage=col_linkage, 
                        row_colors=label_colors, #col_colors=label_colors,
                        method="average", col_cluster=False,
                        figsize=(15, 20), cmap="rocket_r",
                        dendrogram_ratio=(0.1, 0.2),
                        cbar_pos=(1, 0, 0.03, 1),
                        yticklabels=yticklabels,
                        robust = True
                        )


        clust.ax_heatmap.set_xlabel("Video Frames", fontsize=10)
        clust.ax_heatmap.set_ylabel("Participant ID", fontsize=10)
        clust.fig.subplots_adjust(top=0.9)
        clust.cax.set_position([1.03, 0.065, 0.03, 0.665])
        clust.cax.set_ylabel("Attention Weight", fontsize=10)
        clust.cax.tick_params(labelsize=10)

        clust.ax_heatmap.legend(
            handles,
            class_labels,
            # title="label",
            loc="lower left",
            bbox_to_anchor=(0, 1.00),
            ncol = 2 if target_name == 'label' else 3,
            fontsize=10,
            frameon=False,
        )
        clust.ax_heatmap.set_xticks(np.arange(0, seq_len+1, 10) + 0.5, minor=False)
        clust.ax_heatmap.set_xticklabels(np.arange(0, seq_len * 3 + 1, 30), minor=False, fontsize=10)

        save_path = Path(f"plots/{project_path.parent.name}/clustermap_{task_name}.pdf")
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=300, transparent=True,format='pdf') # 
        plt.close()
        # break
    # break

100%|██████████| 10/10 [00:00<00:00, 1454.29it/s]
100%|██████████| 10/10 [00:00<00:00, 2124.24it/s]
100%|██████████| 10/10 [00:00<00:00, 2168.83it/s]
100%|██████████| 10/10 [00:00<00:00, 2169.06it/s]
100%|██████████| 10/10 [00:00<00:00, 2155.90it/s]
100%|██████████| 10/10 [00:00<00:00, 2148.83it/s]


In [None]:
Path('MULTI_FOLD_sev_ados').joinpath('')