In [6]:
# Cell 1: Imports and Configuration
import os
import sys
import random
import glob
import datetime
import time
import subprocess
import shlex
import re
import warnings
from pathlib import Path
import io

import numpy as np
import PIL.Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from torch_geometric.data import Data, Batch
from torch_geometric.nn import GPSConv, GATConv

from diffusers import AutoencoderTiny, UNet2DConditionModel, DDPMScheduler

warnings.filterwarnings('ignore')

# --- Configuration Class for Diffusion Training V3 ---
class TrainingConfigDiffusion:
    # --- Hardware & Precision ---
    GPU_IDS = [0, 1]
    NUM_GPUS = len(GPU_IDS) if torch.cuda.is_available() else 0
    DDP_MASTER_PORT = 29505 
    MIXED_PRECISION_TYPE = "bf16"

    # --- Data & Model Paths ---
    # [IMPORTANT] This MUST point to your new directory with preprocessed graphs (containing data.latent).
    GRAPH_DATA_DIR = "/cwStorage/nodecw_group/jijh/hest_graph_data_with_vae_latents"      # <--- 必须修改
    CHECKPOINT_DIR = "/cwStorage/nodecw_group/jijh/model_path/conditioned_diffusion_v1_sampleN2"      # <--- 必须修改
    LOG_DIR = "/cwStorage/nodecw_group/jijh/training_log/conditioned_diffusion_v1_sampleN2"                # <--- 必须修改
    VAE_MODEL_PATH = "/cwStorage/nodecw_group/jijh/model_path/finetuned_taesd_v21_notebook_apr2.pt"              # <--- 必须修改'

    # <<< --- NEW CONFIGURATION PARAMETERS START --- >>>
    # [IMPORTANT] Path to your pre-trained conditioner. Set to None if you want to train from scratch.
    PRETRAINED_CONDITIONER_PATH = "/cwStorage/nodecw_group/jijh/model_path/clip_preprocessed_v2/clip_graph_gigapath_preprocessed_ep50_step350_bs64x2_lr0.0001.pt"
    # [IMPORTANT] Set to True to freeze the conditioner and only train the UNet.
    # Set to False to fine-tune both the conditioner and the UNet.
    FREEZE_CONDITIONER = True 
    # <<< --- NEW CONFIGURATION PARAMETERS END --- >>>

    PRETRAINED_UNET_PATH = "/cwStorage/nodecw_group/jijh/model_path/unet_ddp_bf16_ep15_bs32x3_lr0.0001_acc4.pt"
    
    # [NEW] Optional file to specify which samples to train on
    # Set to a file path (e.g., "/path/to/my_sample_ids.txt") or leave as None to train on all.
    TRAIN_SAMPLE_ID_FILE = "/home1/jijh/diffusion_project/ADiffusion/test_sample_ids.txt"                                    # <--- (可选) 修改

    # Path to the new training script
    TRAIN_SCRIPT_PATH = "/home1/jijh/diffusion_project/ADiffusion/src/pipeline/train_condition_diffusion_ddp_v3.py" # <--- 必须修改

    # --- Model Architecture (ensure these match your data) ---
    CONDITIONER_INPUT_DIM = 50
    CONDITIONER_HIDDEN_DIM = 256
    CONDITIONER_OUTPUT_DIM = 768
    CONDITIONER_N_LAYERS = 4
    CONDITIONER_N_HEADS = 4
    CONDITIONER_ATTN_DROPOUT = 0.1
    UNET_SAMPLE_SIZE = 32
    UNET_IN_CHANNELS = 4
    UNET_OUT_CHANNELS = 4
    UNET_BLOCK_OUT_CHANNELS = "320,640,1280,1280"
    UNET_DOWN_BLOCK_TYPES = "DownBlock2D,CrossAttnDownBlock2D,CrossAttnDownBlock2D,DownBlock2D"
    UNET_UP_BLOCK_TYPES = "UpBlock2D,CrossAttnUpBlock2D,CrossAttnUpBlock2D,UpBlock2D"
    UNET_CROSS_ATTENTION_DIM = CONDITIONER_OUTPUT_DIM

    # --- Training Hyperparameters ---
    EPOCHS = 150
    BATCH_SIZE_PER_GPU = 32
    LEARNING_RATE = 5e-5
    ACCUMULATION_STEPS = 2
    NUM_WORKERS = 16
    
    # --- Logging & Saving ---
    CHECKPOINT_FILENAME_PREFIX = "cond_unet_v3"
    SAVE_INTERVAL_EPOCHS = 50
    SAMPLE_INTERVAL_STEPS = 100

    @classmethod
    def get_script_args(cls):
        """Generates command-line arguments for the new training script."""
        args = [
            f"--graph_data_dir={cls.GRAPH_DATA_DIR}", f"--checkpoint_dir={cls.CHECKPOINT_DIR}",
            f"--log_dir={cls.LOG_DIR}", f"--vae_model_path={cls.VAE_MODEL_PATH}",
            f"--epochs={cls.EPOCHS}", f"--batch_size_per_gpu={cls.BATCH_SIZE_PER_GPU}",
            f"--lr={cls.LEARNING_RATE}", f"--accumulation_steps={cls.ACCUMULATION_STEPS}",
            f"--mixed_precision={cls.MIXED_PRECISION_TYPE}", f"--num_workers={cls.NUM_WORKERS}",
            f"--conditioner_input_dim={cls.CONDITIONER_INPUT_DIM}",
            f"--conditioner_hidden_dim={cls.CONDITIONER_HIDDEN_DIM}",
            f"--conditioner_output_dim={cls.CONDITIONER_OUTPUT_DIM}",
            f"--conditioner_n_layers={cls.CONDITIONER_N_LAYERS}",
            f"--conditioner_n_heads={cls.CONDITIONER_N_HEADS}",
            f"--conditioner_attn_dropout={cls.CONDITIONER_ATTN_DROPOUT}",
            f"--unet_sample_size={cls.UNET_SAMPLE_SIZE}", f"--unet_in_channels={cls.UNET_IN_CHANNELS}",
            f"--unet_out_channels={cls.UNET_OUT_CHANNELS}",
            f"--unet_block_out_channels={cls.UNET_BLOCK_OUT_CHANNELS}",
            f"--unet_down_block_types={cls.UNET_DOWN_BLOCK_TYPES}",
            f"--unet_up_block_types={cls.UNET_UP_BLOCK_TYPES}",
            f"--unet_cross_attention_dim={cls.UNET_CROSS_ATTENTION_DIM}",
            f"--save_interval={cls.SAVE_INTERVAL_EPOCHS}",
            f"--sample_interval_steps={cls.SAMPLE_INTERVAL_STEPS}",
            f"--checkpoint_filename_prefix={cls.CHECKPOINT_FILENAME_PREFIX}",
            f"--pretrained_conditioner_path={cls.PRETRAINED_CONDITIONER_PATH}",
            f"--freeze_conditioner={str(cls.FREEZE_CONDITIONER).lower()}",
            f"--pretrained_unet_path={cls.PRETRAINED_UNET_PATH}",
        ]
        # [NEW] Add the sample ID file argument only if it's specified
        if cls.TRAIN_SAMPLE_ID_FILE:
            args.append(f"--train_sample_id_file={cls.TRAIN_SAMPLE_ID_FILE}")
        return args

