In [34]:
import scanpy as sc
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import skimage.io as io
import random
from torch_geometric.data import Data
os.chdir('/public/home/jijh/diffusion_project/ADiffusion')
import importlib
import src.preprocessing.data_process
importlib.reload(src.preprocessing.data_process)
from src.preprocessing.data_process import extract_patches, create_graph_data_dict, construct_affinity_matrix

In [35]:
def seed_everything(seed=0):
    """Initialize random seeds for reproducibility."""
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    
# Set random seed for reproducibility
random_seed = 0
seed_everything(random_seed)

# Data Preprocessing

In [36]:
from tqdm import tqdm

file_dir = "/public/home/jijh/st_project/cellbin_analysis/spatial_variation/wx_data/"  # Directory containing the data files
files = os.listdir(file_dir)  # List all files in the directory
files = [i for i in files if i.endswith(".h5ad") and "month" in i]  # Filter files to include only those ending with ".h5ad" and containing "month"
file_paths = [os.path.join(file_dir, i) for i in files]  # Create full file paths for the filtered files
adatas = {}  # Initialize an empty dictionary to store AnnData objects

# Read each file and store the AnnData object in the dictionary
for i in range(len(file_paths)):
    adatas[files[i].split(".")[0]] = sc.read(file_paths[i])

# Preprocess each AnnData object
for key in tqdm(adatas.keys(), desc="Preprocessing datasets"):
    sc.pp.normalize_total(adatas[key], target_sum=1e4)  # Normalize counts per cell
    sc.pp.log1p(adatas[key])  # Logarithmize the data
    adatas[key].layers["raw"] = adatas[key].X.copy()  # Store the raw data in the "raw" layer
    sc.pp.scale(adatas[key], max_value=10)  # Scale the data to have a maximum value of 10
    sc.tl.pca(adatas[key], svd_solver="arpack")  # Perform PCA

In [None]:
# Extract spatial coordinates for each cell
cell_coords = {}
for key in adatas.keys():
    cell_coords[key] = adatas[key].obsm["spatial"].copy()
neighbors = {}
for key in cell_coords.keys():
    neighbors[key] = construct_affinity_matrix(cell_coords[key], mode='number', cutoff=10)

# Load the plaque dataset
img_dir = "/public/home/jijh/st_project/cellbin_analysis/spatial_variation/wx_data/protein_seg_result/"
img_files = os.listdir(img_dir)
img_files = [i for i in img_files if i.endswith(".tiff") and "plaque" in i]
# Read the images

imgs = {}
for i in range(len(img_files)):
    imgs[img_files[i].split(".")[0]] = io.imread(os.path.join(img_dir, img_files[i]))
imgs.keys()
# Rename the imgs to match the adata keys
for key in list(imgs.keys()):
    parts = key.split("_")
    if len(parts) > 1:
        new_key = parts[1] + "_" + parts[2]
        imgs[new_key] = imgs.pop(key)


# Extract patches from the images
patches = {}
for key in imgs.keys():
    patches[key] = extract_patches(imgs[key], cell_coords[key], patch_size=128)

# Convert the patches to binary and calculate the area of positive pixels for each patch
binary_patches = {}
for key in patches.keys():
    binary_patches[key] = [patch > 0 for patch in patches[key]]
# Calculate the area of positive pixels for each patch
areas = {}
for key in binary_patches.keys():
    areas[key] = [np.sum(patch) for patch in binary_patches[key]]

In [37]:
# Create the graph dictionary
graph_data_dict = create_graph_data_dict(adatas, areas, neighbors, cell_coords, embeddings=["X"])


# Efficiently convert patches to tensors in batch
for key, graph in tqdm(graph_data_dict.items(), desc="Adding patches to graph data"):
    # Ensure patches[key] is a list of NumPy arrays
    patches_tensor = torch.tensor(np.array(patches[key]), dtype=torch.float)  # Convert patches to a single tensor efficiently
    graph.patches = patches_tensor
# Normalize the edge_attr
for key, graph in tqdm(graph_data_dict.items(), desc="Normalizing edge_attr"):
    graph.edge_attr = graph.edge_attr / graph.edge_attr.max()

# VAE

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Reload the saved dictionaries
save_dir = "/public/home/jijh/diffusion_project/data_storage"
graph_data_dict = torch.load(os.path.join(save_dir,"graph_data_dict.pth"))
positive_nodes_dict = torch.load(os.path.join(save_dir,"positive_nodes_dict.pth"))

graph_data_dict

In [40]:
# Import transform functions from torchvision
from torchvision.transforms import transforms

# Create a dataset of only positive patches
class PositivePatchDataset(Dataset):
    def __init__(self, patches, device):
        self.patches = patches.to(device)
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
        ])

    def __len__(self):
        return self.patches.size(0)

    def __getitem__(self, idx):
        patch = self.patches[idx]  # [1,128,128]
        # Convert to PIL and back to tensor for augmentations
        patch_np = (patch.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
        # Now apply transforms
        patch_pil = transforms.ToPILImage()(patch_np)
        patch_pil = self.transform(patch_pil)
        patch_tensor = transforms.ToTensor()(patch_pil).to(self.patches.device)  # [1,128,128], float in [0,1]
        return patch_tensor
    
# Filter for positive nodes
positive_pathces = []

for key, data in graph_data_dict.items():
    positive_nodes = positive_nodes_dict[key].cpu()
    positive_patches_sample = data.patches[positive_nodes].permute(0, 3, 1, 2)  # [num_positive, 1, 128, 128]
    positive_pathces.append(positive_patches_sample)

# Concatenate all positive patches
positive_patches_sample = torch.cat(positive_pathces, dim=0)

# Create a DataLoader for the positive patches
positive_dataset = PositivePatchDataset(positive_patches_sample, device)
plaque_loader = DataLoader(positive_dataset, batch_size=32, shuffle=True)


In [41]:
from diffusers import AutoencoderKL

# Define the autoencoder model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vae = AutoencoderKL(in_channels=1, out_channels=1, down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), 
                    up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), block_out_channels=(64, 128, 256, 512), latent_channels = 16).to(device)

In [42]:
# Import necessary libraries
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import numpy as np
from PIL import Image

# Define the PositivePatchDataset
class PositivePatchDataset(Dataset):
    def __init__(self, patches):
        """
        Args:
            patches (torch.Tensor): Tensor of patches with shape [N, 1, 128, 128]
        """
        self.patches = patches  # Keep patches on CPU
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ToTensor(),  # Ensure tensor is in [0,1]
        ])

    def __len__(self):
        return self.patches.size(0)

    def __getitem__(self, idx):
        patch = self.patches[idx]  # [1,128,128]
        patch_np = (patch.squeeze(0).cpu().numpy() * 255).astype(np.uint8)  # Convert to [0,255] uint8
        patch_pil = Image.fromarray(patch_np, mode='L')  # Convert to PIL Image in grayscale
        patch_transformed = self.transform(patch_pil)  # Apply augmentations
        return patch_transformed  # [1,128,128], float in [0,1]


In [43]:
# Assuming graph_data_dict and positive_nodes_dict are already defined
# and device is set (e.g., device = torch.device("cuda" if torch.cuda.is_available() else "cpu"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Collect positive patches
positive_patches = []

for key, data in graph_data_dict.items():
    positive_nodes = positive_nodes_dict[key].cpu()  # Ensure on CPU
    positive_patches_sample = data.patches[positive_nodes].permute(0, 3, 1, 2)  # [num_positive, 1, 128, 128]
    positive_patches.append(positive_patches_sample)

# Concatenate all positive patches
positive_patches_sample = torch.cat(positive_patches, dim=0)  # [total_positive, 1, 128, 128]

# Create the dataset
positive_dataset = PositivePatchDataset(positive_patches_sample)

# Split into training and validation sets (e.g., 80% train, 20% val)
train_size = int(0.8 * len(positive_dataset))
val_size = len(positive_dataset) - train_size
train_dataset, val_dataset = random_split(positive_dataset, [train_size, val_size])

