In [None]:
# start coding here
# start coding here
import numpy as np
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from tqdm.auto import tqdm
import torch
from scipy.stats import ks_2samp
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib

matplotlib.style.use(snakemake.input.mpl_style)

In [None]:
# load

scores_df = pd.read_parquet(snakemake.input.cw_transcriptome_term_scores)
scores_df.head()

In [None]:
gsva_results = pd.read_parquet(snakemake.input.gsva_results).set_index("Unnamed: 0")
gsva_library = gsva_results.pop("library")
gsva_results.head()

In [None]:
assert (gsva_results.columns == scores_df.columns).all()

# index does not match everywhere because we eliminated the GO terms (GO:123456) in the previous step
assert (gsva_results.index[0] == scores_df.index[0])

In [None]:
# Correlate on a per-row basis
row_correlations = np.array([pearsonr(scores_df.iloc[i, :], gsva_results.iloc[i, :]) for i in tqdm(range(scores_df.shape[0]), total=scores_df.shape[0])])

# Correlate on a per-column basis
column_correlations = np.array([pearsonr(scores_df.iloc[:, i], gsva_results.iloc[:, i]) for i in tqdm(range(scores_df.shape[1]), total=scores_df.shape[1])])

In [None]:
# Correlate on a per-row basis
row_correlations_spearman = np.array([spearmanr(scores_df.iloc[i, :], gsva_results.iloc[i, :]) for i in tqdm(range(scores_df.shape[0]), total=scores_df.shape[0])])

# Correlate on a per-column basis
column_correlations_spearman = np.array([spearmanr(scores_df.iloc[:, i], gsva_results.iloc[:, i]) for i in tqdm(range(scores_df.shape[1]), total=scores_df.shape[1])])

In [None]:
gene_corr_df = pd.DataFrame(np.concatenate([column_correlations, column_correlations_spearman], axis=1), index=gsva_results.columns, columns=["pearson_rho", "pearson_p",  "spearman_rho", "spearman_p"])
gene_corr_df["type"] = "gene"
term_corr_df = pd.DataFrame(np.concatenate([row_correlations, row_correlations_spearman], axis=1), index=gsva_results.index, columns=["pearson_rho", "pearson_p", "spearman_rho", "spearman_p"])
term_corr_df["type"] = "term"

gene_corrs = gene_corr_df["pearson_rho"]
term_corrs = term_corr_df["pearson_rho"]

In [None]:
fig, ax = plt.subplots(figsize=(3, 3))

sns.ecdfplot(term_corrs, ax=ax)
ax.axhline(0.5)
ax.axvline(0.0)
ax.set_title("Distribution of term-level correlations")
plt.tight_layout()
fig.savefig(snakemake.output.term_level_correlations)

In [None]:
# get the top hit
term_corrs.sort_values(ascending=False).iloc[:50]

In [None]:
term_corrs.head()

In [None]:
term_corr_df.head()


In [None]:
fig, ax = plt.subplots(figsize=(5, 3))
ax.axvline(0.0)
term_corr_df["library"] = gsva_library
order = term_corr_df.groupby("library")["pearson_rho"].mean().sort_values(ascending=False).index
sns.violinplot(data=term_corr_df, x="pearson_rho", y="library", ax=ax, order=order)
ax.set_xlabel('correlation for 14k samples', ha="right")

# Calculate the number of significant terms for each library
significance_level = 0.05
term_corr_df["significant"] = (term_corr_df["pearson_p"] < significance_level) & (term_corr_df["pearson_rho"] > 0)
significance_counts = term_corr_df.groupby('library')['significant'].sum()
total_counts = term_corr_df.groupby('library')['significant'].count()
for i, library in enumerate(order):
    n_sign = significance_counts[library]
    n_total = total_counts[library]
    text_label = f"* {n_sign} of {n_total}"
    # You may need to adjust the x and y coordinates to place the text appropriately
    
    ax.text(0.9, i, text_label, color='black', ha="left", va="center")

ax.axvline(0.0, color='black')
sns.despine()
plt.tight_layout()
fig.savefig(snakemake.output.library_correlations)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(3, 3))