# Instantiate config and create directories
config_diff = TrainingConfigDiffusion()
os.makedirs(config_diff.CHECKPOINT_DIR, exist_ok=True)
os.makedirs(config_diff.LOG_DIR, exist_ok=True)
print("Config loaded. Please ensure all paths are correct before proceeding.")

Config loaded. Please ensure all paths are correct before proceeding.


In [7]:
# Cell 2: Launch DDP Training Script for Diffusion V3

print("\n--- Preparing to Launch DDP Training for Diffusion Model V3 ---")

if not os.path.exists(config_diff.TRAIN_SCRIPT_PATH):
    print(f"Error: Training script '{config_diff.TRAIN_SCRIPT_PATH}' not found.")
else:
    python_executable = sys.executable
    modified_env = os.environ.copy()
    cuda_visible_devices = ",".join(map(str, config_diff.GPU_IDS))
    modified_env["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices

    cmd = [
        python_executable, "-m", "torch.distributed.run",
        f"--nproc_per_node={config_diff.NUM_GPUS}",
        f"--master_port={config_diff.DDP_MASTER_PORT}",
        config_diff.TRAIN_SCRIPT_PATH,
    ]
    cmd.extend(config_diff.get_script_args())

    print("\nLaunching command:")
    print(shlex.join(cmd))
    print("-" * 30 + "\nScript Output:\n" + "-" * 30)

    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                               text=True, bufsize=1, encoding='utf-8', errors='replace',
                               env=modified_env)
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None: break
        if output: print(output.strip())
    rc = process.poll()
    print("-" * 30 + f"\n--- Script Finished with code {rc} ---")


--- Preparing to Launch DDP Training for Diffusion Model V3 ---

Launching command:
/public/home/jijh/micromamba/envs/gpu_env/bin/python -m torch.distributed.run --nproc_per_node=2 --master_port=29505 /home1/jijh/diffusion_project/ADiffusion/src/pipeline/train_condition_diffusion_ddp_v3.py --graph_data_dir=/cwStorage/nodecw_group/jijh/hest_graph_data_with_vae_latents --checkpoint_dir=/cwStorage/nodecw_group/jijh/model_path/conditioned_diffusion_v1_sampleN2 --log_dir=/cwStorage/nodecw_group/jijh/training_log/conditioned_diffusion_v1_sampleN2 --vae_model_path=/cwStorage/nodecw_group/jijh/model_path/finetuned_taesd_v21_notebook_apr2.pt --epochs=150 --batch_size_per_gpu=32 --lr=5e-05 --accumulation_steps=2 --mixed_precision=bf16 --num_workers=16 --conditioner_input_dim=50 --conditioner_hidden_dim=256 --conditioner_output_dim=768 --conditioner_n_layers=4 --conditioner_n_heads=4 --conditioner_attn_dropout=0.1 --unet_sample_size=32 --unet_in_channels=4 --unet_out_channels=4 --unet_block_out_