In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
src_path: str = "../../src"
sys.path.append(src_path)

In [None]:
from components.functional_analysis.orgdb import OrgDB

org_db = OrgDB("Homo sapiens")

In [None]:
tcga_su2c_root = Path("/share/computing/Perez/projects_backups/TCGA_PRAD_SU2C_RNASeq")
figures_path = Path("/media/ssd/Perez/storage").joinpath(
    "TCGA_PRAD_SU2C_RNASeq_manuscript_figures"
)
figures_path.mkdir(parents=True, exist_ok=True)

## Figure 1: Sample/dataset distribution

---


In [None]:
annot_df = pd.read_csv(
    Path("/media/ssd/Perez/storage/TCGA_PRAD_SU2C_RNASeq")
    .joinpath("data")
    .joinpath("samples_annotation_tcga_prad_su2c_clusters_clinical_trt_fixed.csv"),
    index_col=0,
)
annot_df["sample_cluster_no_replicates"] = annot_df[
    "sample_cluster_no_replicates"
].str.lower()
annot_df = annot_df[
    annot_df["sample_cluster_no_replicates"].isin(("norm", "prim", "met_bb"))
]
annot_df["project_id"] = annot_df["project_id"].fillna("SU2C-PCF")
annot_df.loc[
    annot_df["sample_cluster_no_replicates"] == "met_bb", "sample_cluster_no_replicates"
] = "met"

In [None]:
df = (
    annot_df.groupby(["project_id", "sample_cluster_no_replicates"])["patient_id"]
    .count()
    .to_frame()
    .unstack(level=1)
    .sort_index(ascending=False)
    .droplevel(level=0, axis=1)
)
df

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=100)
palette = {
    "norm": "#9ACD32",
    "prim": "#4A708B",
    "met": "#8B3A3A",
}
ax = df.plot(kind="bar", stacked=True, color=palette, ax=ax, width=0.8)

for p in ax.patches:
    if p.get_height() != 0:
        pad = p.get_width() // 2
        ax.annotate(
            str(int(p.get_height())),
            (p.get_x() + 0.4, p.get_height() - 30),
            weight="bold",
            fontsize=12,
            ha="center",
        )

plt.legend(title="Sample type")
plt.ylabel("Number of samples")
plt.xlabel("Dataset of origin")
plt.xticks(rotation=0)
plt.title("Dataset sample type distribution")
ax.set_axisbelow(True)
plt.grid(axis="y", zorder=-10)
plt.tight_layout()
plt.savefig(figures_path.joinpath("1_sample_dist.pdf"))
plt.savefig(figures_path.joinpath("1_sample_dist.svg"))

## Figure 6/7: WGCNA Network

---


### A: prim/norm

In [None]:
hub_genes_df = pd.read_csv(
    tcga_su2c_root.joinpath("wgcna")
    .joinpath(
        "sample_cluster_no_replicates_MET_BB+NORM+PRIM__PRIM_vs_NORM_padj_0_05_up_1_0"
    )
    .joinpath("standard")
    .joinpath("results")
    .joinpath("bicor_signed_hub_genes.csv"),
    index_col=0,
)
mod_hub_genes = hub_genes_df["SYMBOL"].to_dict()
del mod_hub_genes["M0"]
mod_hub_genes

In [None]:
network_file = (
    tcga_su2c_root.joinpath("rich_network")
    .joinpath(
        "sample_cluster_no_replicates_MET_BB+NORM+PRIM__PRIM_vs_NORM_padj_0_05_up_1_0"
    )
    .joinpath("standard")
    .joinpath("random_forest")
    .joinpath(
        "sample_cluster_no_replicates_MET_BB+NORM+PRIM__"
        "PRIM_vs_NORM_padj_0_05_up_1_0_corr_th_0_1_full_wgcna_ml_net.graphml"
    )
)
graph = nx.read_graphml(network_file)

In [None]:
graph_nodes_df = pd.DataFrame(dict(graph.nodes)).transpose().sort_values("module")
graph_nodes_df["shap_value"] = graph_nodes_df["shap_value"].astype(float)
graph_nodes_df["log2FoldChange"] = graph_nodes_df["log2FoldChange"].astype(float)
graph_nodes_df.head()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=300)

ax = sns.stripplot(
    data=graph_nodes_df,
    x="module",
    y="shap_value",
    hue="module",
    dodge=False,
    jitter=False,
    ax=ax,
)

ax = sns.boxplot(
    data=graph_nodes_df,
    x="module",
    y="shap_value",
    hue="module",
    dodge=False,
    ax=ax,
)

# annotate top genes by SHAP value
for mod_id, (mod, mod_df) in enumerate(graph_nodes_df.groupby("module")):
    mod_df.sort_values("shap_value", ascending=False, inplace=True)
    text_y = 99
    for i in range(5):
        gene_meta = mod_df.iloc[i]
        if gene_meta["shap_value"] > 0.2:
            text_y_new = gene_meta["shap_value"] + 0.03
            text_y_new = (
                (text_y := text_y_new)
                if text_y - text_y_new > 0.03
                else (text_y := text_y - 0.03)
            )
            plt.annotate(
                gene_meta.name,
                xy=(mod_id, gene_meta["shap_value"]),
                xytext=(mod_id + 0.2, text_y_new),
                arrowprops=dict(facecolor="black", arrowstyle="->"),
                fontsize=10,
            )

