# Tutorial 1: Generate input for pre-train

In [None]:
from step0_spot_input_gen import generate_spot_inputs, load_spot_data_with_full_info
import os
import pickle
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import LabelEncoder
from scipy.spatial import cKDTree
import h5py
from tqdm import tqdm
from argparse import ArgumentParser
import random
from omics.constants import *

## Initialize the args and fix seeds

Available keys for datasets:

seqFISH+ 3T3: seqfish

MERFISH U2-OS: merfish

MERFISH MOp: mop1_filtered

CosMx lung cancer: cosmx_lung5_rep1

Xenium breast cancer: xenium_hbc1

STARmap PLUS AD1: AD_2766g_m9498

STARmap PLUS AD2: AD_2766g_m11351

STARmap PLUS WT1: AD_2766g_m9494

STARmap PLUS WT2: AD_2766g_m11346

In [None]:
config = {
    "dataset": "cosmx_lung5_rep1",
    "k_neighbors": 50,
    "data_pct": 30,
    "max_points": 200,
    "seed": 42,
}


## Generate spot inputs

In [None]:
print("Starting to generate spot input data...")
cache_path = generate_spot_inputs(**config)
print(f"\nData cached to: {cache_path}")

## Test on the processed spots

In [None]:
print("Testing data loading...")
data, max_points, available_gene_idx, full_classes, selected_genes = load_spot_data_with_full_info(cache_path)

print(f"\n=== Loaded data info ===")
print(f"Data shapes:")
for key, value in data.items():
    print(f"  {key}: {value.shape}")

print(f"\nLabel Encoder info:")
print(f"  Total genes: {len(full_classes)}")
print(f"  Selected genes: {len(selected_genes)}")
print(f"  Max points: {max_points}")
print(f"  Available gene indices: {len(available_gene_idx)}")

print(f"\nGene encoding consistency verification:")
print(f"  Unique gene_ids in data: {len(np.unique(data['gene_ids']))}")
print(f"  gene_ids range: [{np.min(data['gene_ids'])}, {np.max(data['gene_ids'])}]")
print(f"  All selected genes exist in full_classes: {all(gene in full_classes for gene in selected_genes)}")

print(f"\nBatch grouping verification:")
print(f"  Unique batch count: {len(np.unique(data['batch_ids']))}")
print(f"  Batch list: {np.unique(data['batch_ids'])}")