# Create DataLoaders
batch_size = 64
num_workers = 0  # Set to 0 for notebook environments

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True if torch.cuda.is_available() else False  # Optimize data transfer to GPU
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True if torch.cuda.is_available() else False
)

# Combine into a dictionary for easy access
plaque_loader = {'train': train_loader, 'val': val_loader}


In [44]:
# Get the values range of the patches from the positive_dataset
patch_values = []
for patch in positive_dataset:
    patch_values.append(patch.numpy().flatten())

patch_values = np.array(patch_values)
patch_min = patch_values.min()
patch_max = patch_values.max()

patch_min, patch_max


In [45]:
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # Use tqdm for notebook
import os

# Define Loss Function and Optimizer
reconstruction_loss_fn = nn.BCEWithLogitsLoss(reduction='mean')  # Binary Cross-Entropy with Logits Loss
optimizer = optim.Adam(vae.parameters(), lr=1e-4)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # Optional

# Training Parameters
num_epochs = 20
log_interval = 100      # How often to log training status
save_interval = 5       # How often to save the model
visualize_interval = max(1, num_epochs // 5)  # Visualize every 1/5 of total epochs
warmup_epochs = 10       # Number of epochs to warm-up the KL term

# Create a directory to save models
os.makedirs('vae_checkpoints', exist_ok=True)

# Lists to store loss values for plotting
train_losses = []
val_losses = []

# Function to visualize reconstructions
def visualize_reconstructions(model, dataloader, device, epoch, num_images=5):
    model.eval()
    with torch.no_grad():
        data = next(iter(dataloader))  # Retrieve only data
        data = data.to(device)
        latent_dist = model.encode(data).latent_dist  # Obtain latent distribution
        recon_data = model.decode(latent_dist.sample()).sample  # Decode sampled latent vectors

        # Move tensors to CPU and convert to numpy
        data = data.cpu().numpy()
        recon_data = recon_data.cpu().numpy()

    plt.figure(figsize=(10, 4))
    for i in range(num_images):
        # Original Image
        plt.subplot(2, num_images, i + 1)
        plt.imshow(data[i].squeeze(), cmap='gray')
        plt.title("Original")
        plt.axis('off')
        
        # Reconstructed Image
        plt.subplot(2, num_images, i + 1 + num_images)
        plt.imshow(recon_data[i].squeeze(), cmap='gray')
        plt.title("Reconstructed")
        plt.axis('off')
    
    plt.suptitle(f"Reconstruction at Epoch {epoch}")
    plt.tight_layout()
    plt.show()



In [46]:
# Training and Validation Loop
for epoch in range(1, num_epochs + 1):
    # Training Phase
    vae.train()
    epoch_train_loss = 0.0
    
    # Initialize tqdm progress bar for training
    train_bar = tqdm(plaque_loader['train'], desc=f"Epoch {epoch}/{num_epochs} [Train]", leave=False)
    
    for batch_idx, data in enumerate(train_bar):
        data = data.to(device)  # Move data to device
        optimizer.zero_grad()
        
        # Forward pass: Encode and Decode
        latent_dist = vae.encode(data).latent_dist  # Obtain latent distribution
        recon_data = vae.decode(latent_dist.sample()).sample  # Decode sampled latent vectors
        
        # Extract mu and logvar
        mu = latent_dist.mean
        logvar = latent_dist.logvar
        
        # Compute Loss
        recon_loss = reconstruction_loss_fn(recon_data, data)
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        # Calculate annealing factor
        annealing_factor = min(1.0, epoch / warmup_epochs)
        loss = recon_loss + kl_loss * annealing_factor
        
        # Backward and Optimize
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=10.0)  # Optional: Gradient clipping
        optimizer.step()
        
        # Update training loss
        epoch_train_loss += loss.item()
        
        # Update tqdm progress bar with loss
        train_bar.set_postfix({'Loss': loss.item(), 'Recon': recon_loss.item(), 'KL': kl_loss.item()})
    
    # Calculate average training loss for the epoch
    average_train_loss = epoch_train_loss / len(plaque_loader['train'])
    train_losses.append(average_train_loss)
    
    # Validation Phase
    vae.eval()
    epoch_val_loss = 0.0
    
    with torch.no_grad():
        val_bar = tqdm(plaque_loader['val'], desc=f"Epoch {epoch}/{num_epochs} [Val]", leave=False)
        for data in val_bar:
            data = data.to(device)
            latent_dist = vae.encode(data).latent_dist  # Obtain latent distribution
            recon_data = vae.decode(latent_dist.sample()).sample  # Decode sampled latent vectors
            
            # Extract mu and logvar
            mu = latent_dist.mean
            logvar = latent_dist.logvar
            
            # Compute Loss
            recon_loss = reconstruction_loss_fn(recon_data, data)
            kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss = recon_loss + kl_loss * annealing_factor
            
            # Update validation loss
            epoch_val_loss += loss.item()
            
            # Update tqdm progress bar with loss
            val_bar.set_postfix({'Loss': loss.item(), 'Recon': recon_loss.item(), 'KL': kl_loss.item()})
    
    # Calculate average validation loss for the epoch
    average_val_loss = epoch_val_loss / len(plaque_loader['val'])
    val_losses.append(average_val_loss)
    
    print(f"Epoch [{epoch}/{num_epochs}] | Train Loss: {average_train_loss:.4f} | Val Loss: {average_val_loss:.4f}")
    
    # # Step the scheduler
    # scheduler.step()
    
    # Save checkpoint periodically
    if epoch % save_interval == 0 or epoch == num_epochs:
        checkpoint_path = f'vae_checkpoints/vae_bce_hidden16_epoch_{epoch}.pth'
        torch.save(vae.state_dict(), checkpoint_path)
        print(f"Saved model checkpoint at {checkpoint_path}")
    
    # Visualize Reconstructions at specified intervals
    if epoch % visualize_interval == 0 or epoch == num_epochs:
        visualize_reconstructions(vae, plaque_loader['val'], device, epoch, num_images=5)

# Plot Training and Validation Losses
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid(True)
plt.show()


# CLIP + Diffusion

In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
from diffusers import AutoencoderKL
from torch_geometric.nn import GATConv
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

