In [1]:
import os

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=DeprecationWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

In [2]:
import sys
import scanpy as sc
import numpy as np
import pandas as pd
import scarches as sca
import anndata as ad
from scipy import sparse
import gdown
import gzip
import shutil
import urllib.request

Global seed set to 0
 captum (see https://github.com/pytorch/captum).


In [3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))

In [4]:
ref_model_dir_prefix = "/bmbl_data/chenghao/reference/HLCA_reference"  # directory in which to store the reference model directory

# url = "https://zenodo.org/record/7599104/files/HLCA_reference_model.zip"
# output = "HLCA_reference_model.zip"
# gdown.download(url, output, quiet=False)
# shutil.unpack_archive("HLCA_reference_model.zip", extract_dir=ref_model_dir_prefix)
# os.remove(output)

In [5]:
url = "https://zenodo.org/record/7599104/files/HLCA_full_v1.1_emb.h5ad"
path_reference_emb = (
    "/bmbl_data/chenghao/reference/HLCA_emb_and_metadata.h5ad"  # path to reference embedding to be created
)
# output = path_reference_emb
# gdown.download(url, output, quiet=False)

In [6]:
adata_ref = sc.read_h5ad(path_reference_emb)

In [7]:
# subset
adata_ref = adata_ref = adata_ref[adata_ref.obs.core_or_extension == "core", :].copy()
# remove all obs variables that have no entries anymore (i.e. obs columns that were only relevant for the HLCA extension)
cols_to_drop = [
    col for col in adata_ref.obs.columns if adata_ref.obs[col].isnull().all()
]
adata_ref.obs.drop(columns=cols_to_drop, inplace=True)

In [8]:
adata_query_unprep=sc.read_h5ad('/bmbl_data/chenghao/U54/data-0727.h5ad')
adata_query_unprep

AnnData object with n_obs × n_vars = 15352 × 18082
    obs: 'sample'
    var: 'gene_ids', 'feature_types'

In [1]:
adata_query_unprep.obs['sample'].value_counts()

NameError: name 'adata_query_unprep' is not defined

In [10]:
raw_data=sc.read_csv("/bmbl_data/chenghao/sencell/fixRNA_counts.csv")

In [12]:
adata_query_unprep.X=raw_data.X.T

In [13]:
adata_query_unprep.X = sparse.csr_matrix(adata_query_unprep.X)

In [14]:
del adata_query_unprep.obsm
del adata_query_unprep.varm

In [10]:
ref_model_features = pd.read_csv(
    os.path.join("/bmbl_data/chenghao/reference/HLCA_reference/HLCA_reference_model", "var_names.csv"), header=None
)

In [11]:
ref_path="/bmbl_data/chenghao/reference/HLCA_reference/HLCA_reference_model"
path_gene_mapping_df = os.path.join(ref_path, "HLCA_reference_model_gene_order_ids_and_symbols.csv")
# # Download gene information from HLCA github:
url = "https://zenodo.org/record/7599104/files/HLCA_reference_model_gene_order_ids_and_symbols.csv" 
# gdown.download(url, path_gene_mapping_df, quiet=False)

In [12]:
import pandas as pd
gene_id_to_gene_name_df = pd.read_csv(path_gene_mapping_df, index_col=0)
adata_query_unprep.var["gene_names"] = adata_query_unprep.var.index
gene_name_column_name = "gene_names"
n_overlap = (
    adata_query_unprep.var[gene_name_column_name]
    .isin(gene_id_to_gene_name_df.gene_symbol)
    .sum()
)
n_genes_model = gene_id_to_gene_name_df.shape[0]
print(
    f"Number of model input genes detected: {n_overlap} out of {n_genes_model} ({round(n_overlap/n_genes_model*100)}%)"
)

Number of model input genes detected: 1731 out of 2000 (87%)


