In [1]:
# In a Jupyter Notebook cell
import scvi
import torch
import os
import numpy as np
import pandas as pd
import anndata
import scanpy as sc # For reading AnnData


  from .autonotebook import tqdm as notebook_tqdm


In [3]:

print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
    print(f"PyTorch CUDA version: {torch.version.cuda}")
else:
    print("CRITICAL ERROR: CUDA not available. This notebook requires GPU.")
    # You might want to stop execution if no GPU

# --- scvi-tools Settings ---
# Set dl_num_workers (e.g., to 0 for simplicity in notebook debugging, or 2 if preferred)
scvi.settings.dl_num_workers = 0
print(f"Set scvi.settings.dl_num_workers to: {scvi.settings.dl_num_workers}")

# --- File Paths ---
base_dir = '/home/minhang/mds_project/data/cohort_adata/multiVI_model'
original_adata_path = os.path.join(base_dir, 'adata.h5ad')
model_pt_path = os.path.join(base_dir, 'model.pt')

# --- Load Original AnnData (for model context) ---
print(f"Loading original AnnData object from: {original_adata_path}")
adata_mvi_original = sc.read_h5ad(original_adata_path)
adata_mvi_original.var_names_make_unique()
print(f"Original AnnData loaded: {adata_mvi_original.shape}")

# --- Setup AnnData for scvi-tools (using the original AnnData) ---
print("Setting up original AnnData for MULTIVI model...")
scvi.model.MULTIVI.setup_anndata(
    adata_mvi_original,
    batch_key="Tech",
    protein_expression_obsm_key="ADT",
    categorical_covariate_keys=["sample"]
)
print("Original AnnData setup complete.")

# --- Initialize Model Shell & Load State ---
print("Initializing MULTIVI model shell...")
n_genes_val = (adata_mvi_original.var["modality"] == "Gene Expression").sum()
n_regions_val = (adata_mvi_original.var["modality"] == "peaks").sum()
model_shell = scvi.model.MULTIVI(
    adata_mvi_original,
    n_genes=n_genes_val,
    n_regions=n_regions_val,
)
print(f"Model shell created. Inferred params: {model_shell}") # Shows n_hidden, n_latent etc.

print(f"Loading model state from: {model_pt_path}")
loaded_full_checkpoint = torch.load(model_pt_path, map_location='cpu', weights_only=False)
actual_state_dict = loaded_full_checkpoint['model_state_dict']
model_shell.module.load_state_dict(actual_state_dict)
model_shell.is_trained_ = True
model = model_shell # Assign to 'model'
print("Model state loaded.")

# --- Move Model to GPU ---
if torch.cuda.is_available():
    target_device = "cuda:0"
    model.to_device(target_device)
    print(f"Model moved to device: {model.device}")
else:
    print("Model cannot be moved to GPU as CUDA is not available.")

