In [None]:
################################################################
### public scRNA data makes to anndata
# python anndata
import scanpy as sc
import anndata
from scipy import io
from scipy.sparse import coo_matrix, csr_matrix
import numpy as np 
import os

##load
#load.spares matrix
X = io.mmread("GSE151974_matrix.mtx")

#create anndata object
adata = anndata.AnnData(X=X.transpose(). tocsr())

#load cell metadata
cell_meta = pd.read_csv("GSE151974_barcodes.csv")

#load gene names
with open("GSE151974_features.csv", "r") as f:
  gene_names = f.read().splitlines()

#set anndata observations and index obs by barcodes, var- by gene names
adata.obs = cell_meta
adata.obs.index = adata.obs['barcode']
adata.var.index = gene_names 

#save dataset as anndata format h5ad
adata.write("GSE151974.h5ad")


#################################################################

import scanpy as sc
import os
atg7wt = sc.read_h5ad('atg7wt.h5ad')
atg7ko = sc.read_h5ad('atg7ko.h5ad')
atgwt.obs

import pickle

# save
with open('atg7wt.pickle','wb') as f:
   pickle.dump(atgwt,f)
   
# cell2location python   pickle load
conda env list
conda activate cell2location
import scanpy as sc
import os
os.getcwd()

# load
with open('atg7wt.pickle', 'rb') as f:
  atg7wt=pickle.load(f)
  
  
atg7wt.obs
atg7wt

with open('ATG7_ST.pickle', 'rb') as f:
  atg7=pickle.load(f)
  


################################################################
##### cell2location 

conda activate python-3.10

### load packages
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import cell2location as c2l

### load Visium data
# load Visium data  ; adata_vis = sc.read_visium('/path/')
adata_wt=sc.read_visium('/ATG7_WT_Mav1_CR_23_15193_FP_R_SSV_1/Spatial_matrix/',
                        count_file='filtered_feature_bc_matrix.h5',
                        library_id='ATG7_WT',)
adata_wt.var_names_make_unique()


adata_ko=sc.read_visium('/ATG7_KO_MAV-3_CR_23_15194_FP_R_SSV_1/Spatial_matrix/',
                        count_file='filtered_feature_bc_matrix.h5',
                        library_id='ATG7_KO',)
adata_ko.var_names_make_unique()

adata_wt.obs['library_id']='ATG7_WT'
adata_ko.obs['library_id']='ATG7_KO'

adata_wt.obs['Cell_barcodes'] = adata_wt.obs_names
adata_wt.obs.index="ATG7_WT" + adata_wt.obs.index

adata_ko.obs['Cell_barcodes'] = adata_ko.obs_names
adata_ko.obs.index="ATG7_WT" + adata_ko.obs.index


library_names=["ATG7_WT", "ATG7_KO"]
adata_st = adata_wt.concatenate(
  adata_ko,
  batch_key="library_id",
  uns_merge="unique",
  batch_categories=library_names
)

del(adata_wt, adata_ko)

# rename genes to ENSEMBL ID
adata_st.var['feature_name'] = adata_st.var_names,copy() # adata_st.var.set_index('gene_ids', drop=True, inplace=True)

sc.pl.spatial(adata_st, color='PTPRC', gene_symbols='feature_name')

# find mitochondrial (MT) genes
adata_st.var['MT_gene'] = [gene.startswith('MT-') for gene in adata_st.var['feature_name']]

# remove MT genes for spatial mapping (keeping their counts in the object)

adata_st.obsm['MT'] = adata_st[:, adata_st.var['MT_gene'].values].X.toarray()
adata_st = adata_st[:, ~adata_st.var['MT_gene'].values]


#################################################################
### load scRNA as refernce
adata_sc = sc.read(f'/GSE151974.h5ad',)

# rename genes to ENSEMBL ID
adata_sc.var['feature_name'] = adata_sc.var.index
adata_sc.obs['celltype1'].value_counts()
adata_sc.obs['celltype2'].value_counts()
adata_sc.layers["counts"]=adata_sc.X.copy()

# shared features
shared_features =[
  features for features in adata_st.var_names if features in adata_sc.var_names
]

adata_st = adata_st[:, shared_features].copy()
adata_sc = adata_sc[:, shared_features].copy()

# cutoff
...default parameters ;
cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12

selected = c2l.utils.filtering.filter_genes(
  adata_sc, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)

# filter the object
adata_sc = adata_sc[:, selected].copy()
adata_st = adata_st[:, selected].copy()

# celltype
adata_sc.obs['celltype1']=adata_sc.obs['celltype1'].astype('str')
adata_sc.obs['celltype1']=adata_sc.obs['celltype1'].astype('category')

adata_sc.obs['CellType']=adata_sc.obs['CellType'].astype('str')
adata_sc.obs['CellType']=adata_sc.obs['CellType'].astype('category')

