In [None]:
# %% [markdown]
# # 1. Data Loading and Preprocessing
# 
# This notebook loads the PBMC3k dataset, performs standard preprocessing,
# and saves it for later use. We will also split it into a "reference"
# and a "query" set for demonstration purposes.

# %%
import scanpy as sc
import anndata as ad
import numpy as np
import os
import sys

# Add project root to path to import our package
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from scpred_py import _preprocessing

In [None]:
# %% [markdown]
# ## Load Data

# %%
adata = sc.datasets.pbmc3k()
adata.var_names_make_unique()
print(adata)

In [None]:
# %% [markdown]
# ## Preprocessing
# 
# We apply standard filtering, normalization, log-transform, HVG selection, and scaling.
# We'll use our custom preprocessing function to keep it consistent.

# %%
# Add some cell type annotations for training (using Scanpy's workflow)
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
adata_hvg = adata[:, adata.var.highly_variable].copy() # Work on HVG subset
sc.pp.scale(adata_hvg, max_value=10)
sc.tl.pca(adata_hvg, svd_solver='arpack')
sc.pp.neighbors(adata_hvg, n_neighbors=10, n_pcs=40)
sc.tl.louvain(adata_hvg, random_state=42) # Use louvain clusters as 'cell types'

# Map back the cell types to the full (but preprocessed) data
adata.obs['cell_type'] = adata_hvg.obs['louvain']

# Now apply our *intended* preprocessing (without PCA yet) for scPred
adata_proc = _preprocessing.standard_preprocess(adata.copy())
print(adata_proc)

In [None]:
# %% [markdown]
# ## Split Data (Reference vs. Query)
# 
# We'll randomly split the cells into 70% reference and 30% query.

# %%
n_cells = adata_proc.shape[0]
indices = np.arange(n_cells)
np.random.shuffle(indices)

ref_idx = indices[:int(0.7 * n_cells)]
query_idx = indices[int(0.7 * n_cells):]

ref_adata = adata_proc[ref_idx, :].copy()
query_adata = adata_proc[query_idx, :].copy()

print(f"Reference data shape: {ref_adata.shape}")
print(f"Query data shape: {query_adata.shape}")

In [None]:
# %% [markdown]
# ## Save Data
# 
# We'll save these AnnData objects.

# %%
if not os.path.exists('../data/processed'):
    os.makedirs('../data/processed')

ref_adata.write('../data/processed/pbmc3k_ref.h5ad')
query_adata.write('../data/processed/pbmc3k_query.h5ad')

print("Data saved.")