# Tutorial 3: Segmentation-free analysis

In [None]:
import numpy as np
import pandas as pd
import cv2
import argparse
import logging
from omics.constants import *
import os
import pickle
import re
import torch
import torch.utils.data as data
from copy import deepcopy
import random
from sklearn.preprocessing import LabelEncoder
from scipy.spatial import cKDTree
import scanpy as sc
from pytorch_lightning.callbacks import (EarlyStopping, LearningRateMonitor,
                                         ModelCheckpoint, Callback)
# import squidpy as sq
import torch.nn as nn
from math import exp
from sklearn.model_selection import train_test_split
import random
from argparse import Namespace
import h5py
from tqdm import tqdm
from omics.datasets.data_module import testDataModule as DataModule
import hashlib
import anndata as ad
from step1_pretrain import Omics, EpochCallback
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import (EarlyStopping, LearningRateMonitor,
                                         ModelCheckpoint)
from argparse import ArgumentParser
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import cv2
import logging
import joblib
import umap
from sklearn.manifold import TSNE
import json
from visualize import *
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

## Initialize configs and fix seed

For efficiency, process all batches and merge them together

In [None]:
import os

# setup environment variables
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# parameters
args = Namespace(
    ckpt_path=os.path.join(BASE_DIR, 'model_checkpoint.ckpt'),
    dataset_name="xenium_hbc1",
    config=os.path.join(BASE_DIR, '../../configs/bert_config.json'),
    f=None,
    split_slice=None,
    target_width=7524,
    target_height=5469,
    x_min=0,
    x_max=7524,
    y_min=0,
    y_max=5469,
    sample_ratio=0.1,
    merge_threshold=0.5,
    percentile=70,
    pca_model_path=os.path.join(BASE_DIR, 'xenium_hbc1_None_-1_7524_4_5469_pca.pkl'),
    color_model_path=os.path.join(BASE_DIR, 'xenium_hbc1_None_-1_7524_4_5469_pca_color.pkl'),
    batch_id=0, # (0-9)
    total_batches=10,
    radius=20,
    linear_hidden_dim='256',
    max_points=20,
    
)

print(f"Processing batch {args.batch_id + 1}/{args.total_batches}")
print(f"ckpt_path: {args.ckpt_path}")
print(f"Dataset: {args.dataset_name}")

## Specify dataset

In [None]:
split_slice = None
if args.dataset_name == 'mop1':
    adata = pd.read_pickle(MOUSE_SPOT_DATA1) 
    adata.rename(columns={'target_molecule_name': 'gene', 'cell_index': 'nucleus', 'slice_id': 'batch', 'global_x': 'x', 'global_y': 'y'}, inplace=True)
    adata['gene'] = adata['gene'].astype('category')
elif args.dataset_name == 'merfish':
    adata = pd.read_pickle(MERFISH_DATA1)
    split_slice = MERFISH_FOLDS[args.fold]['train']
elif args.dataset_name == 'seqfish':
    adata = pd.read_pickle(SEQFISH_DATA1)
elif args.dataset_name == 'AD_64g_m9721':
    adata = pd.read_pickle(AD_64g_9721_DATA)
    adata.rename(columns={'slice': 'batch'}, inplace=True)
elif args.dataset_name == 'AD_64g_m9781':
    adata = pd.read_pickle(AD_64g_9781_DATA)
    adata.rename(columns={'slice': 'batch'}, inplace=True)
elif args.dataset_name == 'AD_64g_m9919':
    adata = pd.read_pickle(AD_64g_9919_DATA)
elif args.dataset_name == 'AD_64g_m9930':
    adata = pd.read_pickle(AD_64g_9930_DATA)
elif args.dataset_name == 'AD_2766g_m9707':
    adata = pd.read_pickle(AD_2766g_9707_DATA)
elif args.dataset_name == 'AD_2766g_m9735':
    adata = pd.read_pickle(AD_2766g_9735_DATA)
elif args.dataset_name == 'AD_2766g_m9723':
    adata = pd.read_pickle(AD_2766g_9723_DATA)
elif args.dataset_name == 'AD_2766g_m9494':
    adata = pd.read_pickle(AD_2766g_9494_DATA)
elif args.dataset_name == 'AD_2766g_m11346':
    adata = pd.read_pickle(AD_2766g_11346_DATA)
elif args.dataset_name == 'AD_2766g_m11351':
    adata = pd.read_pickle(AD_2766g_11351_DATA)
elif args.dataset_name == 'AD_2766g_m9723_2':
    adata = pd.read_pickle(AD_2766g_9723_2_DATA)
elif args.dataset_name == 'xenium_hbc1':
    adata = pd.read_pickle(Xenium_hbc_rep1_DATA)
    split_slice = XENIUM_HBC_FOLDS1[args.fold]['train']
    split_slice += XENIUM_HBC_FOLDS1[args.fold]['val']