############################################
# 1) 定义 GAT 模型
############################################
class GAT_Embedder(nn.Module):
    """
    多层 GAT + LayerNorm + skip + MLP投影到 embed_dim.
    """
    def __init__(self,
                 in_features=2766,  # 你的节点特征维度
                 gat_hidden=128,    # 增加隐藏单元数
                 heads=8,           # 增加注意力头数
                 dropout=0.5,
                 embed_dim=256):    # 增加 embedding 维度
        super().__init__()
        self.gat1 = GATConv(in_channels=in_features, out_channels=gat_hidden,
                            heads=heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(gat_hidden * heads)

        self.gat2 = GATConv(in_channels=gat_hidden * heads, out_channels=gat_hidden,
                            heads=1, dropout=dropout)
        self.norm2 = nn.LayerNorm(gat_hidden)

        # 最终投影到 embedding 维度
        self.mlp = nn.Sequential(
            nn.Linear(gat_hidden, embed_dim),
            nn.LayerNorm(embed_dim),
        )
        self.dropout = dropout

    def forward(self, x, edge_index):
        # GATConv block 1
        x1 = self.gat1(x, edge_index)  # [N, gat_hidden*heads]
        x1 = self.norm1(x1)
        x1 = F.elu(x1)
        x1 = F.dropout(x1, p=self.dropout, training=self.training)

        # GATConv block 2
        x2 = self.gat2(x1, edge_index)  # [N, gat_hidden]
        if x2.size(-1) == x1.size(-1):
            x2 = x2 + x1  # skip
        x2 = self.norm2(x2)
        x2 = F.elu(x2)
        x2 = F.dropout(x2, p=self.dropout, training=self.training)

        # MLP投影
        out = self.mlp(x2)  # [N, embed_dim]
        return out

############################################
# 2) 定义对比损失 (CLIP-style)
############################################
def clip_style_loss(node_emb, img_emb, temperature=1.0):
    """
    基于对称 CrossEntropy 的CLIP-style对比损失：
      - sim[i,j] = dot( normalize(node_emb[i]), normalize(img_emb[j]) ) / temperature
      - target: (i,i) 为正样本
    """
    B = node_emb.size(0)
    node_norm = F.normalize(node_emb, dim=-1)
    img_norm = F.normalize(img_emb, dim=-1)

    sim = node_norm @ img_norm.t()  # [B,B]
    sim = sim / temperature  # 调整温度

    target = torch.arange(B, device=node_emb.device)

    loss_g2i = F.cross_entropy(sim, target)       # graph->image
    loss_i2g = F.cross_entropy(sim.t(), target)   # image->graph
    return 0.5 * (loss_g2i + loss_i2g)

############################################
# 3) 定义 训练器 (GraphCLIPTrainer)
############################################
class GraphCLIPTrainer:
    """
    一个包含:
      - 模型 (GAT_Embedder)
      - VAE (只encode)
      - 训练/验证逻辑
      - 学习率调度 & 早停 & checkpoint 保存
      - methodB: 每个 mini-batch 都重新 forward 整张图

    注意: 如果图很大，这种做法会非常耗时和显存。
    """
    def __init__(
        self,
        model: nn.Module,
        vae: AutoencoderKL,
        graph_data_dict: dict,  # {key: Data(...)}
        device: torch.device,
        lr=1e-3,
        weight_decay=1e-5,
        max_epochs=10,
        batch_size=128,
        val_batch_size=256,   # 验证时 VAE 编码的 batch_size
        temperature=0.07,
        patience=5,
        save_dir="./graphclip_checkpoints",
        log_dir="./logs"       # TensorBoard log directory
    ):
        self.model = model
        self.vae = vae
        self.graph_data_dict = graph_data_dict
        self.device = device

        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.val_batch_size = val_batch_size  # 用于分批验证
        self.temperature = temperature
        self.patience = patience
        self.save_dir = save_dir

        # (A) 优化器(第一组)
        self.optimizer = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        # (B) 第二组 param: 图像投影层
        #   => 这就意味着 self.optimizer.param_groups = 2
        # 做图像投影层 ([16]->[256])，可训练
        self.img_projector = nn.Sequential(
            nn.Linear(16, model.mlp[-1].normalized_shape[0], bias=False),
            nn.LayerNorm(model.mlp[-1].normalized_shape[0])
        ).to(device)

        self.optimizer.add_param_group({"params": self.img_projector.parameters()})

        # (C) 学习率调度器 => 因为我们有2个 param group => 需要2个 min_lr
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=2,
            verbose=True,
            min_lr=[1e-6, 1e-6]  # 确保与 param groups 数量一致
        )

        # 梯度裁剪
        self.max_grad_norm = 5.0

        # early stopping
        self.best_val_loss = float('inf')
        self.no_improve_epochs = 0

        # TensorBoard
        self.writer = SummaryWriter(log_dir=log_dir)

        # 创建保存目录
        os.makedirs(save_dir, exist_ok=True)

    def _vae_encode_project(self, patches: torch.Tensor) -> torch.Tensor:
        """
        给一批 patch 做 VAE 编码 => [B,16,h',w'] => pooled => [B,16] => 投影 => [B,256].
        """
        with torch.no_grad():
            latent_dist = self.vae.encode(patches).latent_dist
            z = latent_dist.sample()
            z_pooled = z.mean(dim=[2,3])
        img_emb = self.img_projector(z_pooled)
        return img_emb

    def visualize_alignment(self, node_embs, img_embs, epoch):
        """
        使用 t-SNE 将高维嵌入降维到 2D，并绘制节点嵌入与图像嵌入的对齐情况。
        将可视化图保存并记录到 TensorBoard。
        """
        # 将嵌入从 GPU 转移到 CPU 并转换为 NumPy
        node_embs = node_embs.detach().cpu().numpy()
        img_embs = img_embs.detach().cpu().numpy()

        # 使用 t-SNE 进行降维
        tsne = TSNE(n_components=2, random_state=42)
        all_embeds = np.concatenate([node_embs, img_embs], axis=0)
        all_embeds_2d = tsne.fit_transform(all_embeds)

        node_embs_2d = all_embeds_2d[:len(node_embs)]
        img_embs_2d = all_embeds_2d[len(node_embs):]

        # 绘制散点图
        plt.figure(figsize=(10, 10))
        plt.scatter(node_embs_2d[:,0], node_embs_2d[:,1], c='blue', label='Node Embeddings', alpha=0.5)
        plt.scatter(img_embs_2d[:,0], img_embs_2d[:,1], c='red', label='Image Embeddings', alpha=0.5)

        # 绘制对应关系的线条
        for i in range(min(len(node_embs), len(img_embs))):
            plt.plot(
                [node_embs_2d[i,0], img_embs_2d[i,0]],
                [node_embs_2d[i,1], img_embs_2d[i,1]],
                c='gray',
                linewidth=0.5,
                alpha=0.3
            )

        plt.legend()
        plt.title(f'CLIP Alignment at Epoch {epoch}')
        plt.xlabel('t-SNE Dimension 1')
        plt.ylabel('t-SNE Dimension 2')

        # 将图形保存到 TensorBoard
        self.writer.add_figure('CLIP/Alignment', plt.gcf(), global_step=epoch)
        plt.close()

    def train_one_graph_fullbatch_strictB(self, data, epoch_idx) -> float:
        """
        Method B (严格版):
          - 对于同一个 data
          - 每次 mini-batch:
             1) 先整图 forward => node_emb_all
             2) 取 batch节点 => 计算对比损失 => backward & step
          - 重复 step 若干次
        返回: 该图所有 mini-batch 的平均loss
        """
        self.model.train()
        self.img_projector.train()
        
        x = data.x.to(self.device)
        edge_index = data.edge_index.to(self.device)
        patches = data.patches.permute(0,3,1,2).float().to(self.device)
        y = data.y.view(-1).to(self.device)

        pos_idx = (y > 0).nonzero(as_tuple=True)[0]

        if pos_idx.numel() == 0:
            # 若全是负 => 无法对比
            return 0.0

        # 随机打乱
        pos_idx = pos_idx[torch.randperm(pos_idx.numel())]

        # 计算 steps
        steps = pos_idx.numel() // self.batch_size

        total_loss = 0.0

        for step in range(steps):
            # 1) 整图 forward => [N,256]
            node_emb_all = self.model(x, edge_index)

            # 2) 取 mini-batch
            batch_nodes = pos_idx[step*self.batch_size : (step+1)*self.batch_size]  # [B]

            batch_node_emb = node_emb_all[batch_nodes]  # [B,256]

            # 3) VAE encode => [B,256]
            batch_patches = patches[batch_nodes]
            batch_img_emb = self._vae_encode_project(batch_patches)

            # 4) contrastive loss
            loss = clip_style_loss(batch_node_emb, batch_img_emb, temperature=self.temperature)

            # 5) backward & update
            self.optimizer.zero_grad()
            loss.backward()
            # 梯度裁剪
            nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.max_grad_norm)
            nn.utils.clip_grad_norm_(self.img_projector.parameters(), max_norm=self.max_grad_norm)
            self.optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / max(steps, 1)
        return avg_loss

    def train_one_epoch(self, epoch_idx):
        """
        对 graph_data_dict 中的每个 Data 调用 train_one_graph_fullbatch_strictB
        并返回此 epoch 的平均loss

        - 在这里不使用 tqdm
        """
        self.model.train()
        self.img_projector.train()

        epoch_loss = 0.0
        count = 0

        for key, data in self.graph_data_dict.items():
            data_loss = self.train_one_graph_fullbatch_strictB(data, epoch_idx)
            epoch_loss += data_loss
            count += 1
            # 记录到 TensorBoard
            self.writer.add_scalar('Train/Loss_per_graph', data_loss, epoch_idx * len(self.graph_data_dict) + count)

        avg_loss = epoch_loss / max(count, 1)
        self.writer.add_scalar('Train/Loss_avg', avg_loss, epoch_idx)
        print(f"  => [Epoch {epoch_idx}] train avg contrastive loss = {avg_loss:.4f}")
        return avg_loss

    def validate(self, epoch_idx) -> float:
        """
        在验证时，只采样一小部分positive节点来做对比损失，避免显存爆。
        同时收集 embeddings 以用于可视化。
        """
        self.model.eval()
        self.img_projector.eval()

        val_loss = 0.0
        val_count = 0

        # 用于可视化的嵌入
        visualize_node_embs = []
        visualize_img_embs = []

        for key, data in self.graph_data_dict.items():
            x = data.x.to(self.device)
            edge_index = data.edge_index.to(self.device)
            patches = data.patches.permute(0,3,1,2).float().to(self.device)
            y = data.y.view(-1).to(self.device)

            pos_idx = (y > 0).nonzero(as_tuple=True)[0]

            if pos_idx.numel() == 0:
                # 全是负，跳过
                continue

            N = pos_idx.numel()
            # 只随机采样 M 个正节点
            M = min(N, 200)  # 例如最多200个正节点
            sampled_pos_idx = pos_idx[torch.randperm(N)[:M]]

            # 整图 forward (Method B 想法)
            node_emb_all = self.model(x, edge_index)  # [N,256]

            # 只取这 M 个节点
            node_emb_sample = node_emb_all[sampled_pos_idx]  # [M,256]
            patches_sample = patches[sampled_pos_idx]

            # 做 VAE 编码
            img_emb_sample = self._vae_encode_project(patches_sample)

            if M > 1:
                loss = clip_style_loss(node_emb_sample, img_emb_sample, temperature=self.temperature)
                val_loss += loss.item()
                val_count += 1
                # 记录到 TensorBoard
                self.writer.add_scalar('Validate/Loss_per_graph', loss.item(), epoch_idx * len(self.graph_data_dict) + val_count)

                # 收集嵌入以用于可视化
                visualize_node_embs.append(node_emb_sample)
                visualize_img_embs.append(img_emb_sample)

        avg_val_loss = val_loss / max(val_count, 1)
        self.writer.add_scalar('Validate/Loss_avg', avg_val_loss, epoch_idx)
        print(f"  => [Validate] avg contrastive loss = {avg_val_loss:.4f}")

        # 如果有收集到嵌入，则进行可视化
        if visualize_node_embs and visualize_img_embs:
            # 将列表中的张量拼接成一个大的张量
            node_embs = torch.cat(visualize_node_embs, dim=0)
            img_embs = torch.cat(visualize_img_embs, dim=0)
            # 可视化对齐
            self.visualize_alignment(node_embs, img_embs, epoch_idx)

        return avg_val_loss

    def fit(self):
        """
        主训练循环: 
          for epoch in range(1, max_epochs+1):
            train_one_epoch -> validate -> scheduler -> early_stop
        """
        progress_bar = tqdm(range(1, self.max_epochs+1), desc="Training Progress")
        for epoch in progress_bar:
            train_loss = self.train_one_epoch(epoch)
            val_loss = self.validate(epoch)

            # 调度器
            self.scheduler.step(val_loss)
            progress_bar.set_postfix({"train_loss": train_loss, "val_loss": val_loss})

            # 记录到 TensorBoard
            self.writer.add_scalar('Train/Loss_avg_epoch', train_loss, epoch)
            self.writer.add_scalar('Validate/Loss_avg_epoch', val_loss, epoch)

            # early stopping
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.no_improve_epochs = 0
                best_path = os.path.join(self.save_dir, f"best_model_epoch_{epoch}.pth")
                torch.save(self.model.state_dict(), best_path)
                print(f"  => Best model updated, saved at {best_path}")
                # 记录到 TensorBoard
                self.writer.add_scalar('Validate/Best_Loss', val_loss, epoch)
            else:
                self.no_improve_epochs += 1
                if self.no_improve_epochs >= self.patience:
                    print("Early stopping triggered!")
                    break

        # save last model
        last_path = os.path.join(self.save_dir, f"last_model_epoch_{epoch}.pth")
        torch.save(self.model.state_dict(), last_path)
        print(f"Training finished. Final model saved to {last_path}")
        # 关闭 TensorBoard
        self.writer.close()