PyTorch CUDA available: True
Number of CUDA devices: 8
PyTorch CUDA version: 11.8
Set scvi.settings.dl_num_workers to: 0
Loading original AnnData object from: /home/minhang/mds_project/data/cohort_adata/multiVI_model/adata.h5ad
Original AnnData loaded: (192149, 335386)
Setting up original AnnData for MULTIVI model...
[34mINFO    [0m Using column names from columns of adata.obsm[1m[[0m[32m'ADT'[0m[1m][0m                                                      
Original AnnData setup complete.
Initializing MULTIVI model shell...


Model shell created. Inferred params: 
Loading model state from: /home/minhang/mds_project/data/cohort_adata/multiVI_model/model.pt
Model state loaded.
Model moved to device: cuda:0


In [4]:
# Ensure model and adata_mvi_original are loaded and model is on GPU
if 'model' not in locals() or 'adata_mvi_original' not in locals():
    print("ERROR: 'model' or 'adata_mvi_original' not loaded. Please run the setup cell first.")
elif model.device.type != 'cuda':
    print("ERROR: Model is not on CUDA device. Please move it to GPU.")
else:
    print(f"Model is on device: {model.device}. Proceeding with batch processing...")

    batch_size_inspect = 128  # Use a small batch size for quick inspection

    # Create a DataLoader using the adata_mvi_original
    inspect_scdl = model._make_data_loader(
        adata=adata_mvi_original,
        shuffle=False, # Keep order for consistency, though not strictly needed for key inspection
        batch_size=batch_size_inspect,
    )
    print(f"DataLoader created. Will process one batch of size up to {batch_size_inspect}.")

    with torch.no_grad():
        model.module.eval()

        for i, tensors_cpu in enumerate(inspect_scdl):
            print(f"\nProcessing batch {i+1} for inspection...")
            
            # 1. Manually move tensors to the model's device
            tensors_gpu = {
                k: v.to(model.device) if isinstance(v, torch.Tensor) else v
                for k, v in tensors_cpu.items()
            }
            print("  Tensors moved to GPU.")

            # 2. Perform inference
            inference_kwargs = {"n_samples": 1} 
            inference_inputs = model.module._get_inference_input(tensors_gpu)
            inference_outputs = model.module.inference(**inference_inputs, **inference_kwargs)
            print("  Inference step complete.")
            # print(f"  Inference output keys: {inference_outputs.keys()}") # Optional: view inference keys

            # 3. Perform generation
            generative_kwargs = {"use_z_mean": True} 
            generative_inputs = model.module._get_generative_input(tensors_gpu, inference_outputs)
            generative_outputs_dict = model.module.generative(**generative_inputs, **generative_kwargs)
            print("  Generative step complete.")

            # --- INSPECTION ---
            print("\n  --- Inspecting generative_outputs_dict ---")
            print(f"  Top-level keys: {generative_outputs_dict.keys()}")

            if "px" in generative_outputs_dict:
                print(f"    Keys in generative_outputs_dict['px'] (for RNA): {generative_outputs_dict['px'].keys()}")
                if "px_scale" in generative_outputs_dict["px"]:
                    print(f"      Shape of px_scale (RNA): {generative_outputs_dict['px']['px_scale'].shape}")
            
            if "py" in generative_outputs_dict:
                print(f"    Keys in generative_outputs_dict['py'] (for Protein): {generative_outputs_dict['py'].keys()}")
                # Example: check for 'py_mean' or 'py_normalized'
                if "py_mean" in generative_outputs_dict["py"]:
                     print(f"      Shape of py_mean (Protein): {generative_outputs_dict['py']['py_mean'].shape}")
                if "py_normalized" in generative_outputs_dict["py"]:
                     print(f"      Shape of py_normalized (Protein): {generative_outputs_dict['py']['py_normalized'].shape}")

            if "pa" in generative_outputs_dict:
                print(f"    Keys in generative_outputs_dict['pa'] (for ATAC): {generative_outputs_dict['pa'].keys()}")
                # Example: check for 'pa_probs'
                if "pa_probs" in generative_outputs_dict["pa"]:
                    print(f"      Shape of pa_probs (ATAC): {generative_outputs_dict['pa']['pa_probs'].shape}")
            
            print("  --- End of inspection for this batch ---")
            
            break # Process only the first batch

    print("\nInspection finished.")

Model is on device: cuda:0. Proceeding with batch processing...
DataLoader created. Will process one batch of size up to 128.

Processing batch 1 for inspection...
  Tensors moved to GPU.
  Inference step complete.
  Generative step complete.

  --- Inspecting generative_outputs_dict ---
  Top-level keys: dict_keys(['p', 'px_scale', 'px_r', 'px_rate', 'px_dropout', 'py_', 'log_pro_back_mean'])
  --- End of inspection for this batch ---

Inspection finished.


In [5]:
# In a Jupyter Notebook cell (assuming Cell 1 for setup and model loading has been run)

if 'model' not in locals() or 'adata_mvi_original' not in locals():
    print("ERROR: 'model' or 'adata_mvi_original' not loaded. Please run the setup cell first.")
elif model.device.type != 'cuda':
    print("ERROR: Model is not on CUDA device. Please move it to GPU.")
else:
    print(f"Model is on device: {model.device}. Proceeding with batch processing for inspection...")

    batch_size_inspect = 128
    inspect_scdl = model._make_data_loader(
        adata=adata_mvi_original,
        shuffle=False,
        batch_size=batch_size_inspect,
    )
    print(f"DataLoader created. Will process one batch of size up to {batch_size_inspect}.")

    # Get expected dimensions for validation
    n_proteins_expected = adata_mvi_original.obsm["ADT"].shape[1]
    n_regions_expected = (adata_mvi_original.var["modality"] == "peaks").sum()
    print(f"Expected number of proteins: {n_proteins_expected}")
    print(f"Expected number of regions (peaks): {n_regions_expected}")

    with torch.no_grad():
        model.module.eval()
        for i, tensors_cpu in enumerate(inspect_scdl):
            print(f"\nProcessing batch {i+1} for inspection...")
            tensors_gpu = {
                k: v.to(model.device) if isinstance(v, torch.Tensor) else v
                for k, v in tensors_cpu.items()
            }
            print("  Tensors moved to GPU.")

            inference_kwargs = {"n_samples": 1}
            inference_inputs = model.module._get_inference_input(tensors_gpu)
            inference_outputs = model.module.inference(**inference_inputs, **inference_kwargs)
            print("  Inference step complete.")

            generative_kwargs = {"use_z_mean": True}
            generative_inputs = model.module._get_generative_input(tensors_gpu, inference_outputs)
            generative_outputs_dict = model.module.generative(**generative_inputs, **generative_kwargs)
            print("  Generative step complete.")

            # --- DETAILED INSPECTION ---
            print("\n  --- Inspecting generative_outputs_dict ---")
            print(f"  Top-level keys: {list(generative_outputs_dict.keys())}")

            # RNA (already known)
            if "px_scale" in generative_outputs_dict:
                print(f"\n    RNA ('px_scale'):")
                print(f"      Shape: {generative_outputs_dict['px_scale'].shape}") # Batch x n_genes

            # Protein
            if "py_" in generative_outputs_dict:
                print(f"\n    Protein ('py_'):")
                protein_output = generative_outputs_dict['py_']
                if isinstance(protein_output, dict):
                    print(f"      'py_' is a dictionary. Keys: {list(protein_output.keys())}")
                    # Common keys for scvi-tools protein output: 'py_mean', 'py_normalized', 'y_pred'
                    for prot_key in ['py_mean', 'py_normalized', 'y_pred', 'total_sum_mean']: # Add other potential keys if needed
                        if prot_key in protein_output:
                            prot_tensor_shape = protein_output[prot_key].shape
                            print(f"        Found sub-key '{prot_key}' with shape: {prot_tensor_shape}")
                            if prot_tensor_shape[1] == n_proteins_expected:
                                print(f"          Shape matches expected n_proteins ({n_proteins_expected}). This is likely your corrected protein data.")
                            else:
                                print(f"          WARNING: Shape {prot_tensor_shape[1]} does not match expected n_proteins ({n_proteins_expected}).")
                elif hasattr(protein_output, 'shape'): # If py_ is directly a tensor
                    prot_tensor_shape = protein_output.shape
                    print(f"      'py_' is a tensor with shape: {prot_tensor_shape}")
                    if prot_tensor_shape[1] == n_proteins_expected:
                        print(f"        Shape matches expected n_proteins ({n_proteins_expected}). This could be your corrected protein data.")
                    else:
                        print(f"        WARNING: Shape {prot_tensor_shape[1]} does not match expected n_proteins ({n_proteins_expected}).")
                else:
                    print(f"      'py_' is of type: {type(protein_output)}")
            
            if "log_pro_back_mean" in generative_outputs_dict:
                 print(f"\n    Protein Background ('log_pro_back_mean'):")
                 print(f"      Shape: {generative_outputs_dict['log_pro_back_mean'].shape}")


            # ATAC-seq (Accessibility)
            if "p" in generative_outputs_dict: # This was the unusual key
                print(f"\n    Accessibility (key 'p'):")
                atac_output = generative_outputs_dict['p']
                if isinstance(atac_output, dict):
                    print(f"      'p' is a dictionary. Keys: {list(atac_output.keys())}")
                    # Common keys for scvi-tools ATAC output: 'pa_probs', 'rate'
                    for atac_key in ['pa_probs', 'rate', 'probs']: # Add other potential keys
                        if atac_key in atac_output:
                            atac_tensor_shape = atac_output[atac_key].shape
                            print(f"        Found sub-key '{atac_key}' with shape: {atac_tensor_shape}")
                            if atac_tensor_shape[1] == n_regions_expected:
                                print(f"          Shape matches expected n_regions ({n_regions_expected}). This is likely your corrected ATAC data.")
                            else:
                                print(f"          WARNING: Shape {atac_tensor_shape[1]} does not match expected n_regions ({n_regions_expected}).")
                elif hasattr(atac_output, 'shape'): # If 'p' is directly a tensor
                    atac_tensor_shape = atac_output.shape
                    print(f"      'p' is a tensor with shape: {atac_tensor_shape}")
                    if atac_tensor_shape[1] == n_regions_expected:
                        print(f"        Shape matches expected n_regions ({n_regions_expected}). This could be your corrected ATAC data (e.g., probabilities).")
                    else:
                        print(f"        WARNING: Shape {atac_tensor_shape[1]} does not match expected n_regions ({n_regions_expected}).")

                else:
                    print(f"      'p' is of type: {type(atac_output)}")
            
            print("\n  --- End of detailed inspection for this batch ---")
            break # Process only the first batch

    print("\nInspection finished. Review the output above to identify the correct keys and shapes for protein and ATAC.")

Model is on device: cuda:0. Proceeding with batch processing for inspection...
DataLoader created. Will process one batch of size up to 128.
Expected number of proteins: 170
Expected number of regions (peaks): 298785

Processing batch 1 for inspection...
  Tensors moved to GPU.
  Inference step complete.
  Generative step complete.

  --- Inspecting generative_outputs_dict ---
  Top-level keys: ['p', 'px_scale', 'px_r', 'px_rate', 'px_dropout', 'py_', 'log_pro_back_mean']

    RNA ('px_scale'):
      Shape: torch.Size([128, 36601])

    Protein ('py_'):
      'py_' is a dictionary. Keys: ['back_alpha', 'back_beta', 'rate_back', 'fore_scale', 'rate_fore', 'mixing', 'scale', 'r']

    Protein Background ('log_pro_back_mean'):
      Shape: torch.Size([128, 170])

    Accessibility (key 'p'):
      'p' is a tensor with shape: torch.Size([128, 298785])
        Shape matches expected n_regions (298785). This could be your corrected ATAC data (e.g., probabilities).

  --- End of detailed insp