elif args.dataset_name == 'xenium_hbc1_rep2':
    adata = pd.read_pickle(Xenium_hbc_rep2_DATA)
    split_slice = XENIUM_HBC_FOLDS1[args.fold]['train']
    split_slice += XENIUM_HBC_FOLDS1[args.fold]['val']
else:
    raise NotImplementedError

## Preprocessing data

In [None]:
# 数据预处理
if args.split_slice:
    if args.split_slice.isdigit():
        split_slice = [int(args.split_slice)]
    else:
        split_slice = [args.split_slice]

if split_slice:
    adata = adata.query("batch.isin(@split_slice)")

# step1: determine resolution and process spots
total_minx, total_maxx, total_miny, total_maxy = get_min_max_coordinates(adata)
scale_x, scale_y, pixel_size_x, pixel_size_y = determine_visualization_resolution(
    total_minx, total_maxx, total_miny, total_maxy, args.target_width, args.target_height)

# coordinate transformation
x_min_prop = args.x_min / args.target_width
x_max_prop = args.x_max / args.target_width
y_min_prop = args.y_min / args.target_height
y_max_prop = args.y_max / args.target_height

x_range = total_maxx - total_minx
y_range = total_maxy - total_miny

args.x_min = int(total_minx + (x_min_prop * x_range))
args.x_max = int(total_minx + (x_max_prop * x_range))
args.y_min = int(total_miny + (y_min_prop * y_range))
args.y_max = int(total_miny + (y_max_prop * y_range))

# filter data
if args.x_min is not None and args.x_max is not None and args.y_min is not None and args.y_max is not None:
    adata = adata.query("(x >= @args.x_min) & (x <= @args.x_max) & (y >= @args.y_min) & (y <= @args.y_max)")

# process spots in batches
patch_adata = select_batch_spots(adata, args.batch_id, args.total_batches)

## Loading dataset

In [None]:
# set label encoder
label_encoder = LabelEncoder()
label_encoder.fit(adata['gene'].values)
annotation_num = patch_adata['gene'].unique().shape[0]

# create data loader
val_split = []
epoch_callback = EpochCallback()
dm = DataModule(SpatialRadiusDataset, my_collate_fn,
                data_pct=1.0, batch_size=1000, 
                num_workers=8, mask_ratio=0, radius=20,
                mask_function='random', dataset_name=args.dataset_name,
                max_points=20, train_split=split_slice, val_split=val_split,
                label_type='pretrain',callback=epoch_callback,
                x_min=args.x_min, x_max=args.x_max, y_min=args.y_min, y_max=args.y_max,
                adata=patch_adata, batch_id=args.batch_id)
dl = dm.inference_train_dataloader()

## Loading model

In [None]:
print(args.config)
model = Omics.load_from_checkpoint(**args.__dict__,checkpoint_path=args.ckpt_path,strict=False)
model.eval()
model = model.to('cuda')

## Generating spot embeddings

In [None]:
train_rep = get_rep(dl, model)

# create save prefix and stats file path
save_prefix = f"vis_results/{args.dataset_name}_{args.split_slice}_{args.x_min}_{args.x_max}_{args.y_min}_{args.y_max}"
os.makedirs("vis_results", exist_ok=True)

# unified stats file path
stats_file_path = f"global_stats_{args.dataset_name}_{args.percentile}.json"

# generate embeddings (using PCA + percentile clipping and normalization)
pca_embeddings, normalized_embeddings, pca_model, pca_color_model = generate_embeddings_with_pca_percentile_clipping(
    train_rep, 
    pca_dim=50, 
    percentile=args.percentile,
    pca_model_path=args.pca_model_path,
    color_model_path=args.color_model_path,
    save_prefix=save_prefix,
    batch_id=args.batch_id,
    stats_file_path=stats_file_path 
)


## Visualization generation and saving

In [None]:
# visualize and save embeddings
output_image, coord_color_file = visualize_and_save_embeddings(
    total_minx, total_maxx, total_miny, total_maxy,
    patch_adata['x'], patch_adata['y'], patch_adata['gene'], patch_adata['cell'],
    normalized_embeddings, save_prefix, args.target_width, args.target_height, args.batch_id)

## Generating adata with colors

In [None]:
# Sample and create adata
final_adata = sample_spots_and_create_adata(
    patch_adata, pca_embeddings, normalized_embeddings, 
    sample_ratio=args.sample_ratio, save_prefix=save_prefix, batch_id=args.batch_id)

## Saving results

In [None]:
# step5: organize results for delivery
results_dir = organize_results_for_delivery(
    save_prefix, args.dataset_name, args.split_slice, 
    args.x_min, args.x_max, args.y_min, args.y_max, args.batch_id)

print(f"\n=== Finished ===")
print(f"Results directory: {results_dir}")
print(f"Final spot number: {len(patch_adata)}")
print(f"AnnData shape: {final_adata.shape}")
print(f"Stats file: {stats_file_path}")