In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import seaborn as sns
import textwrap

from cuml import PCA, UMAP, HDBSCAN
from tqdm.auto import tqdm
tqdm.pandas(desc="Processing rows")

from src import Tokenizer
from src.tokenizer import ICDLevel

sns.set_style('whitegrid')

In [None]:
def adjust_alpha(
        num_points,
        min_alpha=0.2,
        max_alpha=1.0,
        min_points=20000,
        max_points=500000,
        exponent=-5,
):
    # Ensure num_points is within the specified range
    num_points = max(min_points, min(max_points, num_points))

    factor = (num_points - min_points) / (max_points - min_points)
    scaler = (max_alpha - min_alpha) * np.exp(exponent * factor)

    # Calculate the alpha based on exponentially decreasing scaling
    alpha = min_alpha + scaler * (1 - factor)

    return alpha

In [None]:
def sort_categories(categories, category_index):
    if category_index is not None:
        return sorted(categories, key=lambda x: get_sort_value(x, category_index))
    else:
        return sorted(categories)

In [None]:
def plot_latent(sample_info_df, column, sort=False, category_index=None, x_lim=None, y_lim=None, legend=True, ax=None, other=False, title=None, palette=None, legend_title=None, set_legend_pos=False, extra_legends=None):
    if ax is None:
        if legend:
            f, ax = plt.subplots(figsize=(15, 12))
        else:
            f, ax = plt.subplots(figsize=(12, 12))

    if sort and isinstance(column, str):
        sample_info_df = sample_info_df.sort_values(column)

    if isinstance(column, str):
        plot_column = sample_info_df[column]
    else:
        plot_column = column

    num_non_nan = plot_column.notna().sum()
    categorical = not pd.api.types.is_numeric_dtype(plot_column) or plot_column.dtype == bool

    cat_alpha = adjust_alpha(num_non_nan, min_points=20000, max_points=500000, min_alpha=0.2, max_alpha=0.6)
    other_alpha = adjust_alpha(num_non_nan, min_points=20000, max_points=500000, min_alpha=0.01, max_alpha=0.1)

    s = adjust_alpha(num_non_nan, min_points=20000, max_points=500000, min_alpha=3, max_alpha=15)

    if categorical:
        value_counts = plot_column.value_counts()
        num_categories = len(value_counts)

        if num_categories <= 20:
            hue_column = plot_column.copy()
            hue_order = sort_categories(value_counts.index.tolist(), category_index)
            hue_norm = None

            if palette is None:
                palette = "tab20" if len(value_counts) > 10 else "tab10"
            alpha = cat_alpha

            if other:
                if palette is None:
                    palette = sns.color_palette(palette, n_colors=20)
                    del palette[14]  # delete gray color
                    palette = palette[:num_categories-1]
                    palette.append("gray")

                # mask is last value of hue_order
                mask = ~plot_column.isin(hue_order[-1:])
                alpha = other_alpha * ~mask + cat_alpha * mask

            # title = f"All {num_categories} categories"

        else:
            top_k_categories = value_counts.nlargest(19).index.tolist()
            mask = plot_column.isin(top_k_categories)

            hue_column = plot_column.copy()
            hue_column[~mask] = "Other"

            hue_order = sort_categories(top_k_categories, category_index) + ["Other"]
            hue_norm = None
            if palette is None:
                palette = sns.color_palette("tab20", n_colors=20)
                del palette[14]  # delete gray color
                palette.append("gray")

            alpha = other_alpha * ~mask + cat_alpha * mask
            # title = f"Top {len(top_k_categories)} categories"

    else:
        hue_column = plot_column.astype(float)
        hue_norm = (plot_column.quantile(0.05), plot_column.quantile(0.95))
        hue_order = None
        alpha = cat_alpha
        if palette is None:
            palette = 'crest'
            palette = sns.cubehelix_palette(start=.5, rot=-.5, as_cmap=True)
        # title = f"Quantile normalized {column} values"

    scatterplot = sns.scatterplot(
        data=sample_info_df,
        x=0,
        y=1,
        hue=hue_column,
        hue_norm=hue_norm,
        hue_order=hue_order,
        linewidth=0,
        alpha=alpha,
        s=s / 2,
        palette=palette,
        legend=legend,
        ax=ax,
    )

    # if ax is None:
    #     plt.title(None)
    #     plt.xlabel(None)
    #     plt.ylabel(None)
    #
    #     plt.grid(False)
    #     plt.box(False)
    #     plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False, labelleft=False)
    # else:
    ax.set_title(title)
    ax.set_xlabel(None)
    ax.set_ylabel(None)

    ax.grid(False)
    # ax.set_frame_on(False)
    ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False,
                   labelleft=False)

    # if x_lim is not None:
    #     plt.xlim(*x_lim)
    # if y_lim is not None:
    #     plt.ylim(*y_lim)

    if x_lim is not None:
        ax.set_xlim(*x_lim)
    if y_lim is not None:
        ax.set_ylim(*y_lim)

    if legend:
        handles, labels = scatterplot.get_legend_handles_labels()

        if hue_norm is not None:
            has_low = False
            has_high = False

            for label in labels:
                if float(label) < hue_norm[0]:
                    has_low = True
                elif float(label) > hue_norm[1]:
                    has_high = True

            num_ticks = 6
            tick_locations = np.linspace(hue_norm[0], hue_norm[1], num_ticks)
            labels = [f"{int(i)}" for i in tick_locations]

            normed_values = (tick_locations - hue_norm[0]) / (hue_norm[1] - hue_norm[0])
            colors = palette(normed_values)

            handles = [plt.Line2D([0], [0], marker='o', color='w', label=label,
                                  markersize=7, markerfacecolor=color) for label, color in zip(labels, colors)]

            if has_low:
                labels[0] = f"<{labels[0]}"

            if has_high:
                labels[-1] = f">{labels[-1]}"

        labels = [textwrap.fill(label, 20) for label in labels]

        if extra_legends is not None:
            extra_handles, extra_labels = extra_legends

            labels += extra_labels
            handles += extra_handles
        
        if set_legend_pos:
            ax.legend(handles=handles, labels=labels, bbox_to_anchor=(1., 0.5), loc="center left", title=legend_title)
        else:
            ax.legend(handles=handles, labels=labels, loc="upper right", title=legend_title)
        

    if ax is None:
        f.tight_layout()
        return f