print(adata_sc.obs['celltype1'].astype('category').cat.categories)
print(adata_sc.obs['CellType'].astype('category').cat.categories)

adata_sc.obs['celltype1'].value_counts()
adata_sc.obs['celltype2'].value_counts()
adata_sc.obs['CellType'].value_counts()


#################################################################
### Set the number of threads to the number of CPU cores you want to use
# Training GPU  -> False

import torch
import os

num_cores = 7
os.environ["OMP_NUM_THREADS"] = str(num_cores)
os.environ["MKL_NUM_THREADS"] = str(num_cores)
os.environ["NUMEXPR_NUM_THREADS"] = str(num_cores)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(num_cores)

print(f"MPS 장치를 지원하도록 build가 되었는가? {torch.backends.mps.is_built()}")
print(f"MPS 장치가 사용 가능한가? {torch.backends.mps.is_available()}") 

device = torch.device("mps")


#################################################################
### Estimation of reference cell type signatures (NB regression)

# prepare anndata for the regression model
c2l.models.RegressionModel.setup_anndata(
  adata=adata_sc,
  batch_key ="Oxygen",
  labels_key='CellType',
  categorical_covariate_keys=['CellType'],
  layer="counts"
)

adata_sc.obs['CellType']

# create the regression model
model = c2l.models.RegressionModel(adata_sc)
model.view_anndata_setup()

# default, try on GPU:
use_gpu = False #has GPU use_gpu = True
model.train(max_epochs=250, batch_size=2500, train_size=1, lr=0.002, use_gpu=use_gpu) # time spend

# model plot
model.plot_history(20)
model.plot_QC()

model.export_posterior(
    adata_sc,
    sample_kwargs={"num_samples": 1000, "batch_size": 2500, "use_gpu": use_gpu},
)


# export estimated gene expression in each cluster
if "means_per_cluster_mu_fg" in adata_sc.varm.keys():
    inf_aver = adata_sc.varm["means_per_cluster_mu_fg"][
        [f"means_per_cluster_mu_fg_{i}" for i in adata_sc.uns["mod"]["factor_names"]]
    ].copy()
else:
    inf_aver = adata_sc.var[
        [f"means_per_cluster_mu_fg_{i}" for i in adata_sc.uns["mod"]["factor_names"]]
    ].copy()

inf_aver.columns = adata_sc.uns["mod"]["factor_names"]
inf_aver.head()

inf_aver.to_csv("inf_aver.csv")

#################################################################
### Cell2location: spatial mapping
c2l.models.Cell2location.setup_anndata(
    adata=adata_st,
    batch_key="library_id",
)

model = c2l.models.Cell2location(
    adata_st,
    cell_state_df=inf_aver,
    N_cells_per_location=8,
)

model.view_anndata_setup()

model.train(max_epochs=30000, batch_size=None, train_size=1, use_gpu=False)  # time spend

# model plot
model.plot_history()
model.plot_QC()   
    
adata_st = model.export_posterior(
    adata_st,
    sample_kwargs={
        "num_samples": 1000,
        "batch_size": model.adata.n_obs,
        "use_gpu": False,
    },
)   



# add adata_sc celltype to adata_st
adata_st.obs[adata_st.uns["mod"]["factor_names"]] = adata_st.obsm[
    "q05_cell_abundance_w_sf"
]

adata_sc.obs['CellType'].value_counts()
adata_sc.obs['celltype1'].value_counts()

adata_st.obs['library_id'].value_counts()
adata_st.obs.columns



### save
# adata_st, adata_sc save / load
adata_st.write('./obj/adata_st_cell2loc.h5ad')  # save
adata_sc.write('./obj/adata_sc_cell2loc.h5ad')  # save

adata_st=sc.read_h5ad('./obj/adata_st_cell2loc.h5ad') #load

# model save / load
model.save(f"./obj/", overwrite=True)  # name as model.pt
model = torch.load('./obj/model.pt')



### ref 
# save model with dill
import dill 
with open('model_cell2loc.pickle', 'wb') as fw:
    dill.dump(model, fw)
    
# load model with dill
import dill
with open('./model_cell2loc.pickle', 'rb') as f:
    model=dill.load(f)

# save model with pickle
with open('model_cell2loc.pickle', 'wb') as fw:
    pickle.dump(model, fw)

# load model with pickle
with open('./model_cell2loc.pickle', 'rb') a f:
    model=pickle.load(f)



####################################################
### Visualizing cell abundance in spatial coordinates
# select one slide for visualization
slide = c2l.utils.select_slide(adata_st, "ATG7_WT", batch_key="library_id")
slide.obsm['spatial'] = slide.obsm['spatial'].astype(float)

