# TSNE plots

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from os.path import join, split
import pandas as pd
from itertools import combinations
from deep_utils import JsonUtils
from enum import StrEnum
import os
import seaborn as sns
sns.set_theme(style='white')

In [None]:
data_path = "latent_features"
output_path = "images_paper/tsne"
original_data_dir = "data/"
os.makedirs(output_path, exist_ok=True)
test_external_array = np.load(join(data_path, "test_external_features.npz"), allow_pickle=True)
train_array = np.load(join(data_path,"train_features.npz"), allow_pickle=True)
test_internal_array = np.load(join(data_path, "test_internal_features.npz"), allow_pickle=True)
array_dict = dict(test_external=test_external_array, train=train_array, test_internal=test_internal_array)

In [None]:
class TitleName(StrEnum):
    test_external = "Test External"
    test_internal = "Test Internal"
    train = "Train"

def get_genders(sample_names: list):
    """
    Extract gender names
    :param sample_names: 
    :return: 
    """
    output = []
    if gender:
        for name in sample_names:
            if name in internal_gender_data:
                output.append(internal_gender_data.get(name).strip())
                continue
            else:
                output.append(external_gender_data.get(name).strip())
    else:
        output = [gender for _ in sample_names]
    return output
    
def add_gender(filename, gender):
    """
    Add gender to the output path
    :param filename: 
    :param gender: 
    :return: 
    """
    from deep_utils import DirUtils
    if gender:
        filename = DirUtils.split_extension(filename, suffix=f"_{gender}")
    return filename

In [None]:
gender = "" # m f ""

label_increase = dict(test_external=0, train=4, test_internal=6)
model_type = "tsne"
names = ['test_external', 'train', 'test_internal']
internal_gender_data = JsonUtils.load("internal_gender_img_names.json")
external_gender_data = JsonUtils.load("external_gender_img_names.json")
graph_combinations =  list(combinations(names, 1)) + list(combinations(names, 3)) + list(combinations(names, 2))
base_colors = [r"#ff7979", r"#50ad76", r"#d36efa"]
colors = base_colors[:2] * 3
markers = ["^", "^", "o", "o", "X", "X"] 
dpi = 500

## Anomaly Detection

In [None]:
task = "anomaly_detection"
size = 20
plt.rc('legend', fontsize=size) 

label_map = {
    0: f"{TitleName.test_external} Normal",
    1: f"{TitleName.test_external} AAOCA",
    4: f"{TitleName.train} Normal",
    5: f"{TitleName.train} AAOCA",
    6: f"{TitleName.test_internal} Normal",
    7: f"{TitleName.test_internal} AAOCA",
}

for combination in graph_combinations:
    if combination == ('test_external', 'train', 'test_internal'):
        loc = "upper right"
    else:
        loc = "best"
    path = join(data_path, model_type + "_" + "-".join(combination) + ".npz")
    array = np.load(path)
    
    # y_names = array["y_names"]
    gender_filter = [g == gender for g in get_genders(array["sample_names"])]
    x_reduced = array["x_reduced"][gender_filter]
    y_names = [label_map[i] for i, j in zip(array["y"], gender_filter) if j]
    result_df = pd.DataFrame({'x': x_reduced[:, 0], 'y': x_reduced[:, 1], 'label': y_names})
    result_df = result_df.sort_values("label")
    plt.figure(figsize=(16, 16))

    sns.scatterplot(x='x', y='y', 
                    hue='label', 
                    style="label",
                    data=result_df, 
                    s=50, 
                    palette=colors[:2 * len(combination)],
                    markers=markers[:2 * len(combination)]
                   )
    plt.xlim((np.min(x_reduced[:, 0]) - 1, np.max(x_reduced[:, 0]) + 1))
    plt.ylim((np.min(x_reduced[:, 1]) - 1, np.max(x_reduced[:, 1]) + 1))
    
    legend = plt.legend(bbox_to_anchor=(1.13, 1.00), loc=loc, borderaxespad=0.0, fontsize=20, markerscale=2)
    plt.axis("off")
    plt.savefig(join(output_path, add_gender(task + "_" + model_type + "_" + "-".join(combination)+ ".jpg", gender)), bbox_inches='tight', dpi=dpi)
    
    result_df.to_excel(join(output_path, add_gender(task + "_" + model_type + "_" + "-".join(combination)+ ".xlsx", gender)), index=False)
    plt.close()
    print(f"done: {combination}")

## Origin Classification

In [None]:
def get_names(data_path: str, name: str):
    """
    Get the names
    :param data_path: 
    :param name: 
    :return: 
    """
    data = {"labels": [], "img_names": []}
    for dirpath, dirnames, filenames in os.walk(data_path):
        for filename in filenames:
            if filename == "img_cropped_resampled.nii.gz":
                lbl_str = dirpath.split("/")[-3]
                lbl = int(lbl_str)
                data['img_names'].append(split(dirpath)[-1])
                data['labels'].append(lbl + label_increase[name])
    return data

In [None]:
label_map = {
    0: f"{TitleName.test_external} R-AAOCA",
    1: f"{TitleName.test_external} L-AAOCA",
    2: f"{TitleName.train} R-AAOCA",
    3: f"{TitleName.train} L-AAOCA",
    4: f"{TitleName.test_internal} R-AAOCA",
    5: f"{TitleName.test_internal} L-AAOCA",
}

label_increase = dict(test_external=0, train=2, test_internal=4)
task = 'origin_classification'

test_external_data = get_names(join(original_data_dir, "origin_classification/test_external"), "test_external")
train_data  = get_names(join(original_data_dir, "origin_classification/train"), "train")
test_internal_data  = get_names(join(original_data_dir, "origin_classification/test_internal"), "test_internal")