In [None]:
model_path = 'saved_models/2024-01-23_13-47-36/checkpoint_157001'
representations = {modality: pd.read_parquet(f'{model_path}/{modality}_representations.parquet') for modality in ['disease', 'drug', 'personal', 'lab']}
full_data = pd.read_parquet('data/processed/ukb.parquet')
data = full_data.loc[representations['lab'].index]

In [None]:
def get_umap(df, n_neighbors=15, min_dist=0.1, min_cluster_size=50, min_samples=50, umap_metric="cosine", n_components=2):
    pca_data = df.to_numpy()
    
    umap = UMAP(n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist, metric=umap_metric, random_state=42)
    umap_data = umap.fit_transform(pca_data)
    return pd.DataFrame(umap_data, index=df.index) 

def get_pca(df, n_components=None):
    hidden_states = df.to_numpy()

    if n_components is None:
        pca = PCA(n_components=hidden_states.shape[1], random_state=42).fit(hidden_states)
        n_components = np.argmax(np.cumsum(pca.explained_variance_ratio_) >= 0.95) + 1

    pca = PCA(n_components=n_components, random_state=42)
    pca_data = pca.fit_transform(hidden_states)

    return pd.DataFrame(pca_data, index=df.index)

pca_dfs = {modality: get_pca(representation) for modality, representation in representations.items()}

for mod in pca_dfs:
    print('PCA > .95 shape of', mod, pca_dfs[mod].shape)

umap_dfs = {modality: get_umap(representation) for modality, representation in representations.items()}

In [None]:
polypharmacy = data['drug_codes'].apply(lambda x: len(set(x)))
multimorbidity = data['disease_codes'].apply(len)

In [None]:
f, ax = plt.subplots(1, 2, figsize=(12, 6))
lim = 15
plot_latent(umap_dfs['disease'], multimorbidity, ax=ax[0], legend_title='Multimorbidity', x_lim=[-lim, lim], y_lim=[-lim, lim])

missing_drugs = polypharmacy == 0