with mpl.rc_context({"figure.figsize": [8, 8]}):
    sc.pl.spatial(
        slide,
        cmap="magma",
        color=adata_st.uns["mod"]["factor_names"],
        ncols= 9,
        size=1,
        img_key="hires",
        # limit color scale at 99.2% quantile of cell abundance
        vmin=0,
        vmax="p99.2",
    )


slide2 = c2l.utils.select_slide(adata_st, "ATG7_KO", batch_key="library_id")
slide2.obsm['spatial'] = slide2.obsm['spatial'].astype(float)

with mpl.rc_context({"figure.figsize": [4.5, 5]}):
    sc.pl.spatial(
        slide2,
        cmap="magma",
        color=adata_st.uns["mod"]["factor_names"], 
        ncols= 9,
        size=1,
        img_key="hires",
        # limit color scale at 99.2% quantile of cell abundance
        vmin=0,
        vmax="p99.2",
    )

adata_st.obs
adata_sc



#select up to clusters

from cell2location.plt import plot_spatial

clust_labels = ["CD4 T cell 1", "CD4 T cell 2", "CD8 T cell 1", "CD8 T cell 2" ,"gd T cell", "ILC2", "NK cell"]
clust_labels = ["Mono", "Neut 1","Neut 2","Mast Ba2"]
clust_labels = ["DC1","DC2","Int Mf","Alv Mf"]


clust_labels = ["CD4 T cell 1", "CD4 T cell 2", "CD8 T cell 1", "CD8 T cell 2" ,"gd T cell"]
clust_labels = ["CD4 T cell 1", "CD4 T cell 2", "CD8 T cell 1", "CD8 T cell 2" ,"gd T cell"]

               
clust_col = ['' +str(i) for i in clust_labels]

slide = c2l.utils.select_slide(adata_st, "ATG7_WT", batch_key="library_id")
slide.obsm['spatial'] = slide.obsm['spatial'].astype(float)

slide2 = c2l.utils.select_slide(adata_st, "ATG7_KO", batch_key="library_id")
slide2.obsm['spatial'] = slide2.obsm['spatial'].astype(float)

with mpl.rc_context({"figure.figsize": [15,15]}):
    fig=plot_spatial(
        adata=slide,
        color=clust_col, labels = clust_labels,
        show_img=True, style='fast', max_color_quantile=0.992,
        circle_diameter=6, colorbar_position="right")

with mpl.rc_context({"figure.figsize": [15,15]}):
    fig=plot_spatial(
        adata=slide2,
        color=clust_col, labels = clust_labels,
        show_img=True, style='fast', max_color_quantile=0.992,
        circle_diameter=6, colorbar_position="right")



###############################################################################
# NMF

results_folder = 'result/'

# create paths and names to results folders for reference regression and cell2location models
ref_run_name = f'{results_folder}/reference_signatures'
run_name = f'{results_folder}/cell2location_map'


columns_to_convert = ['in_tissue', 'array_row', 'array_col']
for col in columns_to_convert:
    adata_st.obs[col] = adata_st.obs[col].astype(str)

adata_st2=adata_st.copy()
non_numeric = 'library_id', 'Cell_barcodes'
adata_st2.obs = adata_st2.obs.drop([non_numeric], errors = 'ignore')

adata_st2.obs=df = adata_st2.obs.drop('nan', axis=1)


index=cell2loc.uns["mod"]["factor_names"].index('RAS_SCGB3A2+')
cell2loc.uns["mod"]["factor_names"][index]='Epi_RAS'
cell2loc.obs.rename(columns={'RAS_SCGB3A2+': 'Epi_RAS'}, inplace=True)


# 'factor_names' 키가 'nan' 값을 가지고 있다면 삭제
if adata_st2.uns["mod"]["factor_names"] == 'nan':
    del adata_st2.uns["mod"]["factor_names"]

index=adata_st2.uns["mod"]["factor_names"].index('RAS_SCGB3A2+')
adata_st2.uns["mod"]["factor_names"][index]='Epi_RAS'
adata_st2.obs.rename(columns={'RAS_SCGB3A2+': 'Epi_RAS'}, inplace=True)

res_dict



from cell2location import run_colocation
res_dict, adata_vis = run_colocation(
    adata_st2,
    model_name='CoLocatedGroupsSklearnNMF',
    train_args={
      'n_fact': np.arange(5, 15), # IMPORTANT: use a wider range of the number of factors (5-30)
      'sample_name_col': 'library_id', # columns in adata_vis.obs that identifies sample
      'n_restarts': 3 # number of training restarts
    },
    # the hyperparameters of NMF can be also adjusted:
    model_kwargs={'alpha': 0.01, 'init': 'random', "nmf_kwd_args": {"tol": 0.000001}},
    export_args={'path': f'{run_name}/CoLocatedComb/'}
)