###############################################
# 在Notebook中使用:
###############################################

# 1) 加载数据
#    {'8months-disease-replicate_1': Data(...),
#     '13months-disease-replicate_1': Data(...),
#      ... }


vae.eval()
for p in vae.parameters():
    p.requires_grad = False

# 3) 初始化 GAT 模型
clip_model = GAT_Embedder(
    in_features=2766,   # 你的节点特征维度
    gat_hidden=128,     # 增加隐藏单元数
    heads=8,            # 增加注意力头数
    dropout=0.5,
    embed_dim=256       # 增加 embedding 维度
).to(device)

# 4) 创建并启动 Trainer
trainer = GraphCLIPTrainer(
    model=clip_model,
    vae=vae,
    graph_data_dict=graph_data_dict,
    device=device,
    lr=1e-3,
    weight_decay=1e-5,
    max_epochs=50,       # 可自行调大
    batch_size=128,      # 每次对比时的总节点数
    temperature=0.07,    # 常用温度
    patience=10,         # 根据需要调整
    save_dir="./graphclip_checkpoints",
    log_dir="./logs"     # TensorBoard log directory
)

trainer.fit()


In [56]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from copy import deepcopy
from tqdm.auto import tqdm

from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL

# ------------------------------------------------------------------
# 1) Load your pretrained CLIP-like GAT embedder + projection
#    We'll call it "GraphClipEncoder" for clarity.
#    This can be your GAT_Embedder plus the projector you used.
#    We'll assume you have something like this:
# ------------------------------------------------------------------

class GraphClipEncoder(nn.Module):
    """
    Wraps your GAT_Embedder + projector into one module.
    """
    def __init__(self, gat_model, img_projector, embed_dim=256):
        super().__init__()
        self.gat_model = gat_model
        self.img_projector = img_projector
        self.embed_dim = embed_dim
        
    def forward(self, x, edge_index):
        # GAT forward pass
        node_emb = self.gat_model(x, edge_index)  # [N, embed_dim]
        # If you had an image projector for alignment, you might not need it here
        # but if your final node_emb depends on that, then you'd apply it.
        # For standard usage, let's assume the 'node_emb' is already the final 256-dim
        return node_emb

# ------------------------------------------------------------------
# 2) Define your GraphConditionedLDM model/trainer
# ------------------------------------------------------------------