selected_diseases = ["ovarian cancer", "asthma", "thyroid carcinoma", "aids", "dyslexia", "brachydactyly", "nephropathy", "infections", "colorectal cancer"]
omim_df = term_corrs[gsva_library == "OMIM_Expanded"].sort_values()
sns.ecdfplot(omim_df, ax=ax)
ax.axhline(0.5)
ax.axvline(0.0)
ax.set_title("Correlation coefficients for 'OMIM_Expanded'")
for disease in selected_diseases:
    x = omim_df.loc[disease]
    y = omim_df.index.get_loc(disease)/len(omim_df)
    plt.text(x, y, disease)
    plt.scatter([x], [y], marker="x", color="black")
    

plt.tight_layout()
fig.savefig(snakemake.output.omim_correlations)

In [None]:
stats_results = {}
plot_df = []

# TODO also try threshold=1 later

for i, term in enumerate(tqdm(gsva_results.index)):
    pos = gsva_results.iloc[i][scores_df.iloc[i] > 0]
    neg = gsva_results.iloc[i][scores_df.iloc[i] <= 0]
    library = gsva_library.get(term)
    plot_df.append(
        pd.DataFrame({
            "term": term,
            "type": "positive CW score",
            "ids": pos.index,
            "gsva_score": pos.values,
            "library": library,
            
        })
    )
    plot_df.append(
        pd.DataFrame({
            "term": term,
            "type": "negative CW score",
            "ids": neg.index,
            "gsva_score": neg.values,
            "library": library,
        })
    )
    stats_results[term] = ks_2samp(pos.values, neg.values)
plot_df = pd.concat(plot_df)
plot_df["term"] = pd.Categorical(plot_df.term)

In [None]:
# TODO select some interesting terms

subdf = plot_df[plot_df.term.isin(["Pluripotent Stem Cells", "Lung V2 (HLCA)-ann Level 2-Smooth Muscle", "Response To Endoplasmic Reticulum Stress (GO:0034976)"])].copy()
subdf["term"] = subdf["term"].astype(str)  # plotting this is ass slow

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

sns.violinplot(data=subdf.iloc[::-1], x="term", y="gsva_score", hue="type", ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), ha="right", rotation=10)
plt.tight_layout()

fig.savefig(snakemake.output.cw_binarized_gsva_scores)

In [None]:
stats_results = pd.DataFrame([{
    "term": term, "statistic": ksres.statistic, "pvalue": ksres.pvalue, "sign": ksres.statistic_sign, "statistic_location": ksres.statistic_location
} for term, ksres in stats_results.items()])
stats_results["inv_signed_stat"] = (stats_results["sign"] * stats_results["statistic"] * -1)
stats_results["library"] = stats_results["term"].apply(gsva_library.get)
stats_results.set_index("term", inplace=True)


In [None]:
#  scatterplot of one term (cherry picked)
fig, ax = plt.subplots(figsize=(3, 3))

term = snakemake.params.selected_top_term
sns.scatterplot(x=gsva_results.loc[term], y=scores_df.loc[term], s=1, rasterized=True)
ax.set_xlabel("GSVA")
ax.set_ylabel("CellWhisperer score")
ax.set_title(f"\"{term}\"")
# ax.text(0, 8, f"ρ={term_corrs.loc[term]:.2f}")
ax.text(0, 10, f"KS stat={stats_results.loc[term, 'inv_signed_stat']:.2f}")

plt.tight_layout()
fig.savefig(snakemake.output.top_term_correlation)

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))
ax.axvline(0.0)
order = stats_results.groupby("library")["inv_signed_stat"].mean().sort_values(ascending=False).index
sns.violinplot(data=stats_results, y="library", x="inv_signed_stat", order=order)

# Add text annotations
# Calculate the number of significant terms for each library
significance_level = 0.05
stats_results['significant'] = (stats_results['pvalue'] < significance_level) & (stats_results['sign'] < 0)
significance_counts = stats_results.groupby('library')['significant'].sum()
total_counts = stats_results.groupby('library')['significant'].count()
for i, library in enumerate(order):
    n_sign = significance_counts[library]
    n_total = total_counts[library]
    text_label = f"* {n_sign} of {n_total}"
    # You may need to adjust the x and y coordinates to place the text appropriately
    
    ax.text(0.6, i, text_label, color='black', ha="left", va="center")
