# Spatial ResolVI Integration Testing

This notebook tests the spatial encoder integration in ResolVI using real spatial transcriptomics data. It follows the proven methodology from test_arrayed_perturb.ipynb but validates that the spatial encoder is properly integrated and provides meaningful spatial-aware perturbation predictions.

## Test Objectives:
1. **Validate Spatial Encoder Integration** - Ensure spatial encoder is properly connected to shift network
2. **Test Spatial Data Pipeline** - Verify spatial coordinates flow through the model correctly  
3. **Compare Spatial vs Non-Spatial Effects** - Demonstrate spatial context improves perturbation predictions
4. **Spatial Pattern Analysis** - Show spatial relationships in perturbation effects

In [1]:
# Standard imports
import os
import tempfile
import warnings
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import pandas as pd
import seaborn as sns
import torch
from scipy.spatial import distance_matrix
from scipy.stats import pearsonr
import anndata as ad

Data Loading

In [2]:
# Load real data (using your exact path)
path_to_query_adata = "/mnt/sata2/Analysis_Alex_2/perturb4_no_baysor/final_object_corrected.h5ad"
query_adata = sc.read(path_to_query_adata)

print(f"Loaded data shape: {query_adata.shape}")
print(f"Available obs keys: {list(query_adata.obs.keys())}")
print(f"Available obsm keys: {list(query_adata.obsm.keys())}")

Loaded data shape: (810225, 480)
Available obs keys: ['x_centroid', 'y_centroid', 'transcript_counts', 'control_probe_counts', 'genomic_control_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'deprecated_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'nucleus_count', 'segmentation_method', 'total_transcripts', 'nuclear_transcripts', 'cytoplasmic_transcripts', 'nuclear_transcript_percentage', 'topic', 'batch', '_scvi_batch', '_scvi_labels', 'leiden', 'Class', 'reference_crypt_villi', 'villi_number', 'peyers', 'Topic 1', 'Topic 2', 'Topic 3', 'Topic 4', 'Topic 5', 'Topic 6', 'Topic 7', 'Topic 8', 'Topic 9', 'Topic 10', 'Topic 11', 'Topic 12', 'Topic 13', 'Topic 14', 'Topic 15', 'crypt_villi_axis', 'epithelial_distance', 'epithelial_distance_clipped', 'epithelial_distance_transformed', 'guide_rnas', 'cell_types', 'cluster_cellcharter', 'epithelial_distance_scaled']
Available obsm keys: ['X_cellcharter', 'X_mde', 'X_pca', 'X_scVI_replicates', 'X_spatial']


ResolVI Import

In [3]:
import sys
sys.path.insert(0, '/home/dpatravali/Desktop/scvi-spatial/src')
print("Added to path:", sys.path[0])

import scvi.external.resolvi as RESOLVI
import scvi
print("Importing from:", scvi.external.resolvi.__file__)

# Import spatial encoder for direct testing
from scvi.external.resolvi import SpatialEncoder

Added to path: /home/dpatravali/Desktop/scvi-spatial/src


  from .autonotebook import tqdm as notebook_tqdm


Importing from: /home/dpatravali/Desktop/scvi-spatial/src/scvi/external/resolvi/__init__.py


  doc = func(self, args[0].__doc__, *args[1:], **kwargs)
  doc = func(self, args[0].__doc__, *args[1:], **kwargs)


Setup

# Spatial ResolVI Integration Testing

This notebook tests the spatial encoder integration in ResolVI using real spatial transcriptomics data. It follows the proven methodology from test_arrayed_perturb.ipynb but validates that the spatial encoder is properly integrated and provides meaningful spatial-aware perturbation predictions.

## Test Objectives:
1. **Validate Spatial Encoder Integration** - Ensure spatial encoder is properly connected to shift network
2. **Test Spatial Data Pipeline** - Verify spatial coordinates flow through the model correctly  
3. **Compare Spatial vs Non-Spatial Effects** - Demonstrate spatial context improves perturbation predictions
4. **Spatial Pattern Analysis** - Show spatial relationships in perturbation effects

Model Initialization

Training Setup

In [4]:
# Load real data (using your exact path)
path_to_query_adata = "/mnt/sata2/Analysis_Alex_2/perturb4_no_baysor/final_object_corrected.h5ad"
query_adata = sc.read(path_to_query_adata)

