In [None]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
from scipy.stats import gaussian_kde

# Visualize pathway features

In [None]:
# Helper functions
def get_results_data(pkl_path):
    """Load data from pickle file."""
    results = pickle.load(open(pkl_path, 'rb'))        
    return results

def get_data_pathways(pathways, exp_name, data_type, fold):
    """ Get the mean pathway expression of all samples per pathway together with the predicted risk score (group level). """
    rna_data = pd.read_csv(f'../data/data_files/tcga_{data_type}/rna/rna_data.csv', index_col=0)
    hallmark_pathways = pd.read_csv('../data/data_files/hallmarks_signatures.csv')
    test_data = pd.read_csv(f'../data/data_files/tcga_{data_type}/splits/{fold}/test.csv', index_col=1)
    test_cases = test_data['slide_id'].values

    # Get predicted risks
    results_dir = f'../results/dss_survival_{data_type}/{exp_name}/Fold_{fold}/predicted_risk_scores_test.pkl'
    results_dict = get_results_data(results_dir)
    risk = results_dict['risk']

    plot_data = []
    for sample in test_cases:
        # Get the predicted risk for the specific case
        case_id_index = results_dict['slide_ids'].index(sample)
        assert results_dict['slide_ids'][case_id_index] == sample, "Correct case ID not found in results."
        risk_value_case = risk[case_id_index]

        for pathway in pathways:
            # Get the RNA data for the specific case
            rna_data_case = rna_data[rna_data['Unnamed: 0'] == sample]
            pathway_name = hallmark_pathways.columns[pathway].replace('_', ' ')
            genes = hallmark_pathways[hallmark_pathways.columns[pathway]].values

            genes = rna_data_case.columns.intersection(genes)
            all_expression_data = rna_data_case[genes].values

            # Calculate the mean expression of the pathway (for this sample)
            mean_expression = np.mean(all_expression_data)

            plot_data.append({
                "Sample": sample,
                "Pathway": f"{pathway_name} R({pathway})",
                "Mean expression": mean_expression,
                "risk": risk_value_case
            })

    plot_data_group = pd.DataFrame(plot_data)
    return plot_data_group

def get_info_case(case, pathway_indices, exp_name, data_type, fold):
    """ Get the gene expression distribution per pathway of one sample together with the predicted risk score (sample level). """
    rna_data = pd.read_csv(f'../data/data_files/tcga_{data_type}/rna/rna_data.csv', index_col=0)
    hallmark_pathways = pd.read_csv('../data/data_files/hallmarks_signatures.csv')
    train_data = pd.read_csv(f'../data/data_files/tcga_{data_type}/splits/{fold}/train.csv', index_col=1)
    train_cases = train_data['slide_id'].values

    # Get predicted risks
    results_dir = f'../results/dss_survival_{data_type}/{exp_name}/Fold_{fold}/predicted_risk_scores_test.pkl'
    results_dict = get_results_data(results_dir)
    risk = results_dict['risk']

    # Get the predicted risk for the specific case
    case_id_index = results_dict['slide_ids'].index(case)
    assert results_dict['slide_ids'][case_id_index] == case, "Correct case ID not found in results."
    risk_value_case = risk[case_id_index]

    data_ridge_plot = get_data_ridge(pathway_indices, case, risk_value_case, hallmark_pathways, rna_data, train_cases)
    return data_ridge_plot

def get_data_ridge(pathway_nr, case, risk_case, all_pathways, all_rna_data, tr_cases):
    """Get data for ridge plot."""
    data = {}
    rna_data_case = all_rna_data[all_rna_data['Unnamed: 0'] == case]
    rna_data_train = all_rna_data[all_rna_data['Unnamed: 0'].isin(tr_cases)]

    for j, i in enumerate(pathway_nr):

        #Find pathway name
        pathway_name = all_pathways.columns[i].replace('_', ' ')

        # Get the gene expresssion distribution of this pathway and this sample
        genes = all_pathways[all_pathways.columns[i]].values
        case_ex_genes = rna_data_case.columns.intersection(genes)
        case_expression_data = rna_data_case[case_ex_genes].values[0]

        # Get the mean gene expresssion distribution of this pathway  (over all train cases)
        all_ex_genes = rna_data_train.columns.intersection(genes)
        all_expression_data = rna_data_train[all_ex_genes].values
        # Calculate the mean expression (over the samples) for the pathway
        mean_expression = np.mean(all_expression_data, axis=0)
    
        # Number of genes should be the same in this same pathway
        if len(mean_expression) is not len(case_expression_data):
            print(f"Warning: Length mismatch for pathway {pathway_name}. Expected {len(mean_expression)}, got {len(case_expression_data)}.")
            return

        data_pt = {
            'name': " ".join(pathway_name.split()[1:]),
            'risk': risk_case,
            'genes': np.array(case_ex_genes).flatten(),
            'values': np.array(case_expression_data),
            'mean_values': np.array(mean_expression),
        }
        
        data[f"Pathway {j+1}"] = data_pt

    
    return data