scatterplot = sns.scatterplot(
        data=umap_dfs['disease'][missing_drugs],
        x=0,
        y=1,
        color='gray',
        linewidth=0,
        alpha=0.003,
        s=3,
        legend=False,
        ax=ax[1],
    )

missing_handle = mlines.Line2D([], [], color='gray', marker='o', linestyle='None',
                               markersize=5, label='Missing')

plot_latent(umap_dfs['disease'][~missing_drugs], polypharmacy[~missing_drugs], ax=ax[1], legend_title='Polypharmacy', extra_legends=[[missing_handle], ['Missing']], x_lim=[-lim, lim], y_lim=[-lim, lim])

# f.tight_layout()
# plt.savefig('fused_latent_burden.png', dpi=300)

In [None]:
f, axs = plt.subplots(2, 2, figsize=(16, 15))
lim = 15

for i, ((mod, df), ax) in enumerate(zip(umap_dfs.items(), axs.flatten())):

    scatterplot = sns.scatterplot(
        data=df[missing_drugs],
        x=0,
        y=1,
        color='gray',
        linewidth=0,
        alpha=0.003,
        s=3,
        legend=False,
        ax=ax,
    )

    missing_handle = mlines.Line2D([], [], color='gray', marker='o', linestyle='None',
                                   markersize=5, label='Missing')
    
    plot_latent(df[~missing_drugs], polypharmacy[~missing_drugs], ax=ax, legend_title='Polypharmacy', extra_legends=[[missing_handle], ['Missing']], x_lim=[-lim, lim], y_lim=[-lim, lim])
    
    # plot_latent(df, multimorbidity, ax=ax, legend_title='Multimorbidity', x_lim=[-lim, lim], y_lim=[-lim, lim], title=mod, set_legend_pos=True)

# plt.legend(bbox_to_anchor=(0, 0.21), loc="upper left", title='Disease Group Present')
f.tight_layout()

In [None]:
PP_per_MM = polypharmacy / multimorbidity

In [None]:
relative_mask = (polypharmacy != 0) & (multimorbidity != 0)

In [None]:
f, axs = plt.subplots(2, 2, figsize=(8, 9))
axs = axs.flatten()
lim = 15

titles = {
    'disease' : 'Disease',
    'drug' : 'Prescriptions',
    'personal': 'Personal',
    'lab': 'Laboratory',
}

for i, mod in enumerate(umap_dfs):
    ax = axs[i]
    df = umap_dfs[mod]

    scatterplot = sns.scatterplot(
        data=df[~relative_mask],
        x=0,
        y=1,
        color='gray',
        linewidth=0,
        alpha=0.003,
        s=3,
        legend=False,
        ax=ax,
    )

    missing_handle = mlines.Line2D([], [], color='gray', marker='o', linestyle='None',
                                   markersize=5, label='NA')
    
    plot_latent(df[relative_mask], PP_per_MM[relative_mask], ax=ax, legend_title='PP/MM', extra_legends=[[missing_handle], ['NA']],
                x_lim=[-lim, lim], y_lim=[-lim, lim], legend=mod=='drug', title=titles[mod])
    
# plt.legend(bbox_to_anchor=(0, 0.31), loc="upper left")
f.tight_layout()
plt.savefig('pppermm_latent_all_mod.jpg', dpi=450)

In [None]:
f, axs = plt.subplots(2, 2, figsize=(8, 9))
axs = axs.flatten()
lim = 15

titles = {
    'disease' : 'Disease',
    'drug' : 'Prescriptions',
    'personal': 'Personal',
    'lab': 'Laboratory',
}

for i, mod in enumerate(umap_dfs):
    ax = axs[i]
    df = umap_dfs[mod]

    scatterplot = sns.scatterplot(
        data=df[~relative_mask],
        x=0,
        y=1,
        color='gray',
        linewidth=0,
        alpha=0.003,
        s=3,
        legend=False,
        ax=ax,
    )

    missing_handle = mlines.Line2D([], [], color='gray', marker='o', linestyle='None',
                                   markersize=5, label='NA')
    
    plot_latent(df[relative_mask], PP_per_MM[relative_mask], ax=ax, legend_title='PP/MM', extra_legends=[[missing_handle], ['NA']],
                x_lim=[-lim, lim], y_lim=[-lim, lim], legend=mod=='drug', title=titles[mod])
    