data_dict = dict(test_external=test_external_data, train=train_data, test_internal=test_internal_data)

colors = base_colors[:2] * 3
test_external_data

In [None]:
for combination in graph_combinations:
    if combination == ('test_external', 'train', 'test_internal'):
        loc = "lower right"
    else:
        loc = "best"
    array = np.load(join(data_path, model_type + "_" + "-".join(combination)+".npz"))
    data = dict(labels=[], img_names=[])
    for name in combination:
        d = data_dict[name]
        data['labels'].extend(d['labels'])
        data['img_names'].extend(d['img_names'])
        
    sample_names = array['sample_names']
    genders = get_genders(sample_names)
    x_reduced = array["x_reduced"]
    preds = array['preds']
    name_lbl = {name: lbl for name, lbl in zip(data['img_names'], data['labels'])}
    new_x, new_y, new_labels, new_names, new_preds = [], [], [], [], []

    for index, name in enumerate(sample_names):
        if name in name_lbl and genders[index] == gender:
            new_x.append(x_reduced[:, 0][index])
            new_y.append(x_reduced[:, 1][index])
            new_labels.append(name_lbl[name])
            new_names.append(name)
            new_preds.append(preds[index])
    
    result_df = pd.DataFrame({'x': new_x, 'y': new_y, 'label': [label_map[lbl] for lbl in new_labels]})
    result_df = result_df.sort_values("label")
    plt.figure(figsize=(16, 16))
    
    sns.scatterplot(x = 'x',
                    y = 'y', 
                    hue = 'label', 
                    style = "label",
                    data = result_df, 
                    s = 50, 
                    palette = colors[:2 * len(combination)],
                    markers = markers[:2 * len(combination)]
                    #hue_kws=dict(marker=markers[:2 * len(combination)])
                   )
    plt.xlim((np.min(new_x) - 1, np.max(new_x) + 1))
    plt.ylim((np.min(new_y) - 1, np.max(new_y) + 1))
    plt.legend(loc=loc, fontsize="18", markerscale=2)
    plt.axis("off")
    plt.savefig(join(output_path, add_gender(task + "_" + model_type + "_" + "-".join(combination)+ ".jpg", gender)), bbox_inches='tight', dpi=dpi)
    plt.close()
    result_df.to_excel(join(output_path, add_gender(task + "_" + model_type + "_" + "-".join(combination)+ ".xlsx", gender)), index=False)
    print(f"done: {combination}")

## Risk Classification

In [None]:
label_map = {
    0: f"{TitleName.test_external} Low Risk",
    1: f"{TitleName.test_external} High Risk",
    2: f"{TitleName.train} Low Risk",
    3: f"{TitleName.train} High Risk",
    4: f"{TitleName.test_internal} Low Risk",
    5: f"{TitleName.test_internal} High Risk",
}

label_increase = dict(test_external=0, train=2, test_internal=4)
task = 'risk_classification'

test_external_data = get_names(join(original_data_dir, "risk_classification/test_external"), "test_external")
train_data  = get_names(join(original_data_dir, "risk_classification/train"), "train")
test_internal_data  = get_names(join(original_data_dir, "risk_classification/test_internal"), "test_internal")

data_dict = dict(test_external=test_external_data, train=train_data, test_internal=test_internal_data)

colors = base_colors[:2] * 3
markers = ["^", "^", "o", "o", "X", "X"] 
n_samples = 2

In [None]:
for combination in graph_combinations:
    if combination == ('test_external', 'train', 'test_internal'):
        loc = "lower right"
    else:
        loc = "best"
    array = np.load(join(data_path, model_type + "_" + "-".join(combination)+".npz"))
    data = dict(labels=[], img_names=[])
    for name in combination:
        d = data_dict[name]
        data['labels'].extend(d['labels'])
        data['img_names'].extend(d['img_names'])
        
    sample_names = array['sample_names']
    genders = get_genders(sample_names)
    x_reduced = array["x_reduced"]
    preds = array['preds']
    name_lbl = {name: lbl for name, lbl in zip(data['img_names'], data['labels'])}
    new_x, new_y, new_labels, new_names, new_preds = [], [], [], [], []

    for index, name in enumerate(sample_names):
        if name in name_lbl and genders[index] == gender:
            new_x.append(x_reduced[:, 0][index])
            new_y.append(x_reduced[:, 1][index])
            new_labels.append(name_lbl[name])
            new_names.append(name)
            new_preds.append(preds[index])
    
    result_df = pd.DataFrame({'x': new_x, 'y': new_y, 'label': [label_map[lbl] for lbl in new_labels]})
    result_df = result_df.sort_values("label")
    plt.figure(figsize=(16, 16))
    
    sns.scatterplot(x = 'x',
                    y = 'y', 
                    hue = 'label', 
                    style = "label",
                    data = result_df, 
                    s = 50, 
                    palette = colors[:n_samples * len(combination)],
                    markers = markers[:n_samples * len(combination)]
    )
    plt.xlim((np.min(new_x) - 1, np.max(new_x) + 1))
    plt.ylim((np.min(new_y) - 1, np.max(new_y) + 1))
    plt.legend(loc=loc, fontsize="18", markerscale=2)
    plt.axis("off")
    plt.savefig(join(output_path, add_gender(task + "_" + model_type + "_" + "-".join(combination)+ ".jpg", gender)), bbox_inches='tight', dpi=dpi)
    plt.close()
    result_df.to_excel(join(output_path, add_gender(task + "_" + model_type + "_" + "-".join(combination)+ ".xlsx", gender)), index=False)
    print(f"done: {combination}")