In [None]:
import sys; sys.path.append("../resources/")
from dotplot_utils import *
import pandas as pd
import seaborn as sns
# make this notebook work better with Scanpy
import warnings; warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
# make output directories
import os
if not os.path.exists("ST_out"):
    os.mkdir("ST_out")

if not os.path.exists("ST_out/plots_overview/"):
    os.mkdir("ST_out/plots_overview/")

In [None]:
sc.set_figure_params(transparent=True, dpi_save=400)
sc.settings.figdir = "ST_out/plots_overview/"

---
## Read in NMF scRNA-seq reference to check for genes in cell state loadings

In [None]:
a_sc = sc.read("../data/scRNA/VUMC_NMF/VUMC_NMF_k30_dt0_1.h5ad"); a_sc

In [None]:
a_sc.var_names[a_sc.var_names.isin([
    "DPEP1",
    "NT5E",
    "IL6",
    "ITGB1",
    "ITGA2",
    "ITGA11",
    "ITGAV",
    "ITGB5",
    "ITGB3",
    "THBS1",
    "CD47",
    "CD36",
])]

In [None]:
rename_refNMF = {
    "usage_1":"STM",
    "usage_2":"END1",
    "usage_3":"BL1",
    "usage_4":"FIB1",
    "usage_5":"CRC1",
    "usage_6":"MYE1",
    "usage_7":"TL1",
    "usage_8":"MYE2",
    "usage_9":"CRC2",
    "usage_10":"CT",
    "usage_11":"SSC",
    "usage_12":"CRC3",
    "usage_13":"EE1",
    "usage_14":"MYE3",
    "usage_15":"PLA",
    "usage_16":"FIB2",
    "usage_17":"MYE4",
    "usage_18":"GOB",
    "usage_19":"MAS",
    "usage_20":"MYE5",
    "usage_21":"CRC4",
    "usage_22":"ABS",
    "usage_23":"TUF",
    "usage_24":"FIB3",
    "usage_25":"FIB4",
    "usage_26":"TL2",
    "usage_27":"END2",
    "usage_28":"TL3",
    "usage_29":"EE2",
    "usage_30":"BL2",
}

---

In [None]:
from cNMF.cnmf import cnmf_markers

cnmf_markers(
    adata=a_sc,
    n_genes=300,
    spectra_score_file="../data/scRNA/VUMC_NMF/VUMC_NMF.gene_spectra_score.k_30.dt_0_1.txt",
)
markers = a_sc.uns["cnmf_markers"].copy()
markers.columns = list(rename_refNMF.values())

In [None]:
markers.head()

In [None]:
import kitchen.ingredients as k

In [None]:
p = k.rank_genes_cnmf(comb_k, ncols=2, n_points=24, titles=list(comb_k.obs.columns[comb_k.obs.columns.str.startswith("usage_")]))
#plt.savefig("cnmf_spectra_test.png")

---

In [None]:
from collections import Counter

def locate_genes_in_loadings(markers, genes):
    counts = []
    for gene in genes:
        cols = markers.columns[(markers.values==gene).any(0)].tolist()
        if len(cols) > 0:
            print("{} detected in {}".format(gene, cols))
            counts.extend(cols)
            for col in cols:
                i = markers.loc[markers[col] == gene].index[0]
                print("\t{} ranking = {}".format(col, i))
        else:
            print("{} not detected".format(gene))
    print("\n{}".format(Counter(counts)))

In [None]:
locate_genes_in_loadings(markers, [
    "DPEP1",
    "DDR1",
    "PAK4",
    "TGFBI",
    "PMCA4b",
    "RNLS",
])

---
## Read in ST data

In [None]:
sample_key = pd.read_csv("../resources/ST/visium_sample_key.csv", index_col=0)

In [None]:
%%time
outs = []
for s in sample_key.index:
    a = sc.read("../data/ST/{}_master.h5ad".format(s))
    print("Read adata from ../data/ST/{}_master.h5ad".format(s))

    # compile training data
    outs.append(a)

assert len(outs) == len(sample_key), "Check length of outs"