print(f"Loaded data shape: {query_adata.shape}")
print(f"Available obs keys: {list(query_adata.obs.keys())}")
print(f"Available obsm keys: {list(query_adata.obsm.keys())}")

Loaded data shape: (810225, 480)
Available obs keys: ['x_centroid', 'y_centroid', 'transcript_counts', 'control_probe_counts', 'genomic_control_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'deprecated_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'nucleus_count', 'segmentation_method', 'total_transcripts', 'nuclear_transcripts', 'cytoplasmic_transcripts', 'nuclear_transcript_percentage', 'topic', 'batch', '_scvi_batch', '_scvi_labels', 'leiden', 'Class', 'reference_crypt_villi', 'villi_number', 'peyers', 'Topic 1', 'Topic 2', 'Topic 3', 'Topic 4', 'Topic 5', 'Topic 6', 'Topic 7', 'Topic 8', 'Topic 9', 'Topic 10', 'Topic 11', 'Topic 12', 'Topic 13', 'Topic 14', 'Topic 15', 'crypt_villi_axis', 'epithelial_distance', 'epithelial_distance_clipped', 'epithelial_distance_transformed', 'guide_rnas', 'cell_types', 'cluster_cellcharter', 'epithelial_distance_scaled']
Available obsm keys: ['X_cellcharter', 'X_mde', 'X_pca', 'X_scVI_replicates', 'X_spatial']


Training

In [5]:
# # Import ResolVI with spatial capabilities
# import sys
# sys.path.insert(0, 'src')  # Adjust path to your scvi-tools source
# print("Added to path:", sys.path[0])

# import scvi.external.resolvi as RESOLVI
# import scvi
# print("Importing from:", scvi.external.resolvi.__file__)

# # Import spatial encoder for direct testing
# from scvi.external.resolvi import SpatialEncoder

In [6]:
query_adata.obsm["X_spatial"]

array([[ 703.83251953, 2350.98022461],
       [ 717.39398193, 2364.40039062],
       [ 706.7855835 , 2360.75854492],
       ...,
       [5484.81005859, 7346.80517578],
       [5478.77685547, 7352.08691406],
       [5532.67822266, 7486.91748047]], shape=(810225, 2))

Analysis

In [7]:
query_adata.obs

Unnamed: 0,x_centroid,y_centroid,transcript_counts,control_probe_counts,genomic_control_counts,control_codeword_counts,unassigned_codeword_counts,deprecated_codeword_counts,total_counts,cell_area,...,Topic 14,Topic 15,crypt_villi_axis,epithelial_distance,epithelial_distance_clipped,epithelial_distance_transformed,guide_rnas,cell_types,cluster_cellcharter,epithelial_distance_scaled
aaaackdi-1,703.832520,2350.980225,256,0,0,0,0,0,256,50.123439,...,0.037465,2.026178,0.934038,0.076727,0.071227,-0.936883,Other cells,Enterocytes,1,0.071227
aaaaddii-1,717.393982,2364.400391,487,0,0,0,0,0,487,111.897192,...,0.878993,1.372876,0.920061,0.076727,0.073572,-0.936883,Other cells,Enterocytes,1,0.073572
aaaafjep-1,706.785583,2360.758545,317,0,0,0,0,0,317,73.424065,...,0.296977,1.588671,0.925218,0.076727,0.067487,-0.936883,Other cells,Enterocytes,1,0.067487
aaaaklej-1,709.713379,2355.399902,402,0,0,0,0,0,402,72.791878,...,0.232374,1.598706,0.926149,0.076727,0.073855,-0.936883,Other cells,Enterocytes,1,0.073855
aaaakpai-1,744.997314,2352.819092,172,0,0,0,0,0,172,41.543752,...,1.491616,1.023234,0.906294,0.087120,0.087120,-0.807832,Other cells,Enterocytes,1,0.087120
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
oinanmfc-1,5558.731445,7586.157227,45,0,0,0,0,0,45,43.124220,...,-0.279271,-0.579267,0.387358,0.151869,0.151869,-0.075063,Other cells,Macrophages,2,0.151869
oinaojph-1,5561.004395,7594.268555,25,0,0,0,0,0,25,32.377032,...,-0.279271,-0.579267,0.405426,0.125362,0.125362,-0.360792,Other cells,Plasma Cells,4,0.125362
oinbbbgh-1,5484.810059,7346.805176,32,0,0,0,0,0,32,67.327971,...,1.618375,0.876619,0.920581,0.114040,0.114040,-0.488728,Other cells,Fibroblasts/Mesenchymal Cells,4,0.114040
oinbfame-1,5478.776855,7352.086914,26,0,0,0,0,0,26,37.615158,...,2.359092,0.677170,0.934221,0.116618,0.116618,-0.459282,Other cells,Endothelial Cells,4,0.116618


