# Graph-Conditioned Diffusion Model Fine-tuning (Spatial Transcriptomics) - Precomputed Graphs

## Cell 1: Imports (Unchanged)

In [32]:
import os
import random
import glob
import datetime
import time
import sys
import subprocess
import shlex
import pickle
working_dir = os.path.dirname("/home1/jijh/diffusion_project/ADiffusion/graph_transformer_trial/graph_transformer.ipynb")
os.chdir(working_dir)
# Set the working directory to the script's directory
# This is necessary to ensure that relative imports work correctly
# and that the script can find the required files.
import torch
import numpy as np
import matplotlib.pyplot as plt
try:
    from tqdm.notebook import tqdm as notebook_tqdm
    IS_NOTEBOOK = True
    print("Notebook environment detected.")
except ImportError:
    from tqdm.auto import tqdm as notebook_tqdm
    IS_NOTEBOOK = False
    print("Using standard tqdm.")

# Assume helper files are in the same directory or accessible via sys.path
# Important: We now use SpatialGraphDataset which loads precomputed data
try:
    from dataset import SpatialGraphDataset
    from models import FullModel
    from graph_utils import get_sinusoidal_positional_encoding # Still needed by precompute script
    print("Helper modules imported successfully.")
except ImportError as e:
     print(f"ERROR importing helper modules: {e}. Make sure .py files are in the correct location.")
     # Stop notebook execution if helpers missing
     raise e

%matplotlib inline


Notebook environment detected.
Helper modules imported successfully.


## Cell 2: Configuration Class