# annotate module hub genes
for mod_id, (mod, hub_gene) in enumerate(mod_hub_genes.items()):
    plt.annotate(
        hub_gene,
        xy=(mod_id, graph_nodes_df.loc[hub_gene, "shap_value"]),
        xytext=(mod_id - 0.1, -0.05),
        ha="left",
        arrowprops=dict(facecolor="black", arrowstyle="->", connectionstyle="Angle3"),
        fontweight="bold",
        fontsize=12,
    )

plt.ylabel("SHAP value", fontsize=14)
plt.xlabel("WGCNA Module", fontsize=14)
plt.xticks(
    np.arange(0, graph_nodes_df["module"].nunique(), 1),
    [f"{m}\n(n={c})" for m, c in graph_nodes_df["module"].value_counts().items()],
    fontsize=12,
)
plt.yticks(fontsize=12)
plt.title(
    "Distribution of shapely values\nWGCNA prim/norm (up-regulated DEGs)", fontsize=16
)
ax.set_axisbelow(True)
plt.grid(axis="y", zorder=-10)
plt.ylim((-0.1, 1.1))
plt.xlim((-1, graph_nodes_df["module"].nunique()))
plt.legend([], [], frameon=False)
plt.tight_layout()
plt.savefig(figures_path.joinpath("6_prim_norm_wgcna.pdf"))
plt.savefig(figures_path.joinpath("6_prim_norm_wgcna.svg"))

### B: met/prim

In [None]:
hub_genes_df = pd.read_csv(
    tcga_su2c_root.joinpath("wgcna")
    .joinpath(
        "sample_cluster_no_replicates_MET_BB+NORM+PRIM__MET_BB_vs_PRIM_padj_0_05_up_1_0"
    )
    .joinpath("standard")
    .joinpath("results")
    .joinpath("bicor_signed_hub_genes.csv"),
    index_col=0,
)
mod_hub_genes = hub_genes_df["SYMBOL"].to_dict()
del mod_hub_genes["M0"]
mod_hub_genes

In [None]:
network_file = (
    tcga_su2c_root.joinpath("rich_network")
    .joinpath(
        "sample_cluster_no_replicates_MET_BB+NORM+PRIM__MET_BB_vs_PRIM_padj_0_05_up_1_0"
    )
    .joinpath("standard")
    .joinpath("random_forest")
    .joinpath(
        "sample_cluster_no_replicates_MET_BB+NORM+PRIM__"
        "MET_BB_vs_PRIM_padj_0_05_up_1_0_corr_th_0_1_full_wgcna_ml_net.graphml"
    )
)
graph = nx.read_graphml(network_file)

In [None]:
graph_nodes_df = pd.DataFrame(dict(graph.nodes)).transpose().sort_values("module")
graph_nodes_df["shap_value"] = graph_nodes_df["shap_value"].astype(float)
graph_nodes_df["log2FoldChange"] = graph_nodes_df["log2FoldChange"].astype(float)
graph_nodes_df.head()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=300)

ax = sns.stripplot(
    data=graph_nodes_df,
    x="module",
    y="shap_value",
    hue="module",
    dodge=False,
    jitter=False,
    ax=ax,
)

ax = sns.boxplot(
    data=graph_nodes_df,
    x="module",
    y="shap_value",
    hue="module",
    dodge=False,
    ax=ax,
)

# annotate top genes by SHAP value
for mod_id, (mod, mod_df) in enumerate(graph_nodes_df.groupby("module")):
    mod_df.sort_values("shap_value", ascending=False, inplace=True)
    text_y = 99
    for i in range(5):
        gene_meta = mod_df.iloc[i]
        if gene_meta["shap_value"] > 0.2:
            text_y_new = gene_meta["shap_value"] + 0.03
            text_y_new = (
                (text_y := text_y_new)
                if text_y - text_y_new > 0.03
                else (text_y := text_y - 0.03)
            )
            plt.annotate(
                gene_meta.name,
                xy=(mod_id, gene_meta["shap_value"]),
                xytext=(mod_id + 0.2, text_y_new),
                arrowprops=dict(facecolor="black", arrowstyle="->"),
                fontsize=10,
            )

# annotate module hub genes
for mod_id, (mod, hub_gene) in enumerate(mod_hub_genes.items()):
    plt.annotate(
        hub_gene,
        xy=(mod_id, graph_nodes_df.loc[hub_gene, "shap_value"]),
        xytext=(mod_id - 0.1, -0.05),
        ha="left",
        arrowprops=dict(facecolor="black", arrowstyle="->", connectionstyle="Angle3"),
        fontweight="bold",
        fontsize=12,
    )

plt.ylabel("SHAP value", fontsize=14)
plt.xlabel("WGCNA Module", fontsize=14)
plt.xticks(
    np.arange(0, graph_nodes_df["module"].nunique(), 1),
    [f"{m}\n(n={c})" for m, c in graph_nodes_df["module"].value_counts().items()],
    fontsize=12,
)
plt.yticks(fontsize=12)
plt.title(
    "Distribution of shapely values\nWGCNA met/prim (up-regulated DEGs)", fontsize=16
)
plt.ylim((-0.1, 1.1))
plt.xlim((-1, graph_nodes_df["module"].nunique()))
ax.set_axisbelow(True)
plt.grid(axis="y", zorder=-10)
plt.legend([], [], frameon=False)
plt.tight_layout()
plt.savefig(figures_path.joinpath("6_met_prim_wgcna.pdf"))
plt.savefig(figures_path.joinpath("6_met_prim_wgcna.svg"))