# Tutorial 4: Spot-level differential abundance analysis for AD

In [None]:
import matplotlib.pyplot as plt
import scanpy as sc
import anndata as ad
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap
import tifffile
from model.datasets.data_module import DataModule
from model.datasets.pretrain_dataset import (SpatialRadiusDataset, 
                                             my_collate_fn)
import cv2
from argparse import Namespace
from skimage import filters, measure
from scipy import ndimage
from scipy.spatial.distance import cdist
import seaborn as sns
from pytorch_lightning import seed_everything, Trainer
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
import argparse
import os
import gc
import json
from sklearn.decomposition import PCA
from scipy import stats
import meld
import joblib
from step1_pretrain import Omics, EpochCallback
from meld_analysis_2766g import *
import sklearn
from embedding_gen import *
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

# Set matplotlib parameters
mpl.rcParams['pdf.fonttype'] = 42
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.dpi'] = 300

## Initialize configs and fix seed

In [None]:
from argparse import Namespace

# Directly write parameters in config
args = Namespace(
    freeze_bert=False,
    emb_dim=256,  # 128, 256
    num_workers=16,
    learning_rate=1e-3,
    momentum=0.9,
    weight_decay=0.05,
    batch_size=50,
    experiment_name="",
    lambda_1=1.,
    seed=42,
    data_pct=1,
    mask_ratio=0,
    radius=200,
    scale=1,
    ckpt_path=os.path.join(BASE_DIR, 'model_checkpoint.ckpt'),
    dataset_name='AD_2766g_m9723',
    fold='fold_1',
    max_points=100,
    label_type='nucleus',
    linear_hidden_dim=None,
    config=os.path.join(BASE_DIR, '../../configs/bert_config.json'),
    debug=False,  # whether to run debug mode
    output_dir="full_slice_results",  # output directory
    force_reprocess=False,  # whether to force reprocess data
    hidden_dim=768,
    output_dim=768,
    pca_path=None,
    norm_path=None,
    zoom_area=None
)

print(f"Task: {args.task}")
print(f"Dataset: {args.dataset_name}")
print(f"Fold: {args.fold}")
print(f"Batch size: {args.batch_size}")
print(f"Learning rate: {args.learning_rate}")
print(f"Output directory: {args.output_dir}")

def parse_zoom_area(value):
    if value is None:
        return None
    try:
        values = [float(x) for x in value.strip('[]').split(',')]
        if len(values) != 4:
            raise ValueError("Invalid zoom area format: {value}, error: {e}")
        return values
    except Exception as e:
        raise argparse.ArgumentTypeError(f"Invalid zoom area format: {value}, error: {e}")

args.gpus = 1
debug_mode = args.debug  # debug mode flag
print('current mode: debug==', debug_mode)
print('current mode: force_reprocess==', args.force_reprocess)
# create output directory
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "data"), exist_ok=True)

## Generate spot embeddings

### Loading PCA model

In [None]:
if args.pca_path is not None:
    # loading PCA model
    pca = joblib.load(args.pca_path)
    # loading normalization parameters
    min_vals, max_vals = joblib.load(args.norm_path)
else:
    pca = PCA(n_components=32)

## Loading pre-trained SpotFormer

In [None]:
# setting random seed
seed_everything(args.seed)

# loading model
model = Omics.load_from_checkpoint(**args.__dict__)
model.eval()
model.to('cuda')

## Loading datasets and generating embeddings

In [None]:
args.mask_function = 'random'

if args.dataset_name == 'seqfish':
    train_split = SEQFISH_FOLDS[args.fold]['train']
    val_split = SEQFISH_FOLDS[args.fold]['val']
elif args.dataset_name == 'merfish':
    train_split = MERFISH_FOLDS[args.fold]['train']
    val_split = MERFISH_FOLDS[args.fold]['val']
elif args.dataset_name == 'mop1':
    train_split = MOP_FOLDS1[args.fold]['train']
    val_split = MOP_FOLDS1[args.fold]['val']
elif args.dataset_name == 'AD_64g_m9721' or args.dataset_name == 'AD_64g_m9781':
    train_split = AD_FOLDS[args.fold]['train']
    val_split = AD_FOLDS[args.fold]['val']
elif args.dataset_name == 'xenium_hbc1':
    train_split = XENIUM_HBC_FOLDS1[args.fold]['train']
    val_split = XENIUM_HBC_FOLDS1[args.fold]['val']