In [13]:
adata_query_unprep = adata_query_unprep[
    :,
    adata_query_unprep.var[gene_name_column_name].isin(
        gene_id_to_gene_name_df.gene_symbol
    ),
].copy()  # subset your data to genes used in the reference model
adata_query_unprep.var.index = adata_query_unprep.var[gene_name_column_name].map(
    dict(zip(gene_id_to_gene_name_df.gene_symbol, gene_id_to_gene_name_df.index))
)  # add gene ids for the gene names, and store in .var.index
# remove index name to prevent bugs later on
adata_query_unprep.var.index.name = None
adata_query_unprep.var["gene_ids"] = adata_query_unprep.var.index

In [14]:
adata_query_unprep.shape

(15352, 1731)

In [15]:
import anndata as ad
from scipy import sparse

def sum_by(adata: ad.AnnData, col: str) -> ad.AnnData:
    adata.strings_to_categoricals()
    adata.obs[col] = adata.obs[col].astype('category')

    assert pd.api.types.is_categorical_dtype(adata.obs[col])

    cat = adata.obs[col].values
    indicator = sparse.coo_matrix(
        (np.broadcast_to(True, adata.n_obs), (cat.codes, np.arange(adata.n_obs))),
        shape=(len(cat.categories), adata.n_obs),
    )

    return ad.AnnData(
        indicator @ adata.X, var=adata.var, obs=pd.DataFrame(index=cat.categories)
    )


adata_query_unprep = sum_by(adata_query_unprep.transpose(), col="gene_ids").transpose()

In [16]:
adata_query_unprep.shape

(15352, 1731)

In [17]:
adata_query_unprep.var = adata_query_unprep.var.join(gene_id_to_gene_name_df).rename(columns={"gene_symbol":"gene_names"})

In [18]:
adata_query_unprep.obs_names.is_unique

True

In [19]:
adata_query_unprep

AnnData object with n_obs × n_vars = 15352 × 1731
    obs: 'sample'
    var: 'gene_names'

In [20]:
adata_query = sca.models.SCANVI.prepare_query_anndata(
    adata=adata_query_unprep, reference_model=ref_path, inplace=False
)