class GraphConditionedLDMTrainer:
    def __init__(
        self,
        graph_data_dict,
        vae: AutoencoderKL,
        graph_clip_encoder: GraphClipEncoder,
        unet: UNet2DConditionModel,
        noise_scheduler: DDPMScheduler,
        device="cuda",
        batch_size=16,
        lr=1e-4,
        max_epochs=10,
        save_dir="./checkpoints_ldm",
        freeze_graph_encoder=True,
    ):
        """
        Args:
            graph_data_dict: dict of {key: PyG Data}, each .patches = (N, H, W, 1),
                             .x = (N, in_features), .edge_index = ...
            vae: A pretrained VAE (AutoencoderKL) used to encode/decode the image patches
            graph_clip_encoder: The GAT+projector model you trained (CLIP-like). 
                                If freeze_graph_encoder=True, we won't update it.
            unet: A UNet2DConditionModel from diffusers with cross_attention_dim matching the GAT embedding size
            noise_scheduler: A DDPMScheduler (or other) from diffusers
        """
        self.graph_data_dict = graph_data_dict
        self.vae = vae.eval().to(device)
        self.graph_clip_encoder = graph_clip_encoder.to(device)
        self.unet = unet.to(device)
        self.noise_scheduler = noise_scheduler
        self.device = device
        self.batch_size = batch_size
        self.lr = lr
        self.max_epochs = max_epochs
        
        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir

        # Freeze VAE encoder & Graph encoder (common practice)
        for p in self.vae.parameters():
            p.requires_grad = False
        if freeze_graph_encoder:
            for p in self.graph_clip_encoder.parameters():
                p.requires_grad = False
        
        # Optimizer (only UNet’s parameters by default)
        self.optimizer = Adam(self.unet.parameters(), lr=self.lr)
    
        # For convenience, gather all nodes from all graphs in one big list.
        # Or you can create a custom Dataset/DataLoader that does multi-graph sampling.
        self.all_nodes = []
        self._prepare_all_nodes()

    def _prepare_all_nodes(self):
        """
        Collect (graph_key, node_idx) pairs in a list so we can sample them in a DataLoader-like manner.
        """
        for gkey, data in self.graph_data_dict.items():
            N = data.x.size(0)
            for i in range(N):
                self.all_nodes.append((gkey, i))
        print(f"Total nodes across all graphs: {len(self.all_nodes)}")
    
    def _get_batch(self, indices):
        """
        Given a list of (graph_key, node_idx), return:
          - node_features batch: [B, in_features]
          - edge_index (we might need the entire graph's edge_index if we do a full forward, 
            but for speed you might pre-embed offline. For simplicity, do full forward.)
          - patches: [B, H, W, 1]
        We'll group them by graph_key and fetch the data in a minimal way.
        """
        # A simple approach: group by graph_key, fetch the entire graph, then pick those node_idx
        # If your graphs are large, you may prefer a different approach or an offline pre-embedding.
        from collections import defaultdict
        batch_by_gkey = defaultdict(list)
        for (gkey, idx) in indices:
            batch_by_gkey[gkey].append(idx)
        
        x_list = []
        edge_index_list = []
        patch_list = []
        # We'll also store all node indices in CPU for final stitching
        for gkey, node_idxs in batch_by_gkey.items():
            data = self.graph_data_dict[gkey]
            # data.x: [N, in_features]
            # data.edge_index: [2, E]
            # data.patches: [N, H, W, 1]
            # We do a GAT forward on the entire graph to get [N, embed_dim], then pick out the needed nodes.
            # So let's fetch everything first (on device):
            x_full = data.x.to(self.device)
            edge_index_full = data.edge_index.to(self.device)
            # Then we embed with the GAT:
            with torch.no_grad():
                node_emb_all = self.graph_clip_encoder(x_full, edge_index_full)  # [N, embed_dim]
            
            # We only need the embedding for the chosen nodes:
            node_idxs_tensor = torch.LongTensor(node_idxs).to(self.device)
            cond_emb = node_emb_all[node_idxs_tensor]  # [len(node_idxs), embed_dim]
            
            # For the UNet cross-attention, shape should be [B, sequence_length, cross_attention_dim].
            # We'll pick sequence_length=1:
            cond_emb = cond_emb.unsqueeze(1)  # [B, 1, embed_dim]

            # Collect patches for these nodes
            all_patches = data.patches.to(self.device)  # [N, 128, 128, 1]
            chosen_patches = all_patches[node_idxs_tensor]  # [b, 128,128,1]

            x_list.append(cond_emb)
            patch_list.append(chosen_patches)
        
        # Concatenate along batch dimension
        cond_emb_batch = torch.cat(x_list, dim=0)         # [B, 1, embed_dim]
        patch_batch = torch.cat(patch_list, dim=0)        # [B, 128,128,1]
        
        return cond_emb_batch, patch_batch

    def train_one_epoch(self):
        self.unet.train()
        # 随机打乱所有节点顺序
        indices = torch.randperm(len(self.all_nodes))
        num_batches = (len(indices) + self.batch_size - 1) // self.batch_size
        
        epoch_loss = 0.0
        # 使用 tqdm 包裹 mini-batch 循环显示进度条
        batch_iter = tqdm(range(num_batches), desc="Training Batches", leave=False)
        for b in batch_iter:
            batch_indices = indices[b*self.batch_size : (b+1)*self.batch_size]
            batch_pairs = [self.all_nodes[i.item()] for i in batch_indices]
            
            # 1) 获取条件 embedding 以及对应 patch
            cond_emb_batch, patch_batch = self._get_batch(batch_pairs)
            B = patch_batch.size(0)

            # 2) 使用 VAE 编码 patch 为 latent 表示
            with torch.no_grad():
                patch_batch = patch_batch.permute(0, 3, 1, 2)  # [B, 1, H, W]
                latent_dist = self.vae.encode(patch_batch).latent_dist
                latents = latent_dist.sample()
            
            # 3) 随机采样时间步，并为 latents 添加噪声
            t = torch.randint(0, self.noise_scheduler.num_train_timesteps, (B,), device=self.device).long()
            noise = torch.randn_like(latents)
            noisy_latents = self.noise_scheduler.add_noise(latents, noise, t)

            # 4) UNet 前向传播：注意我们将关键字 'timesteps' 改为 'timestep'
            model_out = self.unet(
                sample=noisy_latents,
                timestep=t,
                encoder_hidden_states=cond_emb_batch,  # [B, 1, cross_attention_dim]
            )
            pred_noise = model_out.sample  # [B, C, H_latent, W_latent]

            # 5) 计算 MSE 损失
            loss = F.mse_loss(pred_noise, noise)
            
            # 6) 更新参数
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            epoch_loss += loss.item()
            batch_iter.set_postfix(loss=loss.item())

        return epoch_loss / num_batches

    def fit(self):
        # 使用 tqdm 包裹 epoch 循环
        for epoch in tqdm(range(1, self.max_epochs + 1), desc="Epochs"):
            train_loss = self.train_one_epoch()
            print(f"[Epoch {epoch}/{self.max_epochs}] | Train Loss: {train_loss:.6f}")
            if epoch % 5 == 0:
                save_path = os.path.join(self.save_dir, f"unet_epoch_{epoch}.pth")
                torch.save(self.unet.state_dict(), save_path)
                print(f"Saved checkpoint at {save_path}")

# ------------------------------------------------------------------
# 3) Putting it all together (example usage)
# ------------------------------------------------------------------


device = "cuda" if torch.cuda.is_available() else "cpu"

# Assume you've already loaded the following:
#   graph_data_dict: dict of {key: PyG Data}
#   pretrained_vae: an AutoencoderKL instance with weights loaded
#   pretrained_graphclip_encoder: your GAT_Embedder + projection, loaded and optionally frozen
#   or you can reconstruct it similarly.

# Example: UNet with cross_attention_dim = 256 to match your GAT embed_dim
unet_config = {
    "sample_size": 16,             # e.g. if your latent resolution is 16x16
    "in_channels": 16,             # if your VAE latent_channels=16
    "out_channels": 16,            # same as in_channels for DDPM
    "layers_per_block": 2,
    "block_out_channels": (320, 640, 640, 1280),
    "down_block_types": (
        "CrossAttnDownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
    ),
    "up_block_types": (
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "CrossAttnUpBlock2D",
    ),
    # This is crucial: cross_attention_dim must match your GAT embed_dim
    "cross_attention_dim": 256,
}
unet = UNet2DConditionModel(**unet_config)

# A basic DDPM Scheduler
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_start=0.0001,
    beta_end=0.02,
    beta_schedule="linear",
)

# Build the trainer
trainer = GraphConditionedLDMTrainer(
    graph_data_dict=graph_data_dict,
    vae=vae,
    graph_clip_encoder=clip_model,
    unet=unet,
    noise_scheduler=noise_scheduler,
    device=device,
    batch_size=16,
    lr=1e-4,
    max_epochs=20,
    save_dir="./checkpoints_ldm",
    freeze_graph_encoder=True,  # freeze the GAT+CLIP if desired
)