elif args.dataset_name == 'cosmx_lung5_rep1':
    train_split = COSMX_FOLDS51[args.fold]['train']
    val_split = COSMX_FOLDS51[args.fold]['val']
else:
    raise ValueError(f"Dataset {args.dataset_name} not supported")

# file prefix
file_prefix = f"{args.label_type}_{args.dataset_name}_{args.fold}"
adata_path = os.path.join(args.output_dir, "data", f"{file_prefix}_full_slice.h5ad")


# create data module
datamodule = DataModule(SpatialRadiusDataset, my_collate_fn,
                        args.data_pct, args.batch_size, 
                        args.num_workers, radius=args.radius, 
                        mask_ratio=args.mask_ratio, mask_function=args.mask_function,
                        dataset_name=args.dataset_name, max_points=args.max_points,
                        train_split=train_split, val_split=val_split,
                        label_type=args.label_type)

train_dataloader = datamodule.inference_train_dataloader()
val_dataloader = datamodule.val_dataloader()


### Generate embeddings and save

In [None]:
# process and save data
train_reps, _, _, train_indices = process_and_save_data(
    train_dataloader, model, "train", debug=debug_mode)
val_reps, _, _, val_indices = process_and_save_data(
    val_dataloader, model, "val", debug=debug_mode)

# load full dataset
full_adata, train_adata, val_adata, train_slice, val_slice = load_dataset(
    args.dataset_name, args.fold, train_indices, val_indices, debug=debug_mode)

# extract coordinates
train_coords = train_adata[['x', 'y']].values
val_coords = val_adata[['x', 'y']].values

# create AnnData objects for training and validation
train_anndata = AnnData(obs=train_adata)
train_anndata.obsm['spatial'] = train_coords
train_anndata.obs['set_type'] = 'train'
# use PCA to reduce dimensions to 3D
if args.pca_path is not None:
    reduced_reps = pca.transform(train_reps)
    min_vals = reduced_reps.min(axis=0)
    max_vals = reduced_reps.max(axis=0)
    normalized_reps = (reduced_reps - min_vals) / (max_vals - min_vals)
    train_anndata.obsm['X_pca'] = normalized_reps
else:
    reduced_reps = pca.fit_transform(train_reps)
    min_vals = reduced_reps.min(axis=0)
    max_vals = reduced_reps.max(axis=0)
    normalized_reps = (reduced_reps - min_vals) / (max_vals - min_vals)
    train_anndata.obsm['X_pca'] = normalized_reps
    pca_path = os.path.join(args.output_dir, "data", f"{file_prefix}_pca.pkl")
    joblib.dump(pca, pca_path)

val_anndata = AnnData(obs=val_adata)
val_anndata.obsm['spatial'] = val_coords
val_anndata.obs['set_type'] = 'val'
reduced_reps = pca.transform(val_reps)
min_vals = reduced_reps.min(axis=0)
max_vals = reduced_reps.max(axis=0)
normalized_reps = (reduced_reps - min_vals) / (max_vals - min_vals)
val_anndata.obsm['X_pca'] = normalized_reps

# concatenate training and validation data
full_adata = ad.concat([train_anndata, val_anndata], join="outer")


# save full AnnData object
print(f"Saving full AnnData object to {adata_path}...")
full_adata.obs['nucleus'] = full_adata.obs['nucleus'].astype(str)
full_adata.write_h5ad(adata_path)


## Initialize configs and fix seed

In [None]:
# Directly write parameters in config
args = Namespace(
    data_paths=[
        '/path/to/control_sample1.h5ad',
        '/path/to/AD_sample1.h5ad'
    ],  # change to actual h5ad file path
    conditions=['control', 'AD'],  # corresponding to each dataset condition label
    plaque_sample='AD_2766g_m9723',
    plaque_image_path='/path/to/plaque_protein_image.tif',  # change to actual tif file path
    output_dir='./results',
    target_sample_size=200000,
    beta=67,
    knn=7,
    min_plaque_area=100,
    surrounding_kernel_size=9,
    surrounding_iterations=15
)

print(f"Processing {len(args.data_paths)} datasets")
print(f"Conditions: {args.conditions}")
print(f"Plaque sample: {args.plaque_sample}")
print(f"Output directory: {args.output_dir}")

## Preprocess and sample spot embeddings

In [None]:
print("=== MELD Analysis with Protein Plaque Distance Analysis ===")
print(f"Data paths: {args.data_paths}")
print(f"Conditions: {args.conditions}")
print(f"Plaque sample: {args.plaque_sample}")
print(f"Output directory: {args.output_dir}")