In [None]:
%%time
# concatenate anndata objects
a_comb = outs[0].concatenate(
    outs[1:],
    join="outer",
    batch_categories=list(sample_key.index),
    fill_value=0,
)
del a_comb.obsm
del a_comb.var

---
## Rename things and set up for plotting

In [None]:
# rename cell state columns in master anndata.obs
et = list(a_comb.obs.columns[a_comb.obs.columns.str.endswith("_VUMCrefNMF30")])
et_new = [x.replace("_VUMCrefNMF30", "") for x in et]
a_comb.obs.rename(columns=dict(zip(et, et_new)), inplace=True)

In [None]:
cmap_dict = {
    # Tumor Type
    'SSL/HP':"#c4a4e1",'MSI-H':"#7a4fa3",'MSS':"#ffc101",'TA/TVA':"#fee799",'NL':"#1f77b4",
    # Tumor Location
    "Cecum":"#1f4e79","Ascending":"#2e74b7","Hepatic Flexure":"#bdd6ef","Transverse":"#ff717a","Descending":"#fe0001","Sigmoid":"#c00101",
    # this one's global
    "nan":"#ffffff",
    # These are black and white for T and F
    "T":"#000000","F":"#ffffff",
}
stage_colordict = dict(zip(["AD","I","II","III/IV"], sns.color_palette("Reds", len(["AD","I","II","III/IV"])).as_hex()))
grade_colordict = dict(zip(["G1","G2","G3"], sns.color_palette("Reds", len(["G1","G2","G3"])).as_hex()))
cmap_dict = {**cmap_dict, **stage_colordict, **grade_colordict}

In [None]:
patient_colordict = dict(zip(sample_key.patient_name, [cmap_dict[x] for x in sample_key.tumor_type]))

In [None]:
# define heatmap widths
milwrm_width = 4.1
tumor_type_width = 3.8
tumor_loc_width = 3.7
tumor_stage_width = 3.2
tumor_grade_width = 3.0
patient_width = 8.7
pathology_width = 4.0
evolution_width = 3.0
cnv_domain_width = 3.0

---
# Gene plots

In [None]:
a_comb.X = a_comb.X.todense()

In [None]:
a_comb.layers["raw_counts"] = a_comb.X.copy()

In [None]:
custom_dict = {
    "": [
        "DPEP1",
        "DDR1",
        "PAK4",
        "TGFBI",
        "RNLS",
    ],
    "MSS": ["iCMS2","Stem","IES"],
    "MSI-H": ["iCMS3","GOB","SSC","Metaplasia"],
    "NL": ["ABS","CT"],
}
custom_dict_height = 4

In [None]:
for features_name, features_list, height in zip(
        ["custom_genes"],
        [custom_dict],
        [custom_dict_height],
    ):
    for group_name, group, width, groupby_order, groupby_colordict, in zip(
        ["tumortype", "tumorloc", "tumorstage", "tumorgrade", "patient", "pathology_annotation"],
        ["Tumor Type", "Tumor Location", "Tumor Stage", "Tumor Grade", "Patient", "pathology_annotation"],
        [tumor_type_width, tumor_loc_width, tumor_stage_width, tumor_grade_width, patient_width, pathology_width],
        [None, ["Cecum","Ascending","Hepatic Flexure","Transverse","Descending","Sigmoid"], ["NL","AD","I","II","III/IV"], ["NL","G1","G2","G3"], None, None],
        [cmap_dict, cmap_dict, cmap_dict, cmap_dict, patient_colordict, None],
    ):
        cody_heatmap(
            a_comb,
            groupby=group,
            features=sum(custom_dict.values(), []),
            cluster_vars=False,
            vars_dict=custom_dict,
            groupby_order=groupby_order,
            groupby_colordict=groupby_colordict,
            cluster_obs=True if groupby_order is None else False,
            figsize=(width, height),
            save="ST_out/plots_overview/{}_{}_dotplot.png".format(group_name, features_name),
            dpi=400,
            cmap="Greys",
            size_title="Fraction of spots\nin group (%)",
        )