# Train!
trainer.fit()


In [52]:
clip_model

In [53]:
vae

## 改进后（pos only）

In [62]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from collections import defaultdict
from tqdm.auto import tqdm

# 假设你已经导入或定义了以下模块：
# - AutoencoderKL (作为 VAE 模型)
# - UNet2DConditionModel (作为扩散模型中的 UNet)
# - GAT_Embedder (你的图编码器)

# -------------------------
# GraphConditionedLDMTrainer 类
# -------------------------
class GraphConditionedLDMTrainer:
    def __init__(
        self,
        graph_data_dict,         # dict, 每个值均为一个 PyG Data，包含 data.x, data.edge_index, data.patches, data.y
        vae: nn.Module,          # 预训练好的 AutoencoderKL（VAE），用于 patch 到 latent 的映射（冻结不训练）
        graph_clip_encoder: nn.Module,  # 已训练好的 GAT + Projection 模型（CLIP 风格），用于图节点编码
        unet: nn.Module,         # UNet2DConditionModel，用于预测噪声，包含 cross-attention（其 cross_attention_dim 应与图编码器输出匹配）
        noise_scheduler,         # 扩散模型的调度器，例如 DDPMScheduler
        device="cuda",
        batch_size=16,
        lr=1e-4,
        max_epochs=10,
        save_dir="./checkpoints_ldm",
        freeze_graph_encoder=True,
    ):
        """
        Args:
            graph_data_dict: dict {graph_key: Data}, 每个 Data 包含：
                             - data.x: [N, in_features]
                             - data.edge_index: [2, E]
                             - data.patches: [N, H, W, 1]  (图像 patch)
                             - data.y: [N, 1]  (节点标签，正样本为 label > 0)
            vae: 已训练好的 AutoencoderKL（VAE），用于将 patch 编码为 latent 表示
            graph_clip_encoder: 已训练好的 GAT_Embedder + projection 模型，用于将图中节点编码为 embedding
            unet: UNet2DConditionModel，扩散模型中预测噪声使用
            noise_scheduler: 如 DDPMScheduler，用于向 latent 中添加噪声和进行时间步调度
            device: "cuda" 或 "cpu"
            batch_size: 每个 mini-batch 正样本数量
            lr: 学习率
            max_epochs: 总共训练多少个 epoch
            save_dir: 模型存储目录
            freeze_graph_encoder: 是否冻结图编码器（通常冻结以加速训练）
        """
        self.graph_data_dict = graph_data_dict
        self.device = device
        self.batch_size = batch_size
        self.lr = lr
        self.max_epochs = max_epochs

        # 模型迁移到 device
        self.vae = vae.eval().to(device)
        self.graph_clip_encoder = graph_clip_encoder.to(device)
        self.unet = unet.to(device)
        self.noise_scheduler = noise_scheduler

        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir

        # 冻结 VAE 参数（其仅用于预计算 latent 表示）
        for p in self.vae.parameters():
            p.requires_grad = False
        # 根据需要冻结图编码器
        if freeze_graph_encoder:
            for p in self.graph_clip_encoder.parameters():
                p.requires_grad = False

        # 优化器仅更新 UNet 参数
        self.optimizer = Adam(self.unet.parameters(), lr=self.lr)
        # 混合精度训练的梯度缩放器
        self.scaler = torch.cuda.amp.GradScaler()

        # 对每个图提前预计算 VAE latent 表示（固定不变）
        self.precompute_latents()
        # 仅收集正样本节点 (label > 0)
        self.all_nodes = []
        self._prepare_positive_nodes()

    def precompute_latents(self):
        """对所有图中的 patch 预计算 latent 表示，并存入每个 data.latents 中
           为防止一次性占用过多内存，对每个图分批处理 patch 编码。
        """
        print("预计算 VAE latent 表示...")
        # 设置一个较小的批次大小，防止内存爆掉
        latent_batch_size = 64  
        for key, data in self.graph_data_dict.items():
            # data.patches: [N, H, W, 1]
            patches = data.patches.to(self.device).permute(0, 3, 1, 2).float()
            N = patches.size(0)
            latent_chunks = []
            for start in tqdm(range(0, N, latent_batch_size),
                              desc=f"预计算图 {key} 的 latent", leave=False):
                end = start + latent_batch_size
                patch_batch = patches[start:end]
                with torch.no_grad():
                    with torch.cuda.amp.autocast():
                        latent_dist = self.vae.encode(patch_batch).latent_dist
                        latent_chunk = latent_dist.sample()
                        latent_chunks.append(latent_chunk)
                # 清理缓存
                torch.cuda.empty_cache()
            # 将所有小批次拼接为一个 tensor
            data.latents = torch.cat(latent_chunks, dim=0)
        print("预计算完成。")

    def _prepare_positive_nodes(self):
        """
        遍历所有图数据，仅收集 label > 0 的节点：(graph_key, node_index)
        """
        for gkey, data in self.graph_data_dict.items():
            # 假设 data.y 形状为 [N, 1]
            y = data.y.view(-1)
            pos_idx = (y > 0).nonzero(as_tuple=True)[0]
            for idx in pos_idx.tolist():
                self.all_nodes.append((gkey, idx))
        print(f"所有图中正样本节点总数：{len(self.all_nodes)}")

    def _get_batch(self, indices):
        """
        根据一批 (graph_key, node_index) 返回：
          - 条件 embedding: [B, 1, embed_dim]（通过整个图的 GAT forward 得到，并只选择 positive 节点）
          - 对应的预计算 latent 表示: [B, latent_channels, H_lat, W_lat]
        """
        batch_by_gkey = defaultdict(list)
        for (gkey, idx) in indices:
            batch_by_gkey[gkey].append(idx)

        cond_emb_list = []
        latent_list = []
        for gkey, node_idxs in batch_by_gkey.items():
            data = self.graph_data_dict[gkey]
            # 整图前向传播：计算所有节点的 embedding（GAT 计算整个图）
            x_full = data.x.to(self.device)
            edge_index_full = data.edge_index.to(self.device)
            with torch.no_grad():
                node_emb_all = self.graph_clip_encoder(x_full, edge_index_full)  # [N, embed_dim]
            node_idxs_tensor = torch.LongTensor(node_idxs).to(self.device)
            # 选取对应 positive 节点的 embedding，并扩展 token 维度使其适应 UNet 的交叉注意力，形状：[b, 1, embed_dim]
            cond_emb = node_emb_all[node_idxs_tensor].unsqueeze(1)
            cond_emb_list.append(cond_emb)
            # 使用预先存入 data.latents 的 latent 表示
            latents = data.latents[node_idxs_tensor]  # 形状：[b, latent_channels, H_lat, W_lat]
            latent_list.append(latents)

        cond_emb_batch = torch.cat(cond_emb_list, dim=0)
        latent_batch = torch.cat(latent_list, dim=0)
        return cond_emb_batch, latent_batch

    def train_one_epoch(self):
        self.unet.train()
        # 随机打乱所有正样本节点顺序
        indices = torch.randperm(len(self.all_nodes))
        num_batches = (len(indices) + self.batch_size - 1) // self.batch_size

        epoch_loss = 0.0
        batch_iter = tqdm(range(num_batches), desc="Training Batches", leave=False)
        for b in batch_iter:
            batch_indices = indices[b * self.batch_size : (b + 1) * self.batch_size]
            batch_pairs = [self.all_nodes[i.item()] for i in batch_indices]

            # 获取条件 embedding 和预计算好的 latent 表示（作为 ground truth）
            cond_emb_batch, latents = self._get_batch(batch_pairs)
            B = latents.size(0)

            # 随机采样时间步 t，并生成同尺寸噪声
            t = torch.randint(0, self.noise_scheduler.num_train_timesteps, (B,), device=self.device).long()
            noise = torch.randn_like(latents)
            # 为 latent 添加噪声
            noisy_latents = self.noise_scheduler.add_noise(latents, noise, t)

            self.optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                # 注意：UNet 的 forward 接口使用关键字参数 'timestep'（而非 'timesteps'）
                model_out = self.unet(
                    sample=noisy_latents,
                    timestep=t,
                    encoder_hidden_states=cond_emb_batch,  # [B, 1, cross_attention_dim]
                )
                pred_noise = model_out.sample  # 输出预测噪声，形状 [B, latent_channels, H_lat, W_lat]
                loss = F.mse_loss(pred_noise, noise)

            # 反向传播采用混合精度
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            epoch_loss += loss.item()
            batch_iter.set_postfix(loss=loss.item())

        return epoch_loss / num_batches

    def fit(self):
        for epoch in tqdm(range(1, self.max_epochs + 1), desc="Epochs"):
            train_loss = self.train_one_epoch()
            print(f"[Epoch {epoch}/{self.max_epochs}] | Train Loss: {train_loss:.6f}")
            if epoch % 5 == 0:
                save_path = os.path.join(self.save_dir, f"unet_epoch_{epoch}.pth")
                torch.save(self.unet.state_dict(), save_path)
                print(f"Saved checkpoint at {save_path}")

