# Custom Model

In [None]:
# Setup AnnData with your custom model
scvi.model.SCVI.setup_anndata(
    adata,
    batch_key="sample_id",  # Replace with your batch column
    labels_key="cell_type"   # Replace with your cell type column
)

# Create and train custom model
from custom_scvi import CustomSCVIModel

custom_model = CustomSCVIModel(
    adata,
    n_hidden=256,
    n_latent=20,
    n_layers=4,
    dropout_rate=0.2,
    use_layer_norm=True,
    regularization_strength=0.8
)

# Train the model
custom_model.train(
    max_epochs=500,
    early_stopping=True,
    early_stopping_patience=30
)

# Save the model
custom_model.save("custom_scvi_model")

# Get latent representation
latent_custom = custom_model.get_latent_representation()
adata.obsm["X_custom_scVI"] = latent_custom

# Analyze as before
sc.pp.neighbors(adata, use_rep="X_custom_scVI")
sc.tl.umap(adata)
sc.tl.leiden(adata, key_added="custom_scvi_leiden")
sc.pl.umap(adata, color=["custom_scvi_leiden", "cell_type"])

# Compare Original vs Custom Model

In [None]:
# Load original model for comparison
original_model = scvi.model.SCVI.load("scvi_model_allen", adata)

# Get latent representations from both models
adata.obsm["X_original_scVI"] = original_model.get_latent_representation()
adata.obsm["X_custom_scVI"] = custom_model.get_latent_representation()

# Run UMAP on both latent spaces
sc.pp.neighbors(adata, use_rep="X_original_scVI", key_added="neighbors_original")
sc.pp.neighbors(adata, use_rep="X_custom_scVI", key_added="neighbors_custom")

sc.tl.umap(adata, neighbors_key="neighbors_original", key_added="umap_original")
sc.tl.umap(adata, neighbors_key="neighbors_custom", key_added="umap_custom")

# Create comparison visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
sc.pl.umap(adata, color="cell_type", use_rep="umap_original", title="Original scVI", ax=axes[0], show=False)
sc.pl.umap(adata, color="cell_type", use_rep="umap_custom", title="Custom scVI", ax=axes[1], show=False)
plt.tight_layout()
plt.savefig("model_comparison.png", dpi=300)
plt.show()

# Compare cell type separation metrics
from sklearn.metrics import silhouette_score

silhouette_original = silhouette_score(
    adata.obsm["X_original_scVI"], 
    adata.obs["cell_type"].cat.codes
)

silhouette_custom = silhouette_score(
    adata.obsm["X_custom_scVI"], 
    adata.obs["cell_type"].cat.codes
)

print(f"Original model silhouette score: {silhouette_original:.4f}")
print(f"Custom model silhouette score: {silhouette_custom:.4f}")