# Metadata exploration

This notebook explores the metadata of the TCGA-PRAD and SU2C datasets, which includes clinical information.

---
## Setup and imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.pyplot import figure
from sklearn import preprocessing

In [None]:
src_path: str = "../../src"
sys.path.append(src_path)

### Global variables

In [None]:
root: Path = Path("/media/ssd/Perez/storage")

data_root: Path = root.joinpath("TCGA_PRAD_SU2C_RNASeq")

data_path: Path = data_root.joinpath("data")

counts_path = data_path.joinpath(Path("star_counts"))

deseq2_path: Path = data_root.joinpath("deseq2")

plots_path: Path = data_root.joinpath("plots")

In [None]:
data_su2c_root: Path = root.joinpath("SU2C_PCF_2019_RNASeq")
data_su2c_annotations_path = data_su2c_root.joinpath("samples_annotations")

In [None]:
anno_file: Path = data_path.joinpath("samples_annotation_tcga_prad_su2c_clusters.csv")
annot_df = pd.read_csv(anno_file, index_col=0)

In [None]:
from components.functional_analysis.orgdb import OrgDB
from data.utils import filter_genes_wrt_annotation
from r_wrappers.utils import map_gene_id

In [None]:
org_db = OrgDB("Homo sapiens")

---
## 1. TCGA-PRAD

...


---
## 2. SU2C

In [None]:
su2c_annotation = pd.read_csv(
    data_su2c_annotations_path.joinpath("samples_annotation_rna_downloaded.csv"),
    index_col=0,
)

In [None]:
su2c_dds = pd.read_csv(deseq2_path.joinpath("Metastatic_dds.csv"), index_col=0)
su2c_vst = pd.read_csv(deseq2_path.joinpath("Metastatic_vst.csv"), index_col=0)

### 2.1. Select SU2C samples and add extended annotation

In [None]:
annot_df_su2c = deepcopy(annot_df[annot_df["sample_type"] == "Metastatic"]).join(
    su2c_annotation
)

### 2.2. Number of patients in common among clusters

In [None]:
grouped_patient_ids = annot_df_su2c.groupby("sample_cluster").agg(list)["patient_id"]
grouped_patient_ids

In [None]:
common_patients = set.intersection(*grouped_patient_ids.map(set))
print(
    f"There are {len(common_patients)} common patients between sample clusters"
    f" {set(annot_df_su2c['sample_cluster'])}: \n{common_patients}"
)

### 2.3. Explore fields that have between 2 and 15 categories.

In [None]:
annot_df_su2c_nunique = annot_df_su2c.nunique()
print(
    "Fields with a single value: \n"
    f" {annot_df_su2c_nunique[annot_df_su2c_nunique == 1].index.tolist()}"
)

In [None]:
annot_df_su2c_categories = deepcopy(
    annot_df_su2c.loc[:, (annot_df_su2c_nunique != 1) & (annot_df_su2c_nunique <= 15)]
)

#### 2.3.1. Plot field category counts between sample clusters

In [None]:
categorical_fields = [c for c in annot_df_su2c_categories.columns if "cluster" not in c]
print(f"Categorical fields: \n{categorical_fields}")
annot_df_su2c_categories = annot_df_su2c_categories.loc[
    :, categorical_fields + ["sample_cluster"]
]

In [None]:
df_melted = annot_df_su2c_categories.melt(
    id_vars=["sample_cluster"], var_name="cat_feature"
)
df_melted_groupped = df_melted.value_counts().reset_index()
df_melted_groupped = df_melted_groupped.rename(columns={0: "counts"})
df_melted_groupped.to_csv(
    data_path.joinpath("su2c_clusters_categorical_features.csv"), index=False
)

In [None]:
features = set(df_melted_groupped["cat_feature"])