plt.tight_layout()
sns.despine()

fig.savefig(snakemake.output.library_ks_statistics)

In [None]:
fig, ax = plt.subplots(figsize=(3.5, 2.5))
term_corrs.groupby(gsva_library).mean().sort_values(ascending=True).plot.barh(ax=ax)
ax.set_xlabel('correlation for 14k samples', ha="right")
plt.tight_layout()
fig.savefig(snakemake.output.library_correlations)

In [None]:
plot_df.head()

In [None]:
top_diseases = stats_results.sort_values("inv_signed_stat", ascending=False).query("library == 'OMIM_Expanded'").iloc[:10].index

In [None]:
topn = 10
terms = stats_results.sort_values("inv_signed_stat", ascending=False).iloc[:topn].index  # TODO also take some of the top correlation
terms = terms.union(selected_diseases).union(top_diseases).unique()
fig, axes = plt.subplots(len(terms), 2, figsize=(5, 2 * len(terms)))
for i, term in enumerate(tqdm(terms)):

    sns.scatterplot(x=gsva_results.loc[term], y=scores_df.loc[term.split(' (GO:')[0]], s=1, ax=axes[i, 0], rasterized=True)  # TODO probably ned to get rid of the GO
    axes[i, 0].set_xlabel("GSVA")
    axes[i, 0].set_ylabel("CellWhisperer score")
    axes[i, 0].set_title(f"\"{term}\"")
    
    
    axes[i, 0].text(0, 7, f"ρ={term_corrs.loc[term]:.2f}")

    statistic_loc = stats_results.loc[term, "statistic_location"]
    pos_df = plot_df.loc[(plot_df.term == term) & (plot_df["type"] == "positive CW score"), "gsva_score"].sort_values()
    neg_df = plot_df.loc[(plot_df.term == term) & (plot_df["type"] == "negative CW score"), "gsva_score"].sort_values()
    y1 = (pos_df < statistic_loc).sum()/len(pos_df)
    y2 = (neg_df < statistic_loc).sum()/len(neg_df)

    axes[i, 1].plot([statistic_loc, statistic_loc],
                    [y1, y2], color="black")
    sns.ecdfplot(data=plot_df[plot_df.term == term], hue="type", x="gsva_score", ax=axes[i, 1])
    axes[i, 1].text(0.1, 0.8, f"KS stat={stats_results.loc[term, 'inv_signed_stat']:.2f}")

    axes[i, 1].get_legend().remove()

plt.tight_layout()
fig.savefig(snakemake.output.cherry_picked_examples)

In [None]:
single_term_df = []
term = terms[6]
pos = gsva_results.loc[term][scores_df.loc[term.split(' (GO:')[0]] > 2]
neg = gsva_results.loc[term][scores_df.loc[term.split(' (GO:')[0]] <= 2]
single_term_df.append(
    pd.DataFrame({
        "term": term,
        "type": "positive CW score",
        "ids": pos.index,
        "gsva_score": pos.values
    })
)
single_term_df.append(
    pd.DataFrame({
        "term": term,
        "type": "negative CW score",
        "ids": neg.index,
        "gsva_score": neg.values
    })
)

single_term_df = pd.concat(single_term_df)
single_term_df["library"] = single_term_df["term"].apply(gsva_library.get)
single_term_df["term"] = pd.Categorical(single_term_df.term)
sns.ecdfplot(data=single_term_df[single_term_df.term == term], hue="type", x="gsva_score")
ks_2samp(pos.values, neg.values)

In [None]:
term_corr_df["ks_statistic_signed_inv"] = stats_results["inv_signed_stat"]
term_corr_df["ks_statistic"] = stats_results["statistic"]
term_corr_df["ks_pvalue"] = stats_results["pvalue"]
term_corr_df["ks_sign"] = stats_results["sign"]
term_corr_df["ks_stat_location"] = stats_results["statistic_location"]
term_corr_df["ks_significant"] = stats_results["significant"]
term_corr_df.drop(columns=["significant"], inplace=True)


In [None]:
results_df = pd.concat([term_corr_df, gene_corr_df])
results_df.to_csv(snakemake.output.gsva_correlation_results)