In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import os
import sys
import anndata as ad
import tacco as tc
import squidpy as sq

In [None]:
def read_xenium(path, sample_name):
    """
    Custom reader for Xenium data using only H5 and cell CSVs.

    Parameters
    ----------
    path : str
        Directory containing the Xenium files.
    sample_name : str
        Prefix for the files (e.g., 'KidneySample1').

    Returns
    -------
    AnnData
        AnnData object with spatial coordinates and metadata.
    """
    # File paths
    gene_counts_file = os.path.join(path, f"{sample_name}_cell_feature_matrix.h5")
    cell_metadata_file = os.path.join(path, f"{sample_name}_cells.csv")

    # Load gene counts&cellscsv 
    adata = sc.read_10x_h5(filename=gene_counts_file)
    df = pd.read_csv(cell_metadata_file)

    df.set_index(adata.obs_names, inplace = True)
    adata.obs = df.copy()
    # Add spatial information
    adata.obsm["spatial"] = adata.obs[["x_centroid", "y_centroid"]].copy().to_numpy()
    # Add sample label
    adata.obs["sample"] = sample_name
    return adata


In [None]:
base_dir = '/storage2/fs1/sanjayjain/Active/Xenium/Data/KID_final_Dataset/'
sample_list = ['xen4_3781','xen6_3781','xen10_3723','xen10_3946','xen12_3990','xen17_3612','xen21_KPMP057']
adata_list = []
for sample in sample_list:
    adata = read_xenium(base_dir, sample)
    adata_list.append(adata)


## PRE QC plots 

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Collect metrics
metrics_list = []
markers = [
    'o','s','D','^','v','<','>','p','h','H','*','+','x',
    '1','2','3','4','|','_','.','P','X'
]
marker_map = {}

for adata in adata_list:
    sample_name = adata.obs['sample'].iloc[0]   # e.g., xen4_KPMP123 or xen4_3120
    patient_id = sample_name.split('_')[-1]     # patient ID = last part of name

    if patient_id not in marker_map:
        marker_map[patient_id] = markers[len(marker_map) % len(markers)]

    total_cells = adata.shape[0]
    non_zero_transcripts = adata.obs[adata.obs['transcript_counts'] > 0]
    median_transcripts = non_zero_transcripts['transcript_counts'].median()

    metrics_list.append({
        'Sample': sample_name,
        'Patient': patient_id,
        'Total Cells': total_cells,
        'Median Transcripts': median_transcripts,
        'Marker': marker_map[patient_id]
    })

metrics_df = pd.DataFrame(metrics_list)

# Order samples numerically (xen4 → xen26)
metrics_df['Order'] = metrics_df['Sample'].str.extract(r'xen(\d+)').astype(float)
metrics_df = metrics_df.sort_values(by='Order')

# Compute mean and ±2 SD
mean_cells = metrics_df['Total Cells'].mean()
std2_cells = 2 * metrics_df['Total Cells'].std()

mean_transcripts = metrics_df['Median Transcripts'].mean()
std2_transcripts = 2 * metrics_df['Median Transcripts'].std()

# -----------------------------
# 1️⃣ Plot: Total Cells per Sample
plt.figure(figsize=(12,6))
for idx, row in metrics_df.iterrows():
    # Color based on patient ID prefix
    color = 'blue' if row['Patient'].startswith('KPMP') else 'green'
    plt.scatter(row['Sample'], row['Total Cells'], marker=row['Marker'], color=color, s=120,
                label=row['Patient'] if row['Patient'] not in plt.gca().get_legend_handles_labels()[1] else "")

plt.axhline(mean_cells, color='black', linestyle='-', linewidth=1.5, label='Mean')
plt.axhline(mean_cells + std2_cells, color='red', linestyle='--', linewidth=1, label='+2 SD')
plt.axhline(max(mean_cells - std2_cells, 0), color='red', linestyle='--', linewidth=1, label='-2 SD')