# -------------------------
# 示例使用代码
# -------------------------

device = "cuda" if torch.cuda.is_available() else "cpu"

# 请确保下列变量已经正确加载或初始化：
# graph_data_dict: dict, 如 { 'graph1': data1, 'graph2': data2, ... }
#   每个 data 包含 data.x, data.edge_index, data.y, data.patches
# vae: 已加载的 AutoencoderKL 模型，权重已加载
# graph_clip_encoder: 已训练好的 GAT_Embedder + projection 模型
# unet: 一个 UNet2DConditionModel 模型，其 cross_attention_dim 应与 graph_clip_encoder 输出维度匹配
# noise_scheduler: 例如 DDPMScheduler 的实例
#
# 示例：
# from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
#
# unet = UNet2DConditionModel(**unet_config)
# noise_scheduler = DDPMScheduler(...)
# vae = AutoencoderKL(...)  # 并加载预训练权重
#
# graph_data_dict 已由你的数据预处理模块生成

# 这里假设上述模型和 graph_data_dict 均已加载
trainer = GraphConditionedLDMTrainer(
    graph_data_dict=graph_data_dict,
    vae=vae,
    graph_clip_encoder=clip_model,
    unet=unet,
    noise_scheduler=noise_scheduler,
    device=device,
    batch_size=16,
    lr=1e-4,
    max_epochs=20,
    save_dir="./checkpoints_ldm",
    freeze_graph_encoder=True,  # 冻结图编码器以加速训练
)

trainer.fit()


In [63]:
torch.cuda.empty_cache()

In [None]:
import random
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np

def generate_predicted_image(trainer, graph_key, node_index, num_inference_steps=50):
    """
    Generate the predicted image (patch) by sampling from the trained diffusion model.
    
    Parameters:
      - trainer: An instance of GraphConditionedLDMTrainer containing trained models (vae, unet, noise_scheduler, etc.)
      - graph_key: str, key indicating the target graph.
      - node_index: int, the index of the target node in the graph.
      - num_inference_steps: int, number of inference steps for the reverse diffusion process.
      
    Returns:
      - image: Tensor of the generated image with shape [1, C, H, W].
    """
    device = trainer.device
    data = trainer.graph_data_dict[graph_key]
    # Use precomputed latent's shape (for the specified node) to determine the latent shape.
    latent_shape = data.latents[node_index:node_index+1].shape  # [1, latent_channels, H_lat, W_lat]
    
    # Compute full graph node embeddings, then select the condition for the specified node.
    x_full = data.x.to(device)
    edge_index = data.edge_index.to(device)
    with torch.no_grad():
        node_emb_all = trainer.graph_clip_encoder(x_full, edge_index)  # shape: [N, embed_dim]
    condition = node_emb_all[node_index].unsqueeze(0).unsqueeze(1)  # shape: [1, 1, embed_dim]
    
    # Initialize latent representation with random noise.
    sample = torch.randn(latent_shape, device=device)
    
    trainer.unet.eval()
    trainer.noise_scheduler.set_timesteps(num_inference_steps)
    
    # Reverse diffusion process: iterate through timesteps to denoise.
    for t in tqdm(trainer.noise_scheduler.timesteps, desc=f"Sampling (Graph: {graph_key} Node: {node_index})", leave=False):
        t_tensor = torch.tensor([t], device=device).long()
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                model_out = trainer.unet(
                    sample=sample,
                    timestep=t_tensor,
                    encoder_hidden_states=condition  # [1, 1, embed_dim]
                )
            pred_noise = model_out.sample  # predicted noise, same shape as sample.
            step_output = trainer.noise_scheduler.step(pred_noise, t, sample)
            sample = step_output["prev_sample"]
    
    # Decode the latent representation into an image patch.
    with torch.no_grad():
        decoded = trainer.vae.decode(sample)
        if hasattr(decoded, "sample"):
            image = decoded.sample
        else:
            image = decoded
    return image  # Expected shape: [1, C, H, W]

def get_original_image(data, node_index):
    """
    Retrieve the original image patch for a specified node from the graph data.
    
    Parameters:
      - data: The graph data object containing data.patches (assumed shape [N, H, W, 1]).
      - node_index: int, the index of the target node.
      
    Returns:
      - image: Tensor of the original patch with shape [1, C, H, W].
      
    Note:
      The original patch is in shape [H, W, 1]. We convert it to [1, 1, H, W] (batch, channel, H, W).
    """
    patch = data.patches[node_index]  # shape: [H, W, 1]
    # Convert patch to tensor and permute channels: from [H, W, 1] to [1, 1, H, W]
    patch_tensor = torch.tensor(patch).permute(2, 0, 1).unsqueeze(0).float()
    return patch_tensor

def postprocess_image(image_tensor):
    """
    Postprocess the image tensor to a numpy array normalized to [0,1], ready for visualization.
    
    Parameters:
      - image_tensor: Tensor with shape [1, C, H, W].
      
    Returns:
      - image_np: numpy array in shape [H, W, C].
      
    If the image has only one channel, replicate it to 3 channels for color visualization.
    """
    img = image_tensor.squeeze(0).cpu()  # shape becomes [C, H, W]
    if img.shape[0] == 1:
        img = img.repeat(3, 1, 1)  # Convert grayscale to 3 channels.
    image_np = img.permute(1, 2, 0).numpy()  # shape: [H, W, 3]
    image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min() + 1e-8)
    return image_np

def visualize_comparison(trainer, num_samples=5, num_inference_steps=50):
    """
    Randomly select num_samples positive nodes and visualize their original and predicted image patches.
    The visualization is arranged in 2 rows and 5 columns:
      - The first row shows the Original images for the selected samples.
      - The second row shows the corresponding Predicted images.
      
    Parameters:
      - trainer: The GraphConditionedLDMTrainer instance (with attributes: all_nodes, graph_data_dict, etc.)
      - num_samples: int, number of samples (columns) to display.
      - num_inference_steps: int, number of inference steps during reverse diffusion sampling.
    """
    device = trainer.device
    selected_nodes = random.sample(trainer.all_nodes, num_samples)
    
    # Create subplots: 2 rows x num_samples columns.
    fig, axes = plt.subplots(2, num_samples, figsize=(4*num_samples, 8))
    
    # Prepare lists to store images for debugging if needed
    original_images = []
    predicted_images = []
    
    for i, (graph_key, node_index) in enumerate(selected_nodes):
        data = trainer.graph_data_dict[graph_key]
        
        # Get the original patch image.
        original = get_original_image(data, node_index)  # shape: [1, 1, H, W]
        original_np = postprocess_image(original)  # shape: [H, W, 3]
        
        # Generate the predicted image via the diffusion process.
        predicted = generate_predicted_image(trainer, graph_key, node_index, num_inference_steps=num_inference_steps)
        predicted_np = postprocess_image(predicted)  # shape: [H, W, 3]
        
        original_images.append(original_np)
        predicted_images.append(predicted_np)
        
        # Place Original image in first row at column i.
        axes[0, i].imshow(original_np)
        axes[0, i].set_title(f"Original\n(Graph: {graph_key}, Node: {node_index})", fontsize=12)
        axes[0, i].axis("off")
        
        # Place Predicted image in second row at column i.
        axes[1, i].imshow(predicted_np)
        axes[1, i].set_title(f"Predicted\n(Graph: {graph_key}, Node: {node_index})", fontsize=6)
        axes[1, i].axis("off")
    
    plt.tight_layout()
    plt.show()

