In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import muon as mu
import networkx as nx
import scvi
import glob
import scglue
import os
from scipy import sparse
import matplotlib.pyplot as plt
import seaborn as sns

sc.settings.verbosity = 3
sc.settings.n_jobs = 8
scvi.train.Trainer(accelerator = 'mps')
%config InlineBackend.figure_format='retina'

# scGlue model training

In [None]:
TG_rna = sc.read_h5ad("../TG_data/Results/TG_rna_pp.h5ad")
TG_atac = sc.read_h5ad("../TG_data/Results/TG_atac_pp.h5ad")
guidance = nx.read_graphml("../TG_data/Results/guidance.graphml.gz")

In [None]:
scglue.models.configure_dataset(
    TG_rna, "NB", use_highly_variable=True,
    use_layer="counts", use_rep="X_scANVI", use_batch="Source", use_cell_type='subtype'
)

In [None]:
scglue.models.configure_dataset(
    TG_atac, "NB", use_highly_variable=True,
    use_rep="X_lsi", use_cell_type='subtype'
)

In [None]:
from itertools import chain
guidance_hvf = guidance.subgraph(chain(
    TG_rna.var.query("highly_variable").index,
    TG_atac.var.query("highly_variable").index
)).copy()

In [None]:
glue = scglue.models.fit_SCGLUE(
    {"rna": TG_rna, "atac": TG_atac}, guidance_hvf,
    fit_kws={"directory": "glue"}
)
glue.save("glue.dill")

In [None]:
glue = scglue.models.load_model("../TG_data/Results/scglue_model/glue.dill")

In [None]:
# # Get consistency scores and save to file for later use
dx = scglue.models.integration_consistency(
    glue, {"rna": TG_rna, "atac": TG_atac}, guidance_hvf
)

In [None]:
# _ = sns.lineplot(x="n_meta", y="consistency", data=dx).axhline(y=0.05, c="darkred", ls="--")

In [None]:
# Get cell embeddings
TG_rna.obsm["X_glue"] = glue.encode_data("rna", TG_rna)
TG_atac.obsm["X_glue"] = glue.encode_data("atac", TG_atac)

In [None]:
TG_combined = sc.concat([TG_rna, TG_atac])

In [None]:
TG_combined = sc.read_h5ad("../TG_data/Results/TG_combined.h5ad")

In [None]:
# Get feature embeddings
feature_embeddings = glue.encode_graph(guidance_hvf)
feature_embeddings = pd.DataFrame(feature_embeddings, index=glue.vertices)

In [None]:
TG_mdata = mu.MuData({"rna": TG_rna, "atac": TG_atac})

In [None]:
TG_mdata.obsm["X_glue"] = TG_combined[TG_mdata.obs_names].obsm["X_glue"].copy()

In [None]:
TG_mdata['rna'].varm["X_glue"] = feature_embeddings.reindex(TG_rna.var_names).to_numpy()
TG_mdata['atac'].varm["X_glue"] = feature_embeddings.reindex(TG_atac.var_names).to_numpy()

In [None]:
TG_mdata.obs[['Tissue', 'Conditions', 'Source', 'Technology', 'Strains', 'Sex', 'Time', 'subtype', 'scANVI_pred', 'balancing_weight']] = TG_combined.obs[['Tissue', 'Conditions', 'Source', 'Technology', 'Strains', 'Sex', 'Time', 'subtype', 'scANVI_pred', 'balancing_weight']]

In [None]:
TG_mdata.obs['modality'] = 'expression'
TG_mdata.obs.modality[TG_mdata.obs.Source == 'Renthal_ATAC'] = 'peak'

In [None]:
TG_mdata.obs["MetaType"] = TG_mdata.obs.scANVI_pred.astype(str).copy()
TG_mdata.obs.MetaType = TG_mdata.obs.MetaType.replace(['NF1','NF2','NF3','NP','PEP1','PEP2','SST','cLTMR1','p_cLTMR2'], 'Neurons')
TG_mdata.obs["NociceptiveType"] = TG_mdata.obs.scANVI_pred.astype(str).copy()
TG_mdata.obs.NociceptiveType = TG_mdata.obs.NociceptiveType.replace(['NP','PEP1','PEP2','SST','cLTMR1','p_cLTMR2'], 'Nociceptors')

In [None]:
sc.pp.neighbors(TG_mdata, use_rep="X_glue", metric="cosine")
sc.tl.umap(TG_mdata)

In [None]:
TG_mdata.obs.Technology = TG_mdata.obs.Technology.astype(str).copy()
TG_mdata.obs.Technology[TG_mdata.obs.Source == 'Renthal_ATAC'] = '10x_ATAC'
TG_mdata.obs.Technology[TG_mdata.obs.Source == 'Renthal_RNA'] = '10x_RNA'

In [None]:
sc.set_figure_params(dpi=100, figsize=(5,5), frameon=False)
sc.pl.umap(TG_mdata, color=["scANVI_pred"], size=3)
sc.pl.umap(TG_mdata, color=["MetaType"], size=3)
sc.pl.umap(TG_mdata, color=["NociceptiveType"], size=3)
sc.pl.umap(TG_mdata, color=["modality"], size=3)
sc.pl.umap(TG_mdata, color=["Source"], size=3)
sc.pl.umap(TG_mdata, color=["Technology"], size=3)

In [None]:
TG_mdata.obs['PainState'] = TG_mdata.obs.Conditions.map(lambda x: 'noPain' if x in ['Naive', 'PBS'] else 'Pain')
TG_mdata.obs['PainState'] = TG_mdata.obs['PainState'].astype('category').cat.reorder_categories(['noPain', 'Pain'])

In [None]:
TG_mdata.write("../TG_data/Results/TG_mdata.h5mu", compression="gzip")

In [None]:
nx.write_graphml(guidance_hvf, "../TG_data/Results/guidance_hvf.graphml.gz")