In [None]:
# Plot functions
def pathway_swarm_plot(plot_data, cmap = "coolwarm", save_fig_name=None):
    """ Plot the mean pathway expression of all samples per pathway together with the predicted risk score (group level). """
    
    # Normalize the color scale, indiciated by the predicted risk score, around zero.
    risk_scores = np.log2(plot_data["Risk scores"])
    max_val = np.max(np.abs(risk_scores))
    norm = TwoSlopeNorm(vmin=-max_val, vcenter=0, vmax=max_val)
    cmap = plt.get_cmap(cmap)

    fig, ax = plt.subplots(figsize=(10,8))

    # Create the swarm plot
    sns.swarmplot(
        data=plot_data,
        x="Mean expression",
        y="Pathway",
        hue='shap',
        hue_norm=norm,
        palette=cmap,
        legend=False,
        size=4.8  # default is 5
    )

    # add colorbar manually
    sm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
    cbar = fig.colorbar(sm, ax=ax, orientation='horizontal', shrink=0.3, aspect=10, pad=0.1)
    cbar.set_label("log2risk", fontsize=8)

    # Plot params
    for spine in ["top", "right", "left"]:
        ax.spines[spine].set_visible(False)
    ax.spines["bottom"].set_visible(True)
    ax.tick_params(axis="x", bottom=True)
    ax.tick_params(axis="y", left=True) 
    ax.set_ylabel("")
    ax.set_xlabel("Mean pathway expression")

    if save_fig_name:
        plt.savefig(f"{save_fig_name}.pdf", dpi=300, bbox_inches='tight')

    plt.show()

def plot_ridge_pathways(pathway_data, pathway_indices, plot_name=None, spacing=0.7, ylim=[0, 8], xlim=[-7, 12]):
    fig, ax = plt.subplots(figsize=(8, 6))

    pathways = list(pathway_data.values())

    for i, (pathway) in reversed(list(enumerate(pathways))):
        values = np.array(pathway['values'])
        mean_values = np.array(pathway['mean_values'])

        # KDE
        kde = gaussian_kde(values)
        x_range = np.linspace(values.min(), values.max(), 500)
        y = kde(x_range)

        kde_mean = gaussian_kde(mean_values)
        x_range_mean = np.linspace(mean_values.min(), mean_values.max(), 500)
        y_mean = kde_mean(x_range_mean)
        # y is 
        # Vertical offset for stacking
        offset = i * spacing

        # Full-width horizontal baseline
        ax.hlines(offset, -7, 12, color='black', lw=1, zorder=0)

        # Mean ridge
        ax.fill_between(x_range_mean, offset, y_mean + offset, color='grey', alpha=0.4)
        ax.plot(x_range_mean, y_mean + offset, color='grey', lw=1)

        # Fill ridge
        ax.fill_between(x_range, offset, y + offset, color='green', alpha=0.4)
        ax.plot(x_range, y + offset, color='green', lw=1)


    # Y-ticks
    ax.set_yticks([i * spacing for i in range(len(pathways))])
    labels = [f'{p["name"]} (R{pathway_indices[i]})' for i, p in enumerate(pathways)]

    ax.set_yticklabels(labels)
    ax.set_xlabel("Gene expression")

    ax.set_ylim(ylim[0], ylim[1])
    ax.set_xlim(xlim[0], xlim[1])

    # Remove spines for cleaner style
    for spine in ['top', 'right', 'left']:
        ax.spines[spine].set_visible(False)

    plt.tight_layout()
    # save in pdf format
    if plot_name:
        plt.savefig(f"{}_ridge_plot.pdf", dpi=300, bbox_inches='tight')

    plt.show()


###  **(1)** Plotting the mean pathway expression of all samples per pathway together with the predicted risk score (group level).

In [None]:
pathways = [12, 29, 45, 13]
fold = 2
exp_name = 'DIMAF'
data_type = 'brca'
plot_data_group = get_data_pathways(pathways, exp_name, data_type, fold)
pathway_swarm_plot(plot_data_group)

### **(2)**  Plotting the gene expression distribution per pathway of one sample together with the predicted risk score (sample level).

In [None]:
case = "TCGA-BH-A1EV"
pathways_case = [12, 29, 45, 13]
fold = 2
exp_name = 'DIMAF'
data_type = 'brca'
data = get_info_case(case, pathways_case, exp_name, data_type, fold)
plot_ridge_pathways(data, pathways_case)
