In [None]:
import os
import tempfile
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc
import scvi
import anndata

In [None]:
save_dir = tempfile.TemporaryDirectory()
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

In [None]:
import os
import scanpy as sc
import pandas as pd

def read_xenium(path, sample_name):
    """
    Custom reader for Xenium data using only H5 and cell CSVs.

    Parameters
    ----------
    path : str
        Directory containing the Xenium files.
    sample_name : str
        Prefix for the files (e.g., 'KidneySample1').

    Returns
    -------
    AnnData
        AnnData object with spatial coordinates and metadata.
    """
    # File paths
    gene_counts_file = os.path.join(path, f"{sample_name}_cell_feature_matrix.h5")
    cell_metadata_file = os.path.join(path, f"{sample_name}_cells.csv")

    # Load gene counts and cell metadata
    adata = sc.read_10x_h5(filename=gene_counts_file)
    df = pd.read_csv(cell_metadata_file)
    
    # Align and attach metadata
    df.set_index(adata.obs_names, inplace=True)
    adata.obs = df.copy()

    # Add spatial coordinates
    adata.obsm["spatial"] = adata.obs[["x_centroid", "y_centroid"]].to_numpy()

    # Add sample label
    adata.obs["sample"] = sample_name
    #adata = adata[:,~adata.var_names.isin(genes_to_remove)].copy()
    return adata


In [None]:
base_dir = '/Volumes/Active/Xenium/Data/KID_final_Dataset/'
sample_list = ['xen4_3781','xen6_3781','xen10_3723','xen10_3946','xen12_3990','xen17_3612','xen21_KPMP057']
adata_list = []
for sample in sample_list:
    adata = read_xenium(base_dir, sample)
    adata_list.append(adata)

In [None]:
for adata in adata_list:
    # raw to new layer named "counts"
    prom_genes = ['GPX5', 'FKBP5', 'PIGR', 'IGFBP7','PSAP','VIM','GPX3']
    genes_to_keep = [gene for gene in adata.var_names if gene not in prom_genes]
    adata = adata[:, genes_to_keep].copy()
    sc.pp.filter_cells(adata, min_counts=5) 
    sc.pp.filter_genes(adata, min_counts=1)
    adata.layers["counts"] = adata.X.copy()
    #adata = adata[:, ~adata.var_names.isin(genes_to_remove)].copy()

In [None]:
adata1 = adata_list[0]
adata2 = adata_list[1]
adata3 = adata_list[2]
adata4 = adata_list[3]
adata5 = adata_list[4]
adata6 = adata_list[5]
adata7 = adata_list[6] 

In [None]:
# add cell label information 
#sample_list = ['xen4_3781','xen5_3916', 'xen6_3781','xen7_3916','xen10_3723','xen10_3946','xen11_3712','xen11_3919','xen12_3990','xen12_KPMP038','xen13_3782','xen13_KPMP038','xen16_3609','xen16_3785','xen17_3612','xen17_3673,'xen21_3729','xen21_3811','xen21_KPMP057','xen22_KPMP102','xen22A_3809','xen22B_3809','xen23_3782','xen23_3990','xen23_KPMP038','xen24_3782','xen24_3990','xen24_KPMP038','xen25_4091','xen25_4126','xen25_KPMP101','xen26_4091','xen26_4132','xen26_KPMP101']

adata_list = [adata1, adata2, adata3, adata4, adata5, adata6, adata7]
csv_files = [
   ## paths to the celltype dfs for each sample
] 

for adata, csv_file in zip(adata_list, csv_files):
    labels_df = pd.read_csv(csv_file)
    labels_df["cell_id"] = labels_df["cell_id"].astype(str) 
    labels_df = labels_df.set_index("cell_id")
    adata.obs.index = adata.obs.index.astype(str)
    adata.obs = adata.obs.join(labels_df,how="left")


In [None]:
adata1 = adata_list[0]
adata2 = adata_list[1]
adata3 = adata_list[2]
adata4 = adata_list[3]
adata5 = adata_list[4]
adata6 = adata_list[5]
adata7 = adata_list[6] 