# plt.legend(bbox_to_anchor=(0, 0.31), loc="upper left")
f.tight_layout()
plt.savefig('pppermm_latent_all_mod.jpg', dpi=450)

In [None]:
f, axs = plt.subplots(2, 2, figsize=(8, 9))
axs = axs.flatten()
lim = 15

titles = {
    'disease' : 'Fused Disease',
    'drug' : 'Drug Prescriptions',
    'personal': 'Personal',
    'lab': 'Laboratory',
}

for i, mod in enumerate(umap_dfs):
    ax = axs[i]
    df = umap_dfs[mod]

    scatterplot = sns.scatterplot(
        data=df[missing_drugs],
        x=0,
        y=1,
        color='gray',
        linewidth=0,
        alpha=0.03,
        s=3,
        legend=False,
        ax=ax,
    )

    missing_handle = mlines.Line2D([], [], color='gray', marker='o', linestyle='None',
                                   markersize=5, label='Missing')
    
    plot_latent(df[~missing_drugs], polypharmacy[~missing_drugs], ax=ax, legend_title='Polypharmacy', extra_legends=[[missing_handle], ['Missing']],
                x_lim=[-lim, lim], y_lim=[-lim, lim], legend=mod=='drug', title=titles[mod])
    
# plt.legend(bbox_to_anchor=(0, 0.31), loc="upper left")
f.tight_layout()
plt.savefig('pp_latent_all_mod.jpg', dpi=450)

In [None]:
f, axs = plt.subplots(2, 2, figsize=(8, 9))
axs = axs.flatten()
lim = 15

titles = {
    'disease' : 'Fused Disease',
    'drug' : 'Drug Prescriptions',
    'personal': 'Personal',
    'lab': 'Laboratory',
}

for i, mod in enumerate(umap_dfs):
    ax = axs[i]
    df = umap_dfs[mod]
    plot_latent(df, multimorbidity, ax=ax, legend_title='Multimorbidity', x_lim=[-lim, lim], y_lim=[-lim, lim], title=titles[mod], legend=mod=='drug')

# plt.legend(bbox_to_anchor=(0, 0.31), loc="upper left")
f.tight_layout()
plt.savefig('mm_latent_all_mod.jpg', dpi=450)

In [None]:
f, axs = plt.subplots(2, 2, figsize=(8, 9))
axs = axs.flatten()
lim = 15

titles = {
    'disease' : 'Fused Disease',
    'drug' : 'Drug Prescriptions',
    'personal': 'Personal',
    'lab': 'Laboratory',
}

for i, mod in enumerate(umap_dfs):
    ax = axs[i]
    df = umap_dfs[mod]
    plot_latent(df, data['31-0.0'], ax=ax, legend_title='Sex', x_lim=[-lim, lim], y_lim=[-lim, lim], title=titles[mod], legend=mod=='drug')

# plt.legend(bbox_to_anchor=(0, 0.31), loc="upper left")
f.tight_layout()
plt.savefig('sex_latent_all_mod.jpg', dpi=450)

In [None]:
tokenizer = Tokenizer.from_pretrained('pretrained_tokenizers/tokenizer')

In [None]:
chapter = data['disease_codes'].apply(lambda x: pd.Series(tokenizer.icd_handler.finest_level_to_level_token(x, ICDLevel.Chapter)).value_counts().index[0] if len(x) != 0 else 'NA')

In [None]:
f, axs = plt.subplots(2, 2, figsize=(8, 9))
axs = axs.flatten()
lim = 15

titles = {
    'disease' : 'Fused Disease',
    'drug' : 'Drug Prescriptions',
    'personal': 'Personal',
    'lab': 'Laboratory',
}

for i, mod in enumerate(umap_dfs):
    ax = axs[i]
    df = umap_dfs[mod]
    plot_latent(df, chapter, ax=ax, legend_title='ICD Chapter', x_lim=[-lim, lim], y_lim=[-lim, lim], title=titles[mod], legend=mod=='drug', set_legend_pos=True)

# plt.legend(bbox_to_anchor=(0, 0.31), loc="upper left")
f.tight_layout()
plt.savefig('disease_latent_all_mod.jpg', dpi=450)