In [None]:
import os

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pathlib as pl
import scanpy as sc
import seaborn as sns

from matplotlib.colors import ListedColormap

# Preliminary: save only malignant

In [None]:
resdir = pl.Path("/path/to/datasets/")
savedir = pl.Path("/path/to/where/to/save")

for f in resdir.iterdir():
    ds_name = f.stem
    adata = sc.read_h5ad(resdir / f"{ds_name}.h5ad")
    os.makedirs(savedir / f"{ds_name}",exist_ok=True)
    adata[adata.obs.malignant_key=="malignant"].copy().write(savedir / f"{ds_name}" / "malignant.h5ad")

# Benchmark

### 1. Get the annotations for the true programs

In [None]:
datasetdir = pl.Path("/path/to/malignant/datasets")

annotdir = pl.Path("path/to/where/to/save/annotations")

def get_prog_sig(adata: ad.AnnData, program: str = "program"):
    adata.layers["counts"] = adata.X.copy()
    sc.pp.normalize_total(adata, target_sum=10000)
    sc.pp.log1p(adata)
    sc.tl.rank_genes_groups(adata, groupby=program)
    diff_gex = {}
    for prog in adata.obs[program].unique():
        diff_gex[prog] = sc.get.rank_genes_groups_df(adata, group=prog)
    adata.X = adata.layers["counts"].copy()
    del adata.uns["log1p"]
    return diff_gex

for d in datasetdir.iterdir():
    dataset_name = d.stem
    print(dataset_name)
    os.makedirs(annotdir / dataset_name, exist_ok=True)
    adata = sc.read_h5ad(d / "malignant.h5ad")
    diff_gex = get_prog_sig(adata=adata, program="program")
    for prog in diff_gex:
        diff_gex[prog].to_csv(annotdir / dataset_name / f"{prog}.csv")

### 2. Get the correlation between the found and true signatures

In [None]:
def score_sig(adata, signature, score_name):    
    adata.layers["counts"] = adata.X.copy()
    sc.pp.normalize_total(adata, target_sum=10000)
    sc.pp.log1p(adata)
    sc.tl.score_genes(adata, gene_list=signature, score_name=score_name)
    adata.X = adata.layers["counts"]
    del adata.uns["log1p"]
    
    return adata

In [None]:
resdir = pl.Path("path/to/cansig/results")
datadir = pl.Path("path/to/malignant/datasets")
annotdir = pl.Path("/path/to/annotations")
all_res = []
for f in resdir.iterdir():
        dataset_name = f.stem
        print(dataset_name)
        msdir = f / "metasignatures" / "signatures"
        dsannot = annotdir / f.stem 
        adata = sc.read_h5ad(datadir / f.stem / "malignant.h5ad")
        
        metasignatures = {}
        for ms in msdir.iterdir():
            name = ms.stem
            metasignatures[name] = pd.read_csv(ms, index_col=0)
        
        meta_list = list(metasignatures.keys())
        meta_list = list(np.setdiff1d(meta_list, ["outlier"]))
        
        knownsigs = {}
        for ann in dsannot.iterdir():
            name = ann.stem
            knownsigs[name] = pd.read_csv(ann).names
        
        for sig in meta_list:
            adata = score_sig(adata, metasignatures[sig].values.ravel()[:50], score_name=sig)
                
        for sig in knownsigs:
            adata = score_sig(adata, knownsigs[sig].values.ravel()[:50], score_name=sig)
        
        df_corr = adata.obs[meta_list+list(knownsigs.keys())].corr()
        df_corr = df_corr.loc[meta_list,list(knownsigs.keys())]
        
        n_metasigs = pd.Series([len(meta_list)])
        corr_sigs = (df_corr>0.65).sum(axis=1).loc[meta_list]
        n_uncorr_sigs = pd.Series([(corr_sigs==0).sum()])
        found = df_corr.loc[(df_corr>0.65).sum(axis=1)>0].idxmax(axis=1).unique()
        found_sigs = pd.Series(np.zeros(len(knownsigs)),index=list(knownsigs.keys()))
        found_sigs.loc[found] = 1
        max_corr = df_corr.max().loc[list(knownsigs.keys())]

        res = pd.concat([n_metasigs, n_uncorr_sigs, corr_sigs, found_sigs, max_corr])
        res.index = ["n_metasigs","n_uncorr_sigs"]+meta_list+list(knownsigs.keys())+list(np.char.add(list(knownsigs.keys()),'_corr'))
        res.name = dataset_name
        all_res.append(res)

In [None]:
datasets = ["be_1","be_2","highdiff_1","highdiff_2","lowdiff_1","lowdiff_2","morecells_1","morecells_2","highcnv_1","highcnv_2","smalldataset_1","smalldataset_2"]

In [None]:
full_res_df = pd.concat(all_res,axis=1).loc[:,datasets]

In [None]:
fig, ax = plt.subplots(1,1,figsize=(3,1))
sns.heatmap(full_res_df.loc[["program1","program2","program3"]], 
            cmap=ListedColormap(['blue', 'red']), linewidths=2,
            ax=ax,cbar=False)
ax.set_yticklabels(["State 1","State 2","State 3"],rotation="0",verticalalignment="center")
ax.set_xticklabels(ax.get_xticklabels(),rotation="45",horizontalalignment="right")

In [None]:
fig.savefig("figures/simulated_found_program_indicator.svg",bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1,1,figsize=(3,0.3))
sns.heatmap(full_res_df.loc[["n_uncorr_sigs"]], 
            cmap=ListedColormap(['red', 'blue']), linewidths=2,
            ax=ax,cbar=False, annot=True)
ax.set_yticklabels(["N uncorr sig."],rotation="0",verticalalignment="center")
ax.set_xticklabels(ax.get_xticklabels(),rotation="45",horizontalalignment="right")

In [None]:
fig.savefig("figures/simulated_uncorr_program_indicator.svg",bbox_inches="tight")