# Example usage:
if __name__ == '__main__':
    # Assume that trainer, graph_data_dict, and trainer.all_nodes have been properly constructed.
    # trainer.all_nodes is a list of tuples: [(graph_key, node_index), ...]
    visualize_comparison(trainer, num_samples=5, num_inference_steps=100)


In [75]:
import random
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np

def generate_predicted_image(trainer, graph_key, node_index, num_inference_steps=50):
    """
    Generate the predicted image (patch) by sampling from the trained diffusion model.
    
    Parameters:
      - trainer: An instance of GraphConditionedLDMTrainer containing trained models (vae, unet, noise_scheduler, etc.)
      - graph_key: str, key indicating the target graph.
      - node_index: int, the index of the target node in the graph.
      - num_inference_steps: int, number of inference steps for the reverse diffusion process.
      
    Returns:
      - image: Tensor of the generated image with shape [1, C, H, W].
    """
    device = trainer.device
    data = trainer.graph_data_dict[graph_key]
    # Use precomputed latent's shape (for the specified node) to determine the latent shape.
    latent_shape = data.latents[node_index:node_index+1].shape  # [1, latent_channels, H_lat, W_lat]
    
    # Compute full graph node embeddings, then select the condition for the specified node.
    x_full = data.x.to(device)
    edge_index = data.edge_index.to(device)
    with torch.no_grad():
        node_emb_all = trainer.graph_clip_encoder(x_full, edge_index)  # shape: [N, embed_dim]
    condition = node_emb_all[node_index].unsqueeze(0).unsqueeze(1)  # shape: [1, 1, embed_dim]
    
    # Initialize latent representation with random noise.
    sample = torch.randn(latent_shape, device=device)
    
    trainer.unet.eval()
    trainer.noise_scheduler.set_timesteps(num_inference_steps)
    
    # Reverse diffusion process: iterate through timesteps to denoise.
    for t in tqdm(trainer.noise_scheduler.timesteps, desc=f"Sampling (Graph: {graph_key} Node: {node_index})", leave=False):
        t_tensor = torch.tensor([t], device=device).long()
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                model_out = trainer.unet(
                    sample=sample,
                    timestep=t_tensor,
                    encoder_hidden_states=condition  # [1, 1, embed_dim]
                )
            pred_noise = model_out.sample  # predicted noise, same shape as sample.
            step_output = trainer.noise_scheduler.step(pred_noise, t, sample)
            sample = step_output["prev_sample"]
    
    # Decode the latent representation into an image patch.
    with torch.no_grad():
        decoded = trainer.vae.decode(sample)
        if hasattr(decoded, "sample"):
            image = decoded.sample
        else:
            image = decoded
    return image  # Expected shape: [1, C, H, W]

def get_original_image(data, node_index):
    """
    Retrieve the original image patch for a specified node from the graph data.
    
    Parameters:
      - data: The graph data object containing data.patches (assumed shape [N, H, W, 1]).
      - node_index: int, the index of the target node.
      
    Returns:
      - image: Tensor of the original patch with shape [1, C, H, W].
      
    Note:
      The original patch is in shape [H, W, 1]. We convert it to [1, 1, H, W] (batch, channel, H, W).
    """
    patch = data.patches[node_index]  # shape: [H, W, 1]
    # Convert patch to tensor and permute channels: from [H, W, 1] to [1, 1, H, W]
    patch_tensor = torch.tensor(patch).permute(2, 0, 1).unsqueeze(0).float()
    return patch_tensor

def postprocess_image(image_tensor):
    """
    Postprocess the image tensor to a numpy array normalized to [0,1], ready for visualization.
    
    Parameters:
      - image_tensor: Tensor with shape [1, C, H, W].
      
    Returns:
      - image_np: numpy array in shape [H, W, C].
      
    If the image has only one channel, replicate it to 3 channels for color visualization.
    """
    img = image_tensor.squeeze(0).cpu()  # shape becomes [C, H, W]
    if img.shape[0] == 1:
        img = img.repeat(3, 1, 1)  # Convert grayscale to 3 channels.
    image_np = img.permute(1, 2, 0).numpy()  # shape: [H, W, 3]
    image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min() + 1e-8)
    return image_np

def visualize_comparison(trainer, num_samples=5, num_inference_steps=50):
    """
    Randomly select num_samples positive nodes and visualize their original and predicted image patches.
    The visualization is arranged in 2 rows and 5 columns:
      - The first row shows the Original images for the selected samples.
      - The second row shows the corresponding Predicted images.
      
    Parameters:
      - trainer: The GraphConditionedLDMTrainer instance (with attributes: all_nodes, graph_data_dict, etc.)
      - num_samples: int, number of samples (columns) to display.
      - num_inference_steps: int, number of inference steps during reverse diffusion sampling.
    """
    device = trainer.device
    selected_nodes = random.sample(trainer.all_nodes, num_samples)
    
    # Create subplots: 2 rows x num_samples columns.
    fig, axes = plt.subplots(2, num_samples, figsize=(4*num_samples, 8))
    
    # Prepare lists to store images for debugging if needed
    original_images = []
    predicted_images = []
    
    for i, (graph_key, node_index) in enumerate(selected_nodes):
        data = trainer.graph_data_dict[graph_key]
        
        # Get the original patch image.
        original = get_original_image(data, node_index)  # shape: [1, 1, H, W]
        original_np = postprocess_image(original)  # shape: [H, W, 3]
        
        # Generate the predicted image via the diffusion process.
        predicted = generate_predicted_image(trainer, graph_key, node_index, num_inference_steps=num_inference_steps)
        predicted_np = postprocess_image(predicted)  # shape: [H, W, 3]
        
        original_images.append(original_np)
        predicted_images.append(predicted_np)
        graph_name = graph_key.split("-")[0] + "..." + graph_key.split("-")[-1]
        
        # Place Original image in first row at column i.
        axes[0, i].imshow(original_np)
        axes[0, i].set_title(f"Original\n(Graph: {graph_name}, Node: {node_index})", fontsize=12)
        axes[0, i].axis("off")
        
        # Place Predicted image in second row at column i.
        axes[1, i].imshow(predicted_np)
        
        axes[1, i].set_title(f"Predicted\n(Graph: {graph_name}, Node: {node_index})", fontsize=12)
        axes[1, i].axis("off")
    
    plt.tight_layout()
    plt.show()

# Example usage:
if __name__ == '__main__':
    # Assume that trainer, graph_data_dict, and trainer.all_nodes have been properly constructed.
    # trainer.all_nodes is a list of tuples: [(graph_key, node_index), ...]
    visualize_comparison(trainer, num_samples=5, num_inference_steps=100)


In [77]:
visualize_comparison(trainer, num_samples=20, num_inference_steps=100)

In [78]:
visualize_comparison(trainer, num_samples=50, num_inference_steps=100)

In [79]:
visualize_comparison(trainer, num_samples=50, num_inference_steps=20)

In [81]:
visualize_comparison(trainer, num_samples=10, num_inference_steps=200)

In [83]:
visualize_comparison(trainer, num_samples=5, num_inference_steps=100)