fig_size = int(len(features) * 1.5)
plt.figure(figsize=(fig_size, fig_size), dpi=300)
sns.set(style="whitegrid")
for ax_num, feature in enumerate(features):
    plt.subplot(len(features) // 4, 5, ax_num + 1)
    sns.barplot(
        x="sample_cluster",
        y="counts",
        hue="value",
        data=df_melted_groupped[df_melted_groupped["cat_feature"] == feature],
    )
    plt.title(feature)

plt.tight_layout()
plt.savefig(plots_path.joinpath("su2c_clusters_categorical_features_comparison.pdf"))
plt.clf()

#### 2.3.2. Plot field category counts between sample clusters (only shared patients)

In [None]:
annot_df_su2c_categories = annot_df_su2c_categories.loc[
    annot_df_su2c["patient_id"].isin(common_patients)
]

In [None]:
df_melted = annot_df_su2c_categories.melt(
    id_vars=["sample_cluster"], var_name="cat_feature"
)
df_melted_groupped = df_melted.value_counts().reset_index()
df_melted_groupped = df_melted_groupped.rename(columns={0: "counts"})
df_melted_groupped.to_csv(
    data_path.joinpath("su2c_clusters_common_patients_categorical_features.csv"),
    index=False,
)

In [None]:
features = set(df_melted_groupped["cat_feature"])

fig_size = int(len(features) * 1.5)
plt.figure(figsize=(fig_size, fig_size), dpi=300)
sns.set(style="whitegrid")
for ax_num, feature in enumerate(features):
    plt.subplot(len(features) // 4, 5, ax_num + 1)
    sns.barplot(
        x="sample_cluster",
        y="counts",
        hue="value",
        data=df_melted_groupped[df_melted_groupped["cat_feature"] == feature],
    )
    plt.title(feature)

plt.tight_layout()
plt.savefig(
    plots_path.joinpath(
        "su2c_clusters_common_patients_categorical_features_comparison.pdf"
    )
)
plt.clf()

### 2.4. Explore continous fields (e.g., age)

In [None]:
annot_df_su2c_continous = deepcopy(
    annot_df_su2c.loc[
        :, (annot_df_su2c_nunique != 1) & (annot_df_su2c.dtypes is not object)
    ]
)
annot_df_su2c_continous["sample_cluster"] = annot_df_su2c["sample_cluster"]

#### 2.4.1. Plot field category counts between sample clusters

In [None]:
continuous_fields = [
    c for c in annot_df_su2c_continous.columns if "id" not in c.lower()
]
print(f"Continuous fields: \n{continuous_fields}")
annot_df_su2c_continous = annot_df_su2c_continous.loc[:, continuous_fields]

In [None]:
df_melted = annot_df_su2c_continous.melt(
    id_vars=["sample_cluster"], var_name="cont_feature"
)
df_melted_groupped = df_melted.value_counts().reset_index()
df_melted_groupped = df_melted_groupped.rename(columns={0: "counts"})
df_melted_groupped.to_csv(
    data_path.joinpath("su2c_clusters_continous_features.csv"), index=False
)

In [None]:
sns.set(style="whitegrid")

features = set(df_melted["cont_feature"])
fig_size = int(len(features) * 1.5)
fig = figure(figsize=(fig_size, fig_size), dpi=300)
for ax_num, feature in enumerate(features):
    plt.subplot(len(features) // 4, 5, ax_num + 1)
    sns.violinplot(
        x="sample_cluster",
        y="value",
        data=df_melted[df_melted["cont_feature"] == feature],
    )
    plt.title(feature)

plt.tight_layout()
fig.savefig(plots_path.joinpath("su2c_clusters_continous_features_comparison.pdf"))
plt.clf()

#### 2.4.2. Plot field category counts between sample clusters (only shared patients)

In [None]:
annot_df_su2c_continous = annot_df_su2c_continous.loc[
    annot_df_su2c["patient_id"].isin(common_patients)
]

In [None]:
df_melted = annot_df_su2c_continous.melt(
    id_vars=["sample_cluster"], var_name="cont_feature"
)
df_melted_groupped = df_melted.value_counts().reset_index()
df_melted_groupped = df_melted_groupped.rename(columns={0: "counts"})
df_melted_groupped.to_csv(
    data_path.joinpath("su2c_clusters_common_patients_continous_features.csv"),
    index=False,
)

In [None]:
sns.set(style="whitegrid")

features = set(df_melted["cont_feature"])
fig_size = int(len(features) * 1.5)
fig = figure(figsize=(fig_size, fig_size), dpi=300)
for ax_num, feature in enumerate(features):
    plt.subplot(len(features) // 4, 5, ax_num + 1)
    sns.violinplot(
        x="sample_cluster",
        y="value",
        data=df_melted[df_melted["cont_feature"] == feature],
    )
    plt.title(feature)

plt.tight_layout()
fig.savefig(
    plots_path.joinpath(
        "su2c_clusters_common_patients_continous_features_comparison.pdf"
    )
)
plt.clf()

### 2.5. Gene expression distribution distance between clusters

#### 2.5.1. Compute gene expression median difference between clusters

In [None]:
su2c_vst_scaled = deepcopy(
    su2c_vst.loc[filter_genes_wrt_annotation(su2c_vst.index, org_db), :].transpose()
)
su2c_vst_scaled[:] = preprocessing.MinMaxScaler().fit_transform(su2c_vst_scaled)
su2c_vst_scaled

In [None]:
su2c_vst_scaled["sample_cluster"] = annot_df.loc[
    su2c_vst_scaled.index, "sample_cluster"
]

In [None]:
su2c_vst_clustered_median = (
    su2c_vst_scaled.groupby("sample_cluster").median().transpose()
)
su2c_vst_clustered_median

In [None]:
su2c_vst_clustered_median_diff = abs(
    su2c_vst_clustered_median["Metastatic_A"]
    - su2c_vst_clustered_median["Metastatic_B"]
).sort_values(ascending=False)
su2c_vst_clustered_median_diff

In [None]:
sum(su2c_vst_clustered_median_diff > 0.5)

#### 2.5.2. Plot genes with highest median difference

In [None]:
top_n = 50
top_genes = su2c_vst_clustered_median_diff[:top_n].index
top_genes

In [None]:
su2c_vst_clustered_plots = su2c_vst_scaled.loc[:, [*top_genes, "sample_cluster"]]
su2c_vst_clustered_plots.rename(
    columns=dict(zip(top_genes, map_gene_id(top_genes, org_db, "ENSEMBL", "SYMBOL"))),
    inplace=True,
)

In [None]:
df_melted = su2c_vst_clustered_plots.melt(id_vars=["sample_cluster"], var_name="gene")
df_melted

In [None]:
sns.set(style="whitegrid")

genes_symbol = set(df_melted["gene"])
fig_size = int(len(features) * 1.5)
fig = figure(figsize=(fig_size, fig_size), dpi=300)
for ax_num, gene_symbol in enumerate(genes_symbol):
    plt.subplot(len(genes_symbol) // 5, 5, ax_num + 1)
    sns.violinplot(
        y="sample_cluster",
        x="value",
        data=df_melted[df_melted["gene"] == gene_symbol],
    )
    plt.title(gene_symbol)

plt.tight_layout()
fig.savefig(
    plots_path.joinpath(f"su2c_clusters_top_{top_n}_distributed_genes_comparison.pdf")
)
plt.clf()