In [None]:
# %% [markdown]
# # 3. Predicting with an scPred Model
# 
# This notebook loads a trained `ScPredModel` and the query data,
# then performs cell type prediction.

# %%
import scanpy as sc
import anndata as ad
import numpy as np
import pandas as pd
import os
import sys
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix

# Add project root to path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
# %% [markdown]
# ## Load Model and Query Data

# %%
# Load the trained model
with open('../models/scpred_pbmc3k_model.pkl', 'rb') as f:
    scpred_model = pickle.load(f)

print("Loaded Model:", scpred_model)

# Load the query data
query_adata = ad.read_h5ad('../data/processed/pbmc3k_query.h5ad')
print("\nQuery Data:", query_adata)

In [None]:
# %% [markdown]
# ## Perform Prediction
# 
# We use the `predict` method of our loaded model.
# **Important**: The current `_core.py` implementation re-fits PCA
# on common genes and scales the query data. This is a simplification
# and a key area to refine based on the original scPred paper for
# maximum accuracy.

# %%
query_adata_pred = scpred_model.predict(query_adata)

print("\nQuery Data with Predictions:")
print(query_adata_pred.obs[['cell_type', 'scpred_prediction']].head())

In [None]:
# %% [markdown]
# ## Evaluate Predictions
# 
# Since our query data *does* have true labels (because we split it),
# we can evaluate the performance.

# %%
true_labels = query_adata_pred.obs['cell_type']
predicted_labels = query_adata_pred.obs['scpred_prediction']

print("\nClassification Report:\n")
print(classification_report(true_labels, predicted_labels))

In [None]:
# %% [markdown]
# ## Visualize Results
# 
# Let's visualize the confusion matrix.

# %%
cm = confusion_matrix(true_labels, predicted_labels, labels=scpred_model.classifier_.classes_)
cm_df = pd.DataFrame(cm, index=scpred_model.classifier_.classes_, columns=scpred_model.classifier_.classes_)

plt.figure(figsize=(8, 6))
sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

In [None]:
# %% [markdown]
# We can also visualize the UMAP of the query data, colored by true
# and predicted labels.

# %%
# We need to compute UMAP on the query data using its projected PCs
query_adata_pred.obsm['X_scpred_pca'] = scpred_model.predict(query_adata_pred.copy()).obsm['X_scpred_pca'] # Re-run predict to get PCs

# Calculate UMAP based on *our projected* PCs
sc.pp.neighbors(query_adata_pred, n_neighbors=10, use_rep='X_scpred_pca')
sc.tl.umap(query_adata_pred)

# %%
sc.pl.umap(query_adata_pred, color=['cell_type', 'scpred_prediction'], title=['True Labels', 'scPred Predictions'])

In [None]:
# %% [markdown]
# ## Next Steps
# 
# This shows the basic workflow. To improve this, you should focus on:
# 1.  **Hyperparameter Tuning**: Implement `GridSearchCV` in `_training.py`.
# 2.  **PCA Projection Accuracy**: Ensure the scaling and gene handling *exactly* match `scPred`'s method before PCA projection. This might involve saving scaling factors from the reference.
# 3.  **Feature Selection**: Implement the specific informative gene selection used by `scPred`.
# 4.  **Probability Thresholding**: `scPred` includes steps to handle "unassigned" cells based on probability thresholds.
# 5.  **Robustness & Error Handling**: Add more checks and balances.
# 6.  **Testing**: Implement `pytest` tests in the `tests/` directory.