# Create output directories
figures_dir, data_dir = create_output_dirs(args.output_dir)

# Process and sample spot embeddings
sample_adata = process_and_sample_datasets(args.data_paths, args.conditions, 
                                            args.target_sample_size)

## Meld analysis

In [None]:
# Run MELD analysis
sample_densities, sample_likelihoods = run_meld_analysis(sample_adata, args.beta, args.knn)

# Calculate AD likelihood (using first experimental sample as reference)
experimental_samples = [col for col in sample_likelihoods.columns if 'control' not in col.lower()]
if experimental_samples:
    sample_adata.obs['AD_likelihood'] = sample_likelihoods[experimental_samples].mean(axis=1).values
else:
    print("Warning: No experimental samples found for AD likelihood calculation")
    sample_adata.obs['AD_likelihood'] = sample_likelihoods.mean(axis=1).values

# Save MELD results
meld_results_path = os.path.join(data_dir, 'sample_adata_with_meld.h5ad')
sample_adata.write_h5ad(meld_results_path)
print(f"MELD results saved to: {meld_results_path}")

## Visualize Meld analysis results

In [None]:
# Plot MELD results
colors = ['#479EA2', '#C16AAF']
plot_and_save(
    lambda **kwargs: sc.pl.umap(sample_adata, color=['condition'], 
                                palette=colors, **kwargs),
    'umap_condition', figures_dir
)


## Process plaque data

In [None]:
print(f"\n=== Processing Plaque Analysis for {args.plaque_sample} ===")

# Get subset for plaque analysis
sample_adata_subset = sample_adata[sample_adata.obs['dataset'] == args.plaque_sample].copy()

if len(sample_adata_subset) == 0:
    print(f"Warning: No data found for sample {args.plaque_sample}")

# Process plaque image
abeta_img, abeta_gray, binary_mask_closed, labeled_mask, plaque_df = process_plaque_image(
    args.plaque_image_path, args.min_plaque_area)

# Create region masks
region_mask = create_region_masks(binary_mask_closed, args.surrounding_kernel_size, 
                                args.surrounding_iterations)

## Calculate distances to plaques

In [None]:
# Calculate distances to plaques
min_distances, nearest_plaque_ids = calculate_distances_to_plaques(
    sample_adata_subset, plaque_df, abeta_img)

# Add results to subset
sample_adata_subset.obs['nearest_plaque_distance'] = min_distances
sample_adata_subset.obs['nearest_plaque_id'] = nearest_plaque_ids

# Set invalid distances
sample_adata_subset.obs.loc[min_distances == float('inf'), 'nearest_plaque_id'] = -1
sample_adata_subset.obs.loc[min_distances == float('inf'), 'nearest_plaque_distance'] = -1

## Analysis AD likelihood by distance

In [None]:
# Analyze AD likelihood by distance
analysis_results = analyze_ad_likelihood_by_distance(sample_adata_subset, figures_dir)

# Update original dataset
mask = sample_adata.obs['dataset'] == args.plaque_sample
sample_adata.obs.loc[mask, 'nearest_plaque_distance'] = min_distances
sample_adata.obs.loc[mask, 'nearest_plaque_id'] = nearest_plaque_ids

## Save final results

In [None]:
# Save final results
final_results_path = os.path.join(data_dir, 'sample_adata_with_plaque_analysis.h5ad')
sample_adata.write_h5ad(final_results_path)
print(f"Final results saved to: {final_results_path}")

# Save plaque data
plaque_df.to_csv(os.path.join(data_dir, f'plaque_information_{args.plaque_sample}.csv'), 
                    index=False)

# Save analysis results
if analysis_results:
    with open(os.path.join(data_dir, f'distance_analysis_results_{args.plaque_sample}.json'), 'w') as f:
        json.dump(analysis_results, f, indent=2, default=str)

print(f"\n=== Analysis Complete ===")
print(f"Final dataset shape: {sample_adata.shape}")
print(f"Plaques identified: {len(plaque_df)}")
if analysis_results and analysis_results['significant']:
    print(f"âœ“ Significant difference found (p={analysis_results['p_value']:.6f})")
    print(f"  Close to plaque mean: {analysis_results['close_mean']:.4f}")
    print(f"  Far from plaque mean: {analysis_results['far_mean']:.4f}")
else:
    print("No significant difference found between close/far groups")