[34mINFO    [0m File [35m/bmbl_data/chenghao/reference/HLCA_reference/HLCA_reference_model/[0m[95mmodel.pt[0m already downloaded        
[34mINFO    [0m Found [1;36m86.55000000000001[0m% reference vars in query data.                                                    


In [21]:
adata_query_unprep

AnnData object with n_obs × n_vars = 15352 × 1731
    obs: 'sample'
    var: 'gene_names'

In [22]:
adata_query.obs["dataset"] = "batch1"

In [23]:
adata_query

AnnData object with n_obs × n_vars = 15352 × 2000
    obs: 'sample', 'dataset'
    var: 'gene_names'

In [24]:
surgery_model = sca.models.SCANVI.load_query_data(
    adata_query,
    ref_path,
    freeze_dropout=True,
)

[34mINFO    [0m File [35m/bmbl_data/chenghao/reference/HLCA_reference/HLCA_reference_model/[0m[95mmodel.pt[0m already downloaded        




In [25]:
adata_query.obs["scanvi_label"] = "unlabeled"


In [26]:
surgery_model = sca.models.SCANVI.load_query_data(
    adata_query,
    ref_path,
    freeze_dropout=True,
)

[34mINFO    [0m File [35m/bmbl_data/chenghao/reference/HLCA_reference/HLCA_reference_model/[0m[95mmodel.pt[0m already downloaded        


In [27]:
surgery_epochs = 50
early_stopping_kwargs_surgery = {
    "early_stopping_monitor": "elbo_train",
    "early_stopping_patience": 10,
    "early_stopping_min_delta": 0.001,
    "plan_kwargs": {"weight_decay": 0.0},
}

In [28]:
import torch
torch.cuda.is_available()

True

In [29]:
surgery_model.train(max_epochs=surgery_epochs, use_gpu =True, **early_stopping_kwargs_surgery)

[34mINFO    [0m Training for [1;36m50[0m epochs.                                                                                   


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Epoch 50/50: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:01<00:00,  1.22s/it, loss=874, v_num=1]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 50/50: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:01<00:00,  1.23s/it, loss=874, v_num=1]


In [30]:
adata_query_latent = sc.AnnData(surgery_model.get_latent_representation(adata_query))
adata_query_latent.obs = adata_query.obs.loc[adata_query.obs.index, :]


In [31]:
path_celltypes = os.path.join(ref_path, "HLCA_celltypes_ordered.csv")
url = "https://github.com/LungCellAtlas/HLCA_reproducibility/raw/main/supporting_files/celltype_structure_and_colors/manual_anns_and_leveled_anns_ordered.csv" # "https://github.com/LungCellAtlas/mapping_data_to_the_HLCA/raw/main/supporting_files/HLCA_celltypes_ordered.csv"
# gdown.download(url, path_celltypes, quiet=False)

In [32]:
cts_ordered = pd.read_csv(path_celltypes, index_col=0).rename(
    columns={f"Level_{lev}": f"labtransf_ann_level_{lev}" for lev in range(1, 6)}
)

In [33]:
adata_ref.obs = adata_ref.obs.join(cts_ordered, on="ann_finest_level")


In [34]:
knn_transformer = sca.utils.knn.weighted_knn_trainer(
    train_adata=adata_ref,
    train_adata_emb="X",  # location of our joint embedding
    n_neighbors=50,
)

Weighted KNN with n_neighbors = 50 ... 

In [35]:
labels, uncert = sca.utils.knn.weighted_knn_transfer(
    query_adata=adata_query_latent,
    query_adata_emb="X",  # location of our embedding, query_adata.X in this case
    label_keys="ann_finest_level",  # (start of) obs column name(s) for which to transfer labels
    knn_model=knn_transformer,
    ref_adata_obs=adata_ref.obs,
)

finished!


In [36]:
cts_ordered

Unnamed: 0,labtransf_ann_level_1,labtransf_ann_level_2,labtransf_ann_level_3,labtransf_ann_level_4,labtransf_ann_level_5,ordering,colors
Basal resting,Epithelial,Airway epithelium,Basal,Basal resting,4_Basal resting,3,#FFFF00
Suprabasal,Epithelial,Airway epithelium,Basal,Suprabasal,4_Suprabasal,4,#1CE6FF
Hillock-like,Epithelial,Airway epithelium,Basal,Hillock-like,4_Hillock-like,11,#FF34FF
Deuterosomal,Epithelial,Airway epithelium,Multiciliated lineage,Deuterosomal,4_Deuterosomal,13,#FF4A46
Multiciliated (nasal),Epithelial,Airway epithelium,Multiciliated lineage,Multiciliated,Multiciliated (nasal),15,#008941
...,...,...,...,...,...,...,...
Monocyte-derived Mph,Immune,Myeloid,Macrophages,Interstitial macrophages,Monocyte-derived Mph,151,#FAD09F
Interstitial Mph perivascular,Immune,Myeloid,Macrophages,Interstitial macrophages,Interstitial Mph perivascular,155,#FF8A9A
Classical monocytes,Immune,Myeloid,Monocytes,Classical monocytes,4_Classical monocytes,160,#D157A0
Non-classical monocytes,Immune,Myeloid,Monocytes,Non-classical monocytes,4_Non-classical monocytes,163,#BEC459


In [37]:
adata_ref

AnnData object with n_obs × n_vars = 584944 × 30
    obs: 'suspension_type', 'donor_id', 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', "3'_or_5'", 'BMI', 'age_or_mean_of_age_range', 'age_range', 'anatomical_region_ccf_score', 'ann_coarse_for_GWAS_and_modeling', 'ann_finest_level', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'cause_of_death', 'core_or_extension', 'dataset', 'fresh_or_frozen', 'log10_total_counts', 'lung_condition', 'mixed_ancestry', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'reannotation_type', 'sample', 'scanvi_label', 'sequencing_platform', 'smoking_status', 'study', 'subject_type', 'tissue_coarse_unharmonized', 'tis

In [2]:
labels

NameError: name 'labels' is not defined

In [39]:
labels["ann_finest_level"].value_counts()

Alveolar fibroblasts             2833
AT2                              1679
CD4 T cells                      1100
Monocyte-derived Mph              904
EC general capillary              897
CD8 T cells                       698
Classical monocytes               673
Alveolar macrophages              632
Plasma cells                      548
Pericytes                         536
Peribronchial fibroblasts         465
Interstitial Mph perivascular     450
EC aerocyte capillary             413
AT1                               357
NK cells                          312
Adventitial fibroblasts           299
Smooth muscle                     270
EC arterial                       239
Multiciliated (non-nasal)         237
Myofibroblasts                    211
Mast cells                        199
EC venous pulmonary               190
Non-classical monocytes           175
B cells                           150
Basal resting                     148
pre-TB secretory                  147
Goblet (nasa

In [40]:
labels

Unnamed: 0,ann_finest_level
AAACAAGCAGCAATTGAAGTAGAG-1-OSU20126_Healthy_Old,Alveolar fibroblasts
AAACCAATCATGCGTCAAGTAGAG-1-OSU20126_Healthy_Old,Alveolar fibroblasts
AAACCAGGTAAAGCATAAGTAGAG-1-OSU20126_Healthy_Old,EC general capillary
AAACCAGGTATTACCAAAGTAGAG-1-OSU20126_Healthy_Old,EC general capillary
AAACCAGGTCAATTCAAAGTAGAG-1-OSU20126_Healthy_Old,Alveolar fibroblasts
...,...
TTTGGACGTTTGACTAATCATGTG-1-OSU10172_IPF_LL,Alveolar fibroblasts
TTTGGCGGTGTTTGCGATCATGTG-1-OSU10172_IPF_LL,Adventitial fibroblasts
TTTGGCGGTTTGCTCCATCATGTG-1-OSU10172_IPF_LL,Alveolar fibroblasts
TTTGTGAGTACAAAGTATCATGTG-1-OSU10172_IPF_LL,Non-classical monocytes


In [41]:
labels.to_csv("labels.csv")

In [50]:
adata_query_unprep

AnnData object with n_obs × n_vars = 49298 × 1758
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'percent.ribo', 'percent.mito', 'test_res.0.1', 'seurat_clusters', 'ann_finest_level'
    var: 'gene_names'

In [48]:
adata_query_unprep.obs['ann_finest_level']=labels['ann_finest_level']

In [49]:
adata_query_unprep.write("/bmbl_data/chenghao/sencell/fixRNA_annotated.h5ad")

In [None]:
import random
def umapPlot(embedding,cell_index_ls,clusters=None,reduce=False,labels=None,colored_celltypes=None):
    # if tensor: embedding should be .cpu().detach()
    # clusters: Nxt
    # t里面存的是行的index
    if reduce:
        reducer = umap.UMAP()
        embedding = reducer.fit_transform(embedding)
    
    plt.figure(figsize=(6,6),dpi=300)
    cmap1 = matplotlib.cm.get_cmap('tab20')        
    cmap2 = matplotlib.cm.get_cmap('Set3')  
    cmap3 = matplotlib.cm.get_cmap('Set1') 
    cmap4 = matplotlib.cm.get_cmap('Set2') 
    color_ls=cmap1.colors+cmap2.colors+cmap3.colors+cmap4.colors
    
#     cmap = plt.cm.get_cmap('tab20')  # or any other colormap

#     # Generate 50 colors from the colormap
#     color_ls = cmap(np.linspace(0, 1, len(clusters)))
    
    color_count=0
    h_ls=[]
    
    if clusters is None:
        plt.scatter(embedding[:,0],embedding[:,1],alpha=0.5,s=5)
    else:
        for i,(cluster,label) in enumerate(zip(clusters,labels)):
            plt.scatter(embedding[cluster,0],embedding[cluster,1],alpha=0.4,s=1,color=color_ls[color_count],zorder=0)
            color_count+=1
                    
    
    plt.xticks([])
    plt.yticks([])

    
umapPlot(adata_query_unprep.X.todense(),row_numbers,clusters=cluster_cell_ls,labels=celltype_names,colored_celltypes=colored_celltypes)


In [50]:
adata_query_unprep.X.todense().shape

(49327, 1758)

In [53]:
import umap

cell_matrix=np.asarray(adata_query_unprep.X.todense())
reducer = umap.UMAP()
embedding = reducer.fit_transform(cell_matrix)

ModuleNotFoundError: No module named 'utils'

In [36]:
adata_query_unprep = sc.read_h5ad('/bmbl_data/chenghao/combined_scRNAseq_20230928.h5ad')

In [39]:
adata_query_unprep.obs['ann_finest_level']=labels['ann_finest_level']

In [41]:
adata_query_unprep.write("/bmbl_data/chenghao/sencell/data/adenoma_annotated.h5ad")

In [37]:
adata_query

AnnData object with n_obs × n_vars = 66218 × 2000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'percent.ribo', 'percent.mito', 'feature.mad.higher', 'count.mad.lower', 'count.mad.higher', 'mito.mad.higher', 'ribo.mad.higher', 'integrated_snn_res.0.1', 'seurat_clusters', 'disease', 'location', 'celltype_seurat_transfer_ipf', 'dataset', '_scvi_batch', 'scanvi_label', '_scvi_labels'
    var: 'gene_names'
    uns: '_scvi_uuid', '_scvi_manager_uuid'

In [39]:
uncertainty_threshold = 0.2
labels.rename(
    columns={
        f"labtransf_ann_level_{lev}": f"ann_level_{lev}_transferred_label_unfiltered"
        for lev in range(1, 6)
    },
    inplace=True,
)
uncert.rename(
    columns={
        f"labtransf_ann_level_{lev}": f"ann_level_{lev}_transfer_uncert"
        for lev in range(1, 6)
    },
    inplace=True,
)

In [42]:
adata_query.obs = adata_query.obs.join(labels)
adata_query.obs = adata_query.obs.join(uncert)

In [43]:
for lev in range(1, 6):
    adata_query.obs[f"ann_level_{lev}_transferred_label"] = adata_query.obs[
        f"ann_level_{lev}_transferred_label_unfiltered"
    ].mask(
        adata_query.obs[f"ann_level_{lev}_transfer_uncert"] > uncertainty_threshold,
        "Unknown",
    )

In [44]:
print(
    f"Percentage of unknown per level, with uncertainty_threshold={uncertainty_threshold}:"
)
for level in range(1, 6):
    print(
        f"Level {level}: {np.round(sum(adata_query.obs[f'ann_level_{level}_transferred_label'] =='Unknown')/adata_query.n_obs*100,2)}%"
    )

Percentage of unknown per level, with uncertainty_threshold=0.2:
Level 1: 1.54%
Level 2: 3.17%
Level 3: 12.74%
Level 4: 27.61%
Level 5: 31.26%


In [48]:
adata_query.obs['ann_level_3_transferred_label'].to_csv("labels.csv")

In [46]:
adata_ref

AnnData object with n_obs × n_vars = 584944 × 30
    obs: 'suspension_type', 'donor_id', 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', "3'_or_5'", 'BMI', 'age_or_mean_of_age_range', 'age_range', 'anatomical_region_ccf_score', 'ann_coarse_for_GWAS_and_modeling', 'ann_finest_level', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'cause_of_death', 'core_or_extension', 'dataset', 'fresh_or_frozen', 'log10_total_counts', 'lung_condition', 'mixed_ancestry', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'reannotation_type', 'sample', 'scanvi_label', 'sequencing_platform', 'smoking_status', 'study', 'subject_type', 'tissue_coarse_unharmonized', 'tis

In [68]:
uncertainty_threshold = 0.2

labels.rename(
    columns={
        f"labtransf_ann_level_{lev}": f"ann_level_{lev}_transferred_label_unfiltered"
        for lev in range(1, 6)
    },
    inplace=True,
)
uncert.rename(
    columns={
        f"labtransf_ann_level_{lev}": f"ann_level_{lev}_transfer_uncert"
        for lev in range(1, 6)
    },
    inplace=True,
)

In [69]:
adata_query_latent.obs["ref_or_query"] = "query"
adata_ref.obs["ref_or_query"] = "ref"

combined_emb = sc.concat(
    (adata_ref, adata_query_latent), index_unique=None, join="outer"
)  # index_unique="_", batch_key="ref_or_query")

for cat in combined_emb.obs.columns:
    if isinstance(combined_emb.obs[cat].values, pd.Categorical):
        pass
    elif pd.api.types.is_float_dtype(combined_emb.obs[cat]):
        pass
    else:
        print(
            f"Setting obs column {cat} (not categorical neither float) to strings to prevent writing error."
        )
        combined_emb.obs[cat] = combined_emb.obs[cat].astype(str)

Setting obs column is_primary_data (not categorical neither float) to strings to prevent writing error.
Setting obs column ann_finest_level (not categorical neither float) to strings to prevent writing error.
Setting obs column dataset (not categorical neither float) to strings to prevent writing error.
Setting obs column scanvi_label (not categorical neither float) to strings to prevent writing error.
Setting obs column labtransf_ann_level_1 (not categorical neither float) to strings to prevent writing error.
Setting obs column labtransf_ann_level_2 (not categorical neither float) to strings to prevent writing error.
Setting obs column labtransf_ann_level_3 (not categorical neither float) to strings to prevent writing error.
Setting obs column labtransf_ann_level_4 (not categorical neither float) to strings to prevent writing error.
Setting obs column labtransf_ann_level_5 (not categorical neither float) to strings to prevent writing error.
Setting obs column colors (not categorical n

In [71]:
combined_emb.obs = combined_emb.obs.join(labels)
combined_emb.obs = combined_emb.obs.join(uncert)

In [72]:
for lev in range(1, 6):
    combined_emb.obs[f"ann_level_{lev}_transferred_label"] = combined_emb.obs[
        f"ann_level_{lev}_transferred_label_unfiltered"
    ].mask(
        combined_emb.obs[f"ann_level_{lev}_transfer_uncert"] > uncertainty_threshold,
        "Unknown",
    )

In [73]:
print(
    f"Percentage of unknown per level, with uncertainty_threshold={uncertainty_threshold}:"
)
for level in range(1, 6):
    print(
        f"Level {level}: {np.round(sum(combined_emb.obs[f'ann_level_{level}_transferred_label'] =='Unknown')/adata_query.n_obs*100,2)}%"
    )

Percentage of unknown per level, with uncertainty_threshold=0.2:
Level 1: 1.54%
Level 2: 3.17%
Level 3: 12.74%
Level 4: 27.61%
Level 5: 31.26%


In [None]:
sc.pp.neighbors(combined_emb, n_neighbors=30)
sc.tl.umap(combined_emb)

In [None]:
sc.pl.umap(combined_emb, color="ref_or_query", frameon=False, wspace=0.6)

In [None]:
adata_query_final = (
    adata_query_unprep.copy()
)  # copy the original query adata, including gene counts

adata_query_final.obsm["X_scarches_emb"] = adata_query_latent[
    adata_query_final.obs.index, :
].X  # copy over scArches/reference-based embedding

adata_query_final.var["gene_ids"] = adata_query_final.var.index
adata_query_final.var.index = adata_query_final.var.gene_names
adata_query_final.var.index.name = None

sc.pp.normalize_per_cell(adata_query_final, counts_per_cell_after=10000)
sc.pp.log1p(adata_query_final)

for col in combined_emb.obs.columns:
    if col.startswith("ann_level") and "transfer" in col:
        adata_query_final.obs[col] = combined_emb.obs.loc[
            adata_query_final.obs.index, col
        ]
        
        
        

In [None]:
sc.pp.neighbors(adata_query_final, use_rep="X_scarches_emb")
sc.tl.umap(adata_query_final)

In [2]:
1

1