In [None]:
# %% [markdown]
# # 2. Training an scPred Model
# 
# This notebook loads the preprocessed reference data and trains
# our `ScPredModel`.

# %%
import scanpy as sc
import anndata as ad
import numpy as np
import os
import sys
import pickle

# Add project root to path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from scpred_py import ScPredModel

In [None]:
# %% [markdown]
# ## Load Reference Data

# %%
ref_adata = ad.read_h5ad('../data/processed/pbmc3k_ref.h5ad')
print(ref_adata)
print("Cell types:\n", ref_adata.obs['cell_type'].value_counts())

In [None]:
# %% [markdown]
# ## Initialize and Train the Model
# 
# We use the `ScPredModel` class and train it on our reference data.
# We need to specify which column in `.obs` contains the cell type labels.

# %%
scpred_model = ScPredModel()

# Train the model
scpred_model.train(ref_adata, cell_type_key='cell_type', n_components=30)

print("\nModel Trained!")
print("PCA Model:", scpred_model.pca_model_)
print("Classifier:", scpred_model.classifier_)
print("Reference Genes:", len(scpred_model.reference_genes_))

In [None]:
# %% [markdown]
# ## Save the Trained Model
# 
# We can save the trained model object using `pickle` for later use.

# %%
if not os.path.exists('../models'):
    os.makedirs('../models')

with open('../models/scpred_pbmc3k_model.pkl', 'wb') as f:
    pickle.dump(scpred_model, f)

print("Trained model saved to ../models/scpred_pbmc3k_model.pkl")