In [None]:
for adata in [adata1, adata2, adata3, adata4, adata5, adata6, adata7]:
    adata.obs_names_make_unique()

In [None]:
#unique_keys = ["sample1", "sample2", "sample3", "sample4", "sample5", "sample6", "sample7",]

#all_adata = anndata.concat(
#    [adata1, adata2, adata3, adata4, adata5, adata6],
#    keys=unique_keys,       # unique keys for each adata
#    label="sample_id",      # stores the key in obs as "sample_id"
#   index_unique="-")

all_adata = anndata.concat([adata1, adata2, adata3, adata4, adata5, adata6,adata7,adata8,adata9,adata10,adata11,adata12])
# Then map the unique keys to the desired sample names
all_adata.obs["sample_id"] = all_adata.obs["sample"].replace({
   "xen4_3781": "xen4",
    "xen6_3781": "xen6",
    "xen10_3723": "xen10",
    "xen10_3946": "xen10",
    "xen12_3990": "xen12",
    "xen17_3612": "xen17", 
    "xen21_KPMP_057" : "xen21",
    
   
})


In [None]:
all_adata.obs['v2.subclass.levelKIL']

In [None]:
all_adata.obsm["X_spatial"] = all_adata.obsm["spatial"]

In [None]:
all_adata

In [None]:
celltypes_to_remove = ["low_quality", "unassigned"]
all_adata = all_adata[~all_adata.obs['v2.subclass.levelKIL'].isin(celltypes_to_remove)].copy()

In [None]:
all_adata.obs['sample'].unique()

In [None]:
all_adata

In [None]:
# replace "/" with "_" in the "group" column
all_adata.obs["v2.subclass.levelKIL"] = all_adata.obs["v2.subclass.levelKIL"].str.replace("/", "_")

In [None]:
all_adata.obs['v2.subclass.levelKIL'].unique()

In [None]:
all_adata.obs_names_make_unique()

In [None]:
# 1) Subset out any cells where the label is NaN
all_adata = all_adata[all_adata.obs['v2.subclass.levelKIL'].notna()].copy()

# 2) Convert that column to pandas Categorical (dropping any unused categories)
all_adata.obs['v2.subclass.levelKIL'] = (
    all_adata.obs['v2.subclass.levelKIL']
    .astype('category')
    .cat.remove_unused_categories()
)


In [None]:
all_adata.obs['sample'].value_counts()

In [None]:
scvi.external.RESOLVI.setup_anndata(
    all_adata,
    batch_key="sample_id",
    labels_key="v2.subclass.levelKIL",
    layer="counts"       
)

In [None]:
all_adata

In [None]:
model = scvi.external.RESOLVI(all_adata, semisupervised=True)

In [None]:
model.train(max_epochs=50)

In [None]:
all_adata.obsm["X_resolVI"] = model.get_latent_representation(all_adata)

In [None]:
sc.pp.neighbors(all_adata, use_rep="X_resolVI")
sc.tl.umap(all_adata)
sc.pl.umap(all_adata, color="sample_id")  

In [None]:
all_adata.obs["sample"].value_counts()

In [None]:
# sample corrected expression counts
samples_corr = model.sample_posterior(
    model=model.module.model_corrected,
    return_sites=["px_rate"],
    summary_fun={"post_sample_q50": np.median},
    num_samples=3,
    summary_frequency=30,
)

samples_corr = pd.DataFrame(samples_corr).T

# save corrected counts in new layer:
all_adata.layers["generated_expression"] = samples_corr.loc["post_sample_q50", "px_rate"]

In [None]:
marker_genes = ["LRP2", "UMOD", "SLC12A1"]
sc.pl.umap(all_adata, color=marker_genes, layer="generated_expression", vmax="p98", title="Corrected Expression")

In [None]:
all_adata.obs["predicted_labels"] = model.predict(all_adata, num_samples=3, soft=False)

In [None]:
sc.pl.umap(all_adata, color=["predicted_labels"])

In [None]:
all_adata.write_h5ad("/Volumes/Active/Xenium/Banksy_Atlas/20250602_Atlasresolvi_seven_samples_obj.h5ad")