In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
import warnings
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import torch
import scanpy as sc
from sklearn.metrics import (
    silhouette_score, 
    adjusted_rand_score, 
    normalized_mutual_info_score,
    homogeneity_score,
)
import anndata as AnnData
import scvi
from scvi.model import SCVI
from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection

sys.path.append(os.path.abspath("../src"))
from model import InformedSCVI
from pathway import get_pathway_masks, get_random_masks, filter_genes
from train import plot_loss, split_train_val

warnings.simplefilter("ignore")
sc.logging.print_header()
sc.settings.figsize = (10, 10)
sc.settings.figdir = "../results/figures/"
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
sns.set_theme()
torch.set_float32_matmul_precision("high")

[rank: 0] Seed set to 0


Last run with scvi-tools version: 1.4.0.post1


In [2]:
CELL_TYPE_KEY = 'cell1'
# KEY_OF_INTEREST = 'condition'

# Check GPU availability
print(f"Torch cuda available: {torch.cuda.is_available()}", flush=True)

Torch cuda available: True


In [10]:
def run_model(adata, 
              activation="relu",
              masks=None,
              likelihood="zinb",
              layer="RNA", 
              key="scVI"):
    
    # Define model
    InformedSCVI.setup_anndata(
        adata,
        layer=layer,
        batch_key="sample",
    )
    print(type(masks))
    model = InformedSCVI(
        adata,
        gene_likelihood=likelihood,
        activation=activation,
        masks=masks,
    )

    # Train model
    train_indices, val_indices = split_train_val(adata, CELL_TYPE_KEY)
    early_stopping_kwargs = {
        "early_stopping": True,
        "early_stopping_patience": 15,
        "early_stopping_monitor": "validation_loss",
    }
    datasplitter_kwargs = {
        "external_indexing": [train_indices, val_indices],
    }
    model.train(datasplitter_kwargs=datasplitter_kwargs, **early_stopping_kwargs) # Can add: max_epochs=10

    # Save model
    model.save(f"../results/models/{key}", overwrite=True, save_anndata=False)


    # Save latent representation
    latent = model.get_latent_representation()
    adata.obsm[f"X_{key}"] = latent
    
    return adata

In [4]:
def compute_metrics(adata, key):
    model = InformedSCVI.load(f"../results/models/{key}", adata=adata)

    plot_loss(model, save_path=f"../results/figures/{key}/loss.png")

    # Save normalized expression
    adata_subset = adata[adata.obs.cell1 == "Tumor"]
    denoised = model.get_normalized_expression(adata_subset, library_size=1e4)
    denoised = pd.DataFrame(
        denoised,
        index=adata_subset.obs_names,
        columns=adata_subset.var_names,
    )
    os.makedirs(f"../results/models/{key}", exist_ok=True)
    denoised.to_csv(f"../results/models/{key}/denoised_data.csv")
    
    # Clustering
    sc.pp.neighbors(adata, use_rep=f"X_{key}", key_added=f"neighbors_{key}")
    sc.tl.umap(adata, min_dist=0.3, key_added=f"umap_{key}", neighbors_key=f"neighbors_{key}")
    sc.tl.leiden(adata, key_added=f"leiden_{key}", resolution=0.1, neighbors_key=f"neighbors_{key}")
    sc.settings.figdir = f"../results/figures/{key}/"
    for feature in ["leiden", "cell1", "cell2", "sample"]:
        sc.pl.embedding(adata, basis=f"umap_{key}", color=feature, title=feature, frameon=False, show=False, save=f"_{feature}.png")

    # Calculate metrics

    elbo = model.get_elbo()
    reconstruction_error = model.get_reconstruction_error()['reconstruction_loss']
    #transform tensors to floats
    elbo = elbo.item()
    reconstruction_error = reconstruction_error.item()

    silhouette = silhouette_score(adata.obsm[f"X_{key}"], adata.obs[f"leiden_{key}"], metric="euclidean")
    ARI = adjusted_rand_score(adata.obs["cell2"], adata.obs[f"leiden_{key}"])
    NMI = normalized_mutual_info_score(adata.obs["cell2"], adata.obs[f"leiden_{key}"])
    homogeneity = homogeneity_score(adata.obs["cell2"], adata.obs[f"leiden_{key}"])
    metrics = pd.DataFrame(
        {
            "elbo": [elbo],
            "reconstruction_error": [reconstruction_error],
            "silhouette": [silhouette],
            "ARI": [ARI],
            "NMI": [NMI],
            "homogeneity": [homogeneity],
        }
    )
    metrics.to_csv(f"../results/models/{key}/metrics.csv", index=False)

    return