plt.xticks(rotation=45, ha='right')
plt.ylabel('Total Cells', fontsize=12, fontweight='bold')
plt.xlabel('Sample', fontsize=12, fontweight='bold')
plt.title('Total Cells per Sample', fontsize=14, fontweight='bold')
plt.legend(title='Patient ID', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

# -----------------------------
# 2️⃣ Plot: Median Transcripts per Cell
plt.figure(figsize=(12,6))
for idx, row in metrics_df.iterrows():
    color = 'blue' if row['Patient'].startswith('KPMP') else 'green'
    plt.scatter(row['Sample'], row['Median Transcripts'], marker=row['Marker'], color=color, s=120,
                label=row['Patient'] if row['Patient'] not in plt.gca().get_legend_handles_labels()[1] else "")

plt.axhline(mean_transcripts, color='black', linestyle='-', linewidth=1.5, label='Mean')
plt.axhline(mean_transcripts + std2_transcripts, color='red', linestyle='--', linewidth=1, label='+2 SD')
plt.axhline(max(mean_transcripts - std2_transcripts, 0), color='red', linestyle='--', linewidth=1, label='-2 SD')

plt.xticks(rotation=45, ha='right')
plt.ylabel('Median Transcripts per Cell', fontsize=12, fontweight='bold')
plt.xlabel('Sample', fontsize=12, fontweight='bold')
plt.title('Median Transcripts per Cell per Sample', fontsize=14, fontweight='bold')
plt.legend(title='Patient ID', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()


In [None]:
adata1 = adata_list[0]
adata1

In [None]:
for adata in adata_list:
    prom_genes = ['GPX5', 'FKBP5', 'PIGR', 'IGFBP7','PSAP','VIM','GPX3']
    genes_to_keep = [gene for gene in adata.var_names if gene not in prom_genes]
    adata = adata[:, genes_to_keep].copy()
    sc.pp.filter_cells(adata, min_counts=5) 
    sc.pp.filter_genes(adata, min_counts=1)
    adata.layers["counts"] = adata.X.copy()
    sc.pp.normalize_total(adata, target_sum=1e6)
    sc.pp.log1p(adata) 
    sc.pp.pca(adata, n_comps=40)
    sc.pp.neighbors(adata, n_neighbors=4 , metric='cosine')
    sc.tl.umap(adata, min_dist=0.01, spread=1)
    sc.tl.leiden(adata)

## POST QC plots

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Collect metrics
metrics_list = []
markers = [
    'o','s','D','^','v','<','>','p','h','H','*','+','x',
    '1','2','3','4','|','_','.','P','X'
]
marker_map = {}

for adata in adata_list:
    sample_name = adata.obs['sample'].iloc[0]   # e.g., xen4_KPMP123 or xen4_3120
    patient_id = sample_name.split('_')[-1]     # patient ID = last part of name

    if patient_id not in marker_map:
        marker_map[patient_id] = markers[len(marker_map) % len(markers)]

    total_cells = adata.shape[0]
    non_zero_transcripts = adata.obs[adata.obs['transcript_counts'] > 0]
    median_transcripts = non_zero_transcripts['transcript_counts'].median()
    median_cell_area = adata.obs['cell_area'].median()


    metrics_list.append({
        'Sample': sample_name,
        'Patient': patient_id,
        'Total Cells': total_cells,
        'Median Transcripts': median_transcripts,
        'Median Cell Area':median_cell_area,
        'Marker': marker_map[patient_id]
    })

metrics_df = pd.DataFrame(metrics_list)

# Order samples numerically (xen4 → xen26)
metrics_df['Order'] = metrics_df['Sample'].str.extract(r'xen(\d+)').astype(float)
metrics_df = metrics_df.sort_values(by='Order')

# Compute mean and ±2 SD
mean_cells = metrics_df['Total Cells'].mean()
std2_cells = 2 * metrics_df['Total Cells'].std()

mean_transcripts = metrics_df['Median Transcripts'].mean()
std2_transcripts = 2 * metrics_df['Median Transcripts'].std() 

mean_area = metrics_df['Median Cell Area'].mean()
std2_area = 2 * metrics_df['Median Cell Area'].std()

# -----------------------------
# 1️⃣ Plot: Total Cells per Sample
plt.figure(figsize=(12,6))
for idx, row in metrics_df.iterrows():
    # Color based on patient ID prefix
    color = 'blue' if row['Patient'].startswith('KPMP') else 'green'
    plt.scatter(row['Sample'], row['Total Cells'], marker=row['Marker'], color=color, s=120,
                label=row['Patient'] if row['Patient'] not in plt.gca().get_legend_handles_labels()[1] else "")

plt.axhline(mean_cells, color='black', linestyle='-', linewidth=1.5, label='Mean')
plt.axhline(mean_cells + std2_cells, color='red', linestyle='--', linewidth=1, label='+2 SD')
plt.axhline(max(mean_cells - std2_cells, 0), color='red', linestyle='--', linewidth=1, label='-2 SD')

plt.xticks(rotation=45, ha='right')
plt.ylabel('Total Cells', fontsize=12, fontweight='bold')
plt.xlabel('Sample', fontsize=12, fontweight='bold')
plt.title('Total Cells per Sample', fontsize=14, fontweight='bold')
plt.legend(title='Patient ID', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

# -----------------------------
# 2️⃣ Plot: Median Transcripts per Cell
plt.figure(figsize=(12,6))
for idx, row in metrics_df.iterrows():
    color = 'blue' if row['Patient'].startswith('KPMP') else 'green'
    plt.scatter(row['Sample'], row['Median Transcripts'], marker=row['Marker'], color=color, s=120,
                label=row['Patient'] if row['Patient'] not in plt.gca().get_legend_handles_labels()[1] else "")

plt.axhline(mean_transcripts, color='black', linestyle='-', linewidth=1.5, label='Mean')
plt.axhline(mean_transcripts + std2_transcripts, color='red', linestyle='--', linewidth=1, label='+2 SD')
plt.axhline(max(mean_transcripts - std2_transcripts, 0), color='red', linestyle='--', linewidth=1, label='-2 SD')

plt.xticks(rotation=45, ha='right')
plt.ylabel('Median Transcripts per Cell', fontsize=12, fontweight='bold')
plt.xlabel('Sample', fontsize=12, fontweight='bold')
plt.title('Median Transcripts per Cell per Sample', fontsize=14, fontweight='bold')
plt.legend(title='Patient ID', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()


# 3️⃣ Plot: Median Cell Area per Sample
plt.figure(figsize=(12,6))
for idx, row in metrics_df.iterrows():
    color = 'blue' if 'KPMP' in row['Patient'] else 'green'
    plt.scatter(row['Sample'], row['Median Cell Area'], marker=row['Marker'], color=color, s=120,
                label=row['Patient'] if row['Patient'] not in plt.gca().get_legend_handles_labels()[1] else "")

plt.axhline(mean_area, color='black', linestyle='-', linewidth=1.5, label='Mean')
plt.axhline(mean_area + std2_area, color='red', linestyle='--', linewidth=1, label='+2 SD')
plt.axhline(max(mean_area - std2_area, 0), color='red', linestyle='--', linewidth=1, label='-2 SD')

plt.xticks(rotation=45, ha='right')
plt.ylabel('Median Cell Area', fontsize=12, fontweight='bold')
plt.xlabel('Sample', fontsize=12, fontweight='bold')
plt.title('Median Cell Area per Sample', fontsize=14, fontweight='bold')
plt.legend(title='Patient ID', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

In [None]:
metrics_df

In [None]:
import anndata as ad
adata_50prop = ad.read_h5ad("/storage2/fs1/sanjayjain/Active/Stephanie/objects/atlasv2_ds_50prop.h5ad")

In [None]:
## Specific to our atlas object , removing these celltypes from the atlas so tacco does not call them on the xenium data 
cell_types_to_remove = ['Ad', 'PapE']

# Filter out these cell types
adataref = adata_50prop[~adata_50prop.obs['v2.subclass.l1'].isin(cell_types_to_remove)].copy()

## Tacco level 1 Annotation 

In [None]:
import tacco as tc
for adata in adata_list: 
    adata.X = adata.X.astype(np.float32)
    tc.tl.annotate(adata, adataref, annotation_key='v2.subclass.l1', result_key='v2.subclass.l1', assume_valid_counts=True)
    y = adata.obsm['v2.subclass.l1']
# find the column where the value is 1.0
    y['v2.subclass.l1'] = y.idxmax(axis=1)
    adata.obs['v2.subclass.l1'] = y['v2.subclass.l1']

In [None]:
import tacco as tc
import numpy as np
adata_50prop = adataref
for adata in adata_list:
    broad_classes = adata.obs["v2.subclass.l1"].unique().tolist()
    annotation_dict = {}

    for bc in broad_classes:
        # Subset query and reference for the current broad class
        adata_sub_query = adata[adata.obs["v2.subclass.l1"] == bc].copy()
        adata_sub_ref = adata_50prop[adata_50prop.obs["v2.subclass.l1"] == bc].copy()

        # Skip if no observations in either
        if adata_sub_query.n_obs == 0 or adata_sub_ref.n_obs == 0:
            continue

        # TACCO annotation
        dist_df = tc.tl.annotate(
            adata_sub_query,
            adata_sub_ref,
            annotation_key="v2.subclass.l2",
            result_key=None,
            assume_valid_counts=True
        )

        assigned_series = dist_df.idxmax(axis=1)
        for cid, lbl in assigned_series.items():
            annotation_dict[cid] = lbl

    # Assign to .obs — initialize if needed
    adata.obs["v2.subclass.l2"] = np.nan
    for cid, lbl in annotation_dict.items():
        if cid in adata.obs_names:
            adata.obs.loc[cid, "v2.subclass.l2"] = lbl

    print(f"Annotated {adata.obs['v2.subclass.l2'].notna().sum()} cells in current dataset.")


In [None]:
for adata in adata_list: 
    adata.obs['v2.subclass.levelKIL'] = adata.obs['v2.subclass.l2'].map( 
    { "POD": 'POD',
    "dPOD": 'altPOD',
    "PEC": 'PEC',
    "PT-S1": 'PT',
    "dPT": 'altPT',
    "PT-S2": 'PT',
    "PT-S3": 'PT',
    "aPT": 'altPT',
    "frPT": 'altPT',
    "cycPT": 'altPT',
    "DTL2": 'DTL',
    "aDTL": 'altDTL',
    "DTL1": 'DTL',
    "DTL3": 'DTL',
    "dDTL": 'altDTL',
    "ATL": 'ATL',
    "dATL": 'altATL',
    "M-TAL": 'TAL',
    "dM-TAL": 'TAL',
    "C/M-TAL": 'TAL',
    "C-TAL": 'TAL',
    "MD": 'MD',
    "dC-TAL": 'TAL',
    "frTAL": 'altTAL',
    "aTAL": 'altTAL',
    "cycTAL": 'altTAL',
    "DCT": 'DCT',
    "dDCT": 'altDCT',
    "aDCT": 'altDCT',
    "CNT": 'CNT',
    "dCNT": 'altCNT',
    "aCNT": 'altCNT',
    "C-PC": 'PC',
    "M-PC": 'PC',
    "dM-PC": 'altPC',
    "IMCD": 'PC',
    "dIMCD": 'altPC',
    "PapE": 'PapE',
    "C-IC-A": 'IC',
    "dC-IC-A": 'altIC',
    "M-IC-A": 'IC',
    "dM-IC-A": 'IC',
    "tPC-IC": 'IC',
    "IC-B": 'IC',
    "EC-GC": 'EC-GC',
    "dEC-GC": 'EC-GC',
    "aEC-GC": 'EC-GC',
    "EC-AA": 'EC',
    "EC-DVR": 'EC',
    "EC-PTC": 'EC',
    "EC-AVR": 'EC',
    "dEC-AVR": 'EC',
    "infEC-AVR": 'EC',
    "EC-V": 'EC',
    "EC-PCV": 'EC',
    "angEC-PTC": 'EC',
    "EC-EA": 'EC',
    "dEC-PTC": 'EC',
    "infEC-PTC": 'EC',
    "EC-LYM": 'EC',
    "cycEC": 'EC',
    "M-FIB": 'FIB',
    "dM-FIB": 'altFIB',
    "FIB": 'FIB',
    "MYOF": 'MYOF',
    "infFIB": 'infFIB',
    "dFIB": 'FIB',
    "pvFIB": 'pvFIB',
    "MC": 'MC',
    "REN": 'REN',
    "VSMC": 'VSMC',
    "VSMC/P": 'VSMC_P',
    "dVSMC": 'VSMC',
    "Ad": 'Ad',
    "B": 'B',
    "PL": 'PL',
    "T": 'T',
    "NK": 'NK',
    "ERY": 'ERY',
    "MAST": 'MAST',
    "resMAC": 'resMAC',
    "moMAC-INF": 'moMAC-INF',
    "moMAC": 'moMAC',
    "DC": 'DC',
    "MON": 'MON',
    "N": 'N',
    "cycT": 'T',
    "cycMAC": 'MAC',
    "SC/NEU": 'SC_NEU'})
adata_list 

In [None]:
for adata, name in zip(adata_list, sample_list):
    cell_types_df = adata.obs[['v2.subclass.levelKIL']].copy()
    cell_types_df.index.name = 'cell_id'
    filename = f"cell_types_{name}_Xenium.csv"
    cell_types_df.to_csv(filename)

In [None]:
adata_list

In [None]:
import os

# Define output folder
output_folder = "/storage2/fs1/sanjayjain/Active/Asmita/KIDneypaper/Xenium-processed objects/"
os.makedirs(output_folder, exist_ok=True)

for i, adata in enumerate(adata_list, start=1):
    # Use sample name if available, else fallback to index
    if 'sample' in adata.obs.columns:
        sample_name = adata.obs['sample'].iloc[0]
    else:
        sample_name = f"adata_{i}"
    
    # Clean sample name (remove spaces or forbidden characters)
    sample_name = sample_name.replace(" ", "_").replace("/", "_")
    
    # Define filename
    filename = os.path.join(output_folder, f"{sample_name}.h5ad")
    
    # Save AnnData object
    adata.write(filename)
    print(f"Saved: {filename}")