In [33]:
class TrainingConfig:
    # --- Sample Filtering ---
    ALLOWED_SAMPLE_IDS = [
        'TENX159', 'TENX138', 'TENX131', 'TENX88', 'TENX87', 'TENX86', 'TENX85', 'TENX84', 'TENX83', 'TENX82',
        'TENX80', 'TENX79', 'TENX78', 'TENX77', 'TENX73', 'TENX69', 'TENX67', 'TENX61', 'TENX60', 'TENX58',
        'TENX56', 'TENX55', 'TENX54', 'TENX52', 'TENX43', 'TENX31', 'TENX30', 'TENX27', 'TENX19', 'TENX18',
        'ZEN61', 'ZEN60', 'SPA15', 'SPA14', 'SPA13', 'SPA12', 'SPA11', 'SPA10', 'SPA9', 'SPA8', 'SPA7', 'SPA6',
        'SPA5', 'SPA4', 'MEND131', 'MEND130', 'MEND129', 'MEND124', 'MEND123', 'MEND78', 'MEND77', 'MEND76',
        'MEND75', 'MEND74', 'MEND73', 'MEND72', 'MEND71', 'MEND68', 'MEND67', 'MEND66', 'MEND65', 'MEND64',
        'MEND63', 'MEND55', 'MEND53', 'MEND50', 'MEND46', 'MEND44', 'MEND43', 'MEND42', 'NCBI809', 'NCBI808',
        'NCBI807', 'NCBI806', 'NCBI802', 'NCBI801', 'NCBI800', 'NCBI799', 'NCBI720', 'NCBI719', 'NCBI718',
        'NCBI717', 'NCBI716', 'NCBI715', 'NCBI671', 'NCBI670', 'NCBI669', 'NCBI668', 'NCBI667', 'NCBI666',
        'NCBI665', 'NCBI664', 'NCBI663', 'NCBI662', 'NCBI661', 'NCBI660', 'NCBI659', 'NCBI658', 'NCBI657',
        'NCBI656', 'NCBI655', 'NCBI654', 'NCBI653', 'NCBI641', 'NCBI640', 'NCBI639', 'NCBI638', 'NCBI637',
        'NCBI636', 'NCBI635', 'NCBI634', 'NCBI633', 'NCBI632', 'NCBI631', 'NCBI630', 'NCBI629', 'NCBI628',
        'NCBI533', 'NCBI532', 'NCBI531', 'NCBI530', 'NCBI529', 'NCBI528', 'NCBI527', 'NCBI410', 'NCBI409',
        'NCBI408', 'NCBI407', 'NCBI406', 'NCBI405', 'NCBI404', 'NCBI403', 'NCBI402', 'NCBI401', 'NCBI400',
        'NCBI399', 'NCBI398', 'NCBI397', 'NCBI396', 'NCBI395', 'NCBI394', 'NCBI393', 'NCBI392', 'NCBI391',
        'NCBI390', 'NCBI389', 'NCBI388', 'NCBI387', 'NCBI386', 'NCBI385', 'NCBI384', 'NCBI383', 'NCBI382',
        'NCBI381', 'NCBI380', 'NCBI379', 'NCBI378', 'NCBI377', 'NCBI376', 'NCBI375', 'NCBI374', 'NCBI373',
        'NCBI372', 'NCBI371', 'NCBI370', 'NCBI369', 'NCBI368', 'NCBI367', 'NCBI366', 'NCBI365', 'NCBI364',
        'NCBI363', 'NCBI362', 'NCBI361', 'NCBI360', 'NCBI359', 'NCBI358', 'NCBI357', 'NCBI356', 'NCBI355',
        'NCBI354', 'NCBI353', 'NCBI352', 'NCBI351', 'NCBI350', 'NCBI349', 'NCBI348', 'NCBI347', 'NCBI346',
        'NCBI345', 'NCBI344', 'NCBI343', 'NCBI342', 'NCBI341', 'NCBI340', 'NCBI339', 'NCBI338', 'NCBI337',
        'NCBI336', 'MISC12', 'MISC11', 'MISC10', 'MISC9', 'MISC8', 'MISC7', 'MISC6', 'MISC5', 'MISC4', 'MISC3',
        'MISC2', 'MISC1'
    ] # Set to None or [] to process all

    # --- Hardware & Precision ---
    GPU_IDS = [0, 1, 2] #<-- CHANGE IF NEEDED
    PRIMARY_GPU_ID = GPU_IDS[0] if GPU_IDS else 0
    NUM_GPUS = len(GPU_IDS) if torch.cuda.is_available() and GPU_IDS else 0
    PRIMARY_DEVICE_NAME = f"cuda:{PRIMARY_GPU_ID}" if NUM_GPUS > 0 else "cpu"
    MIXED_PRECISION_TYPE = "bf16"
    DDP_MASTER_PORT = 29504 # Yet another port

    # --- Data Paths ---
    latent_dir = "/cwStorage/nodecw_group/jijh/hest_output_latents_bf16"
    expr_base_dir = "/cwStorage/nodecw_group/jijh/hest_output"
    # Modify output paths if filtering is applied
    _filter_suffix = "_filtered" if ALLOWED_SAMPLE_IDS else ""
    graph_data_dir = f"/cwStorage/nodecw_group/jijh/hest_graph_data{_filter_suffix}"
    valid_spots_path = os.path.join(graph_data_dir, f"valid_spots_flat{_filter_suffix}.pkl")

    # Output / Model paths
    checkpoint_dir = "/cwStorage/nodecw_group/jijh/model_path/"
    log_dir = f"/cwStorage/nodecw_group/jijh/training_log/logs_graph_diffusion_precomp{_filter_suffix}"
    final_model_save_path = os.path.join(checkpoint_dir, f"graph_unet_precomp{_filter_suffix}_final.pt")

    # --- Preprocessing Parameters ---
    num_raw_genes = None
    pca_target_dim = 50
    normalization_scale_factor = 1e4
    pos_encoding_dim = 128
    graph_k_neighbors = 8
    preprocessing_n_jobs = 8
    force_rebuild_graphs = False

    # --- Model Parameters ---
    vae_model_path = "/cwStorage/nodecw_group/jijh/model_path/finetuned_taesd_v21_notebook_apr2.pt"
    vae_sd_version = 'v2.1'
    unet_sample_size = 64
    unet_in_channels = 4
    unet_out_channels = 4
    unet_block_out_channels = (320, 640, 1280, 1280)
    unet_down_block_types = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D')
    unet_up_block_types = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D')
    unet_cross_attention_dim = 768
    pretrained_unet_checkpoint_path = "/cwStorage/nodecw_group/jijh/model_path/unet_ddp_bf16_ep15_bs32x3_lr0.0001_acc4.pt"
    graph_transformer_hidden_dim = 768
    graph_transformer_dim = 768
    graph_transformer_layers = 6
    graph_transformer_heads = 8

    # --- Fine-tuning Parameters ---
    batch_size_per_gpu = 16
    epochs = 20
    lr_unet = 1e-5
    lr_graph = 1e-4
    optimizer_type = "AdamW"
    accumulation_steps = 4
    scheduler_train_timesteps = 1000

    # --- Logging & Saving ---
    num_workers = 16
    checkpoint_filename_prefix = f"graph_unet_precomp{_filter_suffix}_ddp"
    log_interval = 20
    sample_interval = 1
    save_interval = 1
    sampling_inference_steps = 30
    sampling_batch_size = 8

    # --- Script Paths ---
    preprocess_script_path = "preprocess_graphs.py"
    train_script_path = "train_graph_diffusion_ddp.py"

    @classmethod
    def get_preprocess_cmd(cls):
        """Generates the command to run the preprocessing script."""
        if not os.path.exists(cls.preprocess_script_path): return None
        cmd = [
            sys.executable, cls.preprocess_script_path,
            f"--latent_dir={cls.latent_dir}", f"--expr_base_dir={cls.expr_base_dir}",
            f"--output_graph_dir={cls.graph_data_dir}", f"--valid_spots_output_path={cls.valid_spots_path}",
            f"--num_raw_genes={cls.num_raw_genes}" if cls.num_raw_genes is not None else "--num_raw_genes=-1",
            f"--pca_target_dim={cls.pca_target_dim}", f"--normalization_scale_factor={cls.normalization_scale_factor}",
            f"--pos_encoding_dim={cls.pos_encoding_dim}", f"--graph_k_neighbors={cls.graph_k_neighbors}",
            f"--n_jobs={cls.preprocessing_n_jobs}",
        ]
        if cls.force_rebuild_graphs: cmd.append("--force_rebuild")
        if cls.ALLOWED_SAMPLE_IDS: cmd.append(f"--allowed_ids={','.join(cls.ALLOWED_SAMPLE_IDS)}")
        return cmd

    @classmethod
    def get_train_script_args(cls):
        """Generates CLI arguments for the DDP training script."""
        args = [
            f"--graph_data_dir={cls.graph_data_dir}", f"--valid_spots_path={cls.valid_spots_path}",
            f"--latent_dir={cls.latent_dir}", f"--checkpoint_dir={cls.checkpoint_dir}", f"--log_dir={cls.log_dir}",
            f"--pretrained_unet_checkpoint_path={cls.pretrained_unet_checkpoint_path}",
            f"--vae_model_path={cls.vae_model_path}", f"--vae_sd_version={cls.vae_sd_version}",
            f"--pca_target_dim={cls.pca_target_dim}", f"--pos_encoding_dim={cls.pos_encoding_dim}",
            f"--unet_sample_size={cls.unet_sample_size}", f"--unet_in_channels={cls.unet_in_channels}",
            f"--unet_out_channels={cls.unet_out_channels}",
            f"--unet_block_out_channels={','.join(map(str, cls.unet_block_out_channels))}",
            f"--unet_down_block_types={','.join(cls.unet_down_block_types)}",
            f"--unet_up_block_types={','.join(cls.unet_up_block_types)}",
            f"--unet_cross_attention_dim={cls.unet_cross_attention_dim}",
            f"--graph_transformer_hidden_dim={cls.graph_transformer_hidden_dim}",
            f"--graph_transformer_dim={cls.graph_transformer_dim}",
            f"--graph_transformer_layers={cls.graph_transformer_layers}",
            f"--graph_transformer_heads={cls.graph_transformer_heads}",
            f"--epochs={cls.epochs}", f"--batch_size_per_gpu={cls.batch_size_per_gpu}",
            f"--lr_unet={cls.lr_unet}", f"--lr_graph={cls.lr_graph}", f"--optimizer_type={cls.optimizer_type}",
            f"--accumulation_steps={cls.accumulation_steps}", f"--mixed_precision={cls.MIXED_PRECISION_TYPE}",
            f"--scheduler_train_timesteps={cls.scheduler_train_timesteps}",
            f"--num_workers={cls.num_workers}", f"--checkpoint_filename_prefix={cls.checkpoint_filename_prefix}",
            f"--log_interval={cls.log_interval}", f"--sample_interval={cls.sample_interval}",
            f"--save_interval={cls.save_interval}", f"--sampling_inference_steps={cls.sampling_inference_steps}",
            f"--sampling_batch_size={cls.sampling_batch_size}",
        ]
        return args