In [11]:
adata = sc.read("../data/NBsmall/NB.bone.Met_preprocessed.h5ad")

adata = adata[:, ~adata.var['ensembl_gene_id'].isna()]

Only considering the two last: ['.Met_preprocessed', '.h5ad'].
Only considering the two last: ['.Met_preprocessed', '.h5ad'].


In [14]:
genes_per_pathway, genes_per_circuit, circuits_per_pathway = get_pathway_masks()
adata, genes_per_pathway, genes_per_circuit, circuits_per_pathway = filter_genes(adata, genes_per_pathway, genes_per_circuit, circuits_per_pathway)
frac = genes_per_circuit.sum(axis=1).mean() / genes_per_circuit.shape[1]
rnd_genes_per_pathway, rnd_genes_per_circuit, rnd_circuits_per_pathway = get_random_masks(adata.var_names, genes_per_circuit.shape[0], genes_per_pathway.shape[0], frac=frac, seed=42)

masks_keggNB = [genes_per_circuit, circuits_per_pathway]
print(f"Masks dimensions: {[mask.shape for mask in masks_keggNB]}")
masks_keggNB_rnd = [rnd_genes_per_circuit, rnd_circuits_per_pathway]
masks_keggNB_pathways = [genes_per_pathway]
print(f"Masks dimensions: {[mask.shape for mask in masks_keggNB_pathways]}")
masks_keggNB_pathways_rnd = [rnd_genes_per_pathway]

Current directory: /cluster/work/bewi/members/rquiles/piscvi/notebooks
Filtering genes based on pathways and circuits with minimum 4 genes per circuit
Circuits in genes_per_circuit: 3832
Circuits in circuits_per_pathway: 3832
Masks dimensions: [(3832, 6486), (367, 3832)]
Masks dimensions: [(367, 6486)]


In [12]:
adata = adata.copy()

In [15]:
adata = run_model(
    adata,
    key="piscVI",
    masks=masks_keggNB,
)

setup_anndata: InformedSCVI
<class 'list'>
Hello
Activation function encoder: ReLU
Activation function encoder: ReLU
Activation function decoder: ReLU


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


2 masks found for module encoder.
Mask dimensions for module encoder: [torch.Size([3832, 6486]), torch.Size([367, 3832])]
Trainer kwargs:  {'early_stopping_patience': 15, 'early_stopping_monitor': 'validation_loss', 'early_stopping': True}


Training:   0%|          | 0/218 [00:00<?, ?it/s]

Monitored metric validation_loss did not improve in the last 15 records. Best score: 2016.654. Signaling Trainer to stop.


In [16]:
compute_metrics(adata, key="piscVI")

[34mINFO    [0m File ..[35m/results/models/piscVI/[0m[95mmodel.pt[0m already downloaded                                                 
setup_anndata: InformedSCVI
Hello
Activation function encoder: ReLU
Activation function encoder: ReLU
Activation function decoder: ReLU
[34mINFO    [0m Received view of anndata, making copy.                                                                    
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


<Figure size 1000x600 with 0 Axes>

In [17]:
adata = run_model(
    adata,
    key="piscVII",
    masks=masks_keggNB_rnd,
)

setup_anndata: InformedSCVI
<class 'list'>
Hello
Activation function encoder: ReLU
Activation function encoder: ReLU
Activation function decoder: ReLU


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


2 masks found for module encoder.
Mask dimensions for module encoder: [torch.Size([3832, 6486]), torch.Size([367, 3832])]
Trainer kwargs:  {'early_stopping_patience': 15, 'early_stopping_monitor': 'validation_loss', 'early_stopping': True}


SLURM auto-requeueing enabled. Setting signal handlers.


Training:   0%|          | 0/218 [00:00<?, ?it/s]

Monitored metric validation_loss did not improve in the last 15 records. Best score: 1998.307. Signaling Trainer to stop.


In [18]:
compute_metrics(adata, key="piscVII")

[34mINFO    [0m File ..[35m/results/models/piscVII/[0m[95mmodel.pt[0m already downloaded                                                
setup_anndata: InformedSCVI
Hello
Activation function encoder: ReLU
Activation function encoder: ReLU
Activation function decoder: ReLU
[34mINFO    [0m Received view of anndata, making copy.                                                                    
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


<Figure size 1000x600 with 0 Axes>