In [8]:
# Setup anndata with spatial coordinates
RESOLVI.RESOLVI.setup_anndata(
    query_adata, 
    labels_key="cell_types",
    layer="raw",
    batch_key="batch", 
    perturbation_key="guide_rnas", 
    control_perturbation="sgCd19",
    spatial_key="X_spatial",
    background_key="guide_rnas",
    background_category="Other cells",
    categorical_covariate_keys=["batch"]
)



[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


  categorical_mapping = _make_column_categorical(


[34mINFO    [0m Generating sequential column names                                                                        


Effects Calculation

In [9]:
# Initialize model with spatial parameters (following your methodology)
spatial_resolvi = RESOLVI.RESOLVI(
    query_adata,
    semisupervised=True,  # Following your approach
    n_latent=32,          # Following your parameters
    perturbation_hidden_dim=128,  # Following your parameters
    n_input_spatial=2,    # NEW: spatial input dimension
    control_penalty_weight=1.0
)

print("✅ Spatial RESOLVI model initialized successfully")

# Verify spatial encoder is present
has_spatial_encoder = hasattr(spatial_resolvi.module.model, 'spatial_encoder')
print(f"✅ Model has spatial encoder: {has_spatial_encoder}")

if has_spatial_encoder:
    spatial_encoder = spatial_resolvi.module.model.spatial_encoder
    print(f"   Spatial encoder type: {type(spatial_encoder).__name__}")
    print(f"   Spatial encoder input dim: {spatial_encoder.encoder[0].in_features}")
    print(f"   Spatial encoder output dim: {spatial_encoder.mean_encoder.out_features}")

✅ Spatial RESOLVI model initialized successfully
✅ Model has spatial encoder: True
   Spatial encoder type: SpatialEncoder
   Spatial encoder input dim: 2
   Spatial encoder output dim: 32


Spatial Analysis

In [10]:
# Get dataset-dependent priors (following your methodology)
priors = spatial_resolvi.compute_dataset_dependent_priors()
print(f"Dataset priors: {priors}")

spatial_resolvi.module.guide.downsample_counts_mean = priors["mean_log_counts"]
spatial_resolvi.module.guide.downsample_counts_std  = priors["std_log_counts"]

# Convert downsample parameters (following your approach)
spatial_resolvi.module.guide.downsample_counts_mean = float(
    spatial_resolvi.module.guide.downsample_counts_mean
)
spatial_resolvi.module.guide.downsample_counts_std = float(
    spatial_resolvi.module.guide.downsample_counts_std
)

Dataset priors: {'background_ratio': np.float32(0.00067531277), 'median_distance': np.float64(372.9849853515625), 'mean_log_counts': np.float32(5.433722), 'std_log_counts': np.float32(0.8109416)}


Save Results

In [11]:
# Train with perturbation focus (following your methodology)
print("Starting training with spatial integration...")
spatial_resolvi.train(
    max_epochs=100,
    check_val_every_n_epoch=100,
    lr=3e-4,       # Following your parameters
    train_on_perturbed_only=True  # Following your methodology
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/dpatravali/miniforge3/envs/resolvi-env/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Starting training with spatial integration...
Training configuration with train_on_perturbed_only=True:
  Training set: 2465 perturbation-relevant cells (0.3%)
    - Control cells: 454 (0.1%)
    - Perturbed cells: 2011 (0.2%)
  Excluded from training: 807760 background cells (99.7%)
  Background/neighbor computations: all 810225 cells


/home/dpatravali/miniforge3/envs/resolvi-env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.
/home/dpatravali/miniforge3/envs/resolvi-env/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


TypeError: unsupported operand type(s) for /: 'NoneType' and 'float'