# %% Cell 3: Instantiate Config
config = TrainingConfig()
# Make sure output directories exist before running scripts
os.makedirs(config.graph_data_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.log_dir, exist_ok=True)


## Cell 4: Run Preprocessing Script

In [29]:
print("--- Running Preprocessing Script ---")
preprocess_cmd = config.get_preprocess_cmd()

if preprocess_cmd:
    run_preprocessing = config.force_rebuild_graphs or not os.path.exists(config.valid_spots_path)
    if run_preprocessing:
        print("Executing command:")
        print(shlex.join(preprocess_cmd))
        print("-" * 30)
        start_time = time.time()
        result = subprocess.run(preprocess_cmd, capture_output=True, text=True, encoding='utf-8', errors='replace')
        end_time = time.time()
        print(f"Preprocessing script finished in {end_time - start_time:.2f} seconds.")
        print("--- Preprocessing Script STDOUT ---")
        print(result.stdout)
        print("--- Preprocessing Script STDERR ---")
        print(result.stderr)
        if result.returncode != 0:
            print(f"ERROR: Preprocessing script failed with exit code {result.returncode}.")
            raise RuntimeError("Preprocessing script failed.")
        else:
            print("Preprocessing script completed successfully.")
            if not os.path.exists(config.valid_spots_path):
                 raise FileNotFoundError(f"Preprocessing script finished but valid spots file is missing: {config.valid_spots_path}")
            if not glob.glob(os.path.join(config.graph_data_dir, "*.pt")): # Check for graph files too
                 raise FileNotFoundError(f"Preprocessing script finished but no graph files found in: {config.graph_data_dir}")
    else:
        print(f"Skipping preprocessing. Required output file '{config.valid_spots_path}' already exists.")
        print(f"Set config.force_rebuild_graphs = True to re-run.")
else:
    print(f"Error: Preprocessing script '{config.preprocess_script_path}' not found.")
    raise FileNotFoundError("Preprocessing script not found.")

--- Running Preprocessing Script ---
Executing command:
/public/home/jijh/micromamba/envs/gpu_env/bin/python preprocess_graphs.py --latent_dir=/cwStorage/nodecw_group/jijh/hest_output_latents_bf16 --expr_base_dir=/cwStorage/nodecw_group/jijh/hest_output --output_graph_dir=/cwStorage/nodecw_group/jijh/hest_graph_data_filtered --valid_spots_output_path=/cwStorage/nodecw_group/jijh/hest_graph_data_filtered/valid_spots_flat_filtered.pkl --num_raw_genes=-1 --pca_target_dim=50 --normalization_scale_factor=10000.0 --pos_encoding_dim=128 --graph_k_neighbors=8 --n_jobs=8 --allowed_ids=TENX159,TENX138,TENX131,TENX88,TENX87,TENX86,TENX85,TENX84,TENX83,TENX82,TENX80,TENX79,TENX78,TENX77,TENX73,TENX69,TENX67,TENX61,TENX60,TENX58,TENX56,TENX55,TENX54,TENX52,TENX43,TENX31,TENX30,TENX27,TENX19,TENX18,ZEN61,ZEN60,SPA15,SPA14,SPA13,SPA12,SPA11,SPA10,SPA9,SPA8,SPA7,SPA6,SPA5,SPA4,MEND131,MEND130,MEND129,MEND124,MEND123,MEND78,MEND77,MEND76,MEND75,MEND74,MEND73,MEND72,MEND71,MEND68,MEND67,MEND66,MEND65,ME

## Cell 5: Optional - Test Dataset Loading

In [34]:
print("\n--- Testing Dataset Loading (Precomputed) ---")
try:
    test_config = TrainingConfig() # Use current config
    start_time = time.time()
    test_dataset = SpatialGraphDataset(test_config)
    end_time = time.time()
    print(f"Dataset loaded in {end_time - start_time:.2f} seconds.")
    print(f"Number of valid spots: {len(test_dataset)}")
    print(f"Detected node feature dimension: {test_dataset.node_feature_dim}")

    if len(test_dataset) > 0:
        print("\nTesting __getitem__...")
        try:
             from torch_geometric.data import Data
             item = test_dataset[0]
             print("Item retrieved successfully.")
             print("Item type:", type(item))
             if isinstance(item, Data):
                 print("Item keys:", item.keys)
                 print("Node features shape:", item.x.shape)
                 print("Target latent shape:", item.y.shape)
             else: print("Retrieved item is not a PyG Data object.")
        except ImportError: print("PyTorch Geometric not found. Cannot fully test __getitem__ returning Data object.")
        except Exception as e: print(f"Error during __getitem__ test: {e}"); raise e

except Exception as e:
    print(f"Error during dataset test: {e}")
    print("Check paths to precomputed data and valid spots file.")
    raise e



--- Testing Dataset Loading (Precomputed) ---
Loading valid spots list from: /cwStorage/nodecw_group/jijh/hest_graph_data_filtered/valid_spots_flat_filtered.pkl
Loaded 122423 valid spot entries.
Found 41 precomputed graph files.
Detected node feature dimension from graphs: 178
Dataset loaded in 0.05 seconds.
Number of valid spots: 122423
Detected node feature dimension: 178

Testing __getitem__...
ERROR: Failed loading latent file /cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123/MEND123_28618_24487.pt for idx 0: [Errno 2] No such file or directory: '/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123/MEND123_28618_24487.pt'
Error during __getitem__ test: [Errno 2] No such file or directory: '/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123/MEND123_28618_24487.pt'
Error during dataset test: [Errno 2] No such file or directory: '/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123/MEND123_28618_24487.pt'
Check paths to precomputed data and va

FileNotFoundError: [Errno 2] No such file or directory: '/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123/MEND123_28618_24487.pt'

## Cell 6: Optional - Test Model Instantiation

In [None]:
print("\n--- Testing Model Instantiation ---")
try:
    model_config = TrainingConfig()
    # Add detected dimension to config obj passed to model, in case it's needed
    # setattr(model_config, 'graph_input_dim', test_dataset.node_feature_dim)
    test_model = FullModel(model_config)
    print("FullModel initialized.")

    if model_config.pretrained_unet_checkpoint_path and os.path.exists(model_config.pretrained_unet_checkpoint_path):
        print(f"Loading UNet weights from: {model_config.pretrained_unet_checkpoint_path}")
        test_model.load_unet_weights(model_config.pretrained_unet_checkpoint_path)
    else: print("Pretrained UNet path not specified or not found.")

    total_params = sum(p.numel() for p in test_model.parameters())
    trainable_params = sum(p.numel() for p in test_model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

except ImportError as e: print(f"Import Error during model test (likely PyG layers): {e}")
except Exception as e: print(f"Error during model test: {e}"); raise e
