In [7]:
import scanpy as sc
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import skimage.io as io
import random
from torch_geometric.data import Data

In [8]:
os.chdir('/public/home/jijh/diffusion_project/ADiffusion')

In [9]:
import importlib
import src.preprocessing.data_process
importlib.reload(src.preprocessing.data_process)

In [10]:
from src.preprocessing.data_process import extract_patches, create_graph_data_dict, construct_affinity_matrix

In [11]:
def seed_everything(seed=0):
    """Initialize random seeds for reproducibility."""
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    
# Set random seed for reproducibility
random_seed = 0
seed_everything(random_seed)

# Data Preprocessing

In [176]:
from tqdm import tqdm

file_dir = "/public/home/jijh/st_project/cellbin_analysis/spatial_variation/wx_data/"  # Directory containing the data files
files = os.listdir(file_dir)  # List all files in the directory
files = [i for i in files if i.endswith(".h5ad") and "month" in i]  # Filter files to include only those ending with ".h5ad" and containing "month"
file_paths = [os.path.join(file_dir, i) for i in files]  # Create full file paths for the filtered files
adatas = {}  # Initialize an empty dictionary to store AnnData objects

# Read each file and store the AnnData object in the dictionary
for i in range(len(file_paths)):
    adatas[files[i].split(".")[0]] = sc.read(file_paths[i])

# Preprocess each AnnData object
for key in tqdm(adatas.keys(), desc="Preprocessing datasets"):
    sc.pp.normalize_total(adatas[key], target_sum=1e4)  # Normalize counts per cell
    sc.pp.log1p(adatas[key])  # Logarithmize the data
    adatas[key].layers["raw"] = adatas[key].X.copy()  # Store the raw data in the "raw" layer
    # sc.pp.scale(adatas[key], max_value=10)  # Scale the data to have a maximum value of 10
    # sc.tl.pca(adatas[key], svd_solver="arpack")  # Perform PCA

In [177]:
# Extract spatial coordinates for each cell
cell_coords = {}
for key in adatas.keys():
    cell_coords[key] = adatas[key].obsm["spatial"].copy()

In [178]:
neighbors = {}
for key in cell_coords.keys():
    neighbors[key] = construct_affinity_matrix(cell_coords[key], mode='radius', cutoff=100)

In [179]:

# Load the plaque dataset
img_dir = "/public/home/jijh/st_project/cellbin_analysis/spatial_variation/wx_data/protein_seg_result/"
img_files = os.listdir(img_dir)
img_files = [i for i in img_files if i.endswith(".tiff") and "plaque" in i]
# Read the images

imgs = {}
for i in range(len(img_files)):
    imgs[img_files[i].split(".")[0]] = io.imread(os.path.join(img_dir, img_files[i]))
imgs.keys()
# Rename the imgs to match the adata keys
for key in list(imgs.keys()):
    parts = key.split("_")
    if len(parts) > 1:
        new_key = parts[1] + "_" + parts[2]
        imgs[new_key] = imgs.pop(key)



In [180]:
# Extract patches from the images
patches = {}
for key in imgs.keys():
    patches[key] = extract_patches(imgs[key], cell_coords[key], patch_size=128)

# Convert the patches to binary and calculate the area of positive pixels for each patch
binary_patches = {}
for key in patches.keys():
    binary_patches[key] = [patch > 0 for patch in patches[key]]
# Calculate the area of positive pixels for each patch
areas = {}
for key in binary_patches.keys():
    areas[key] = [np.sum(patch) for patch in binary_patches[key]]

In [181]:
# Choose 5 patches with non-zero area for each image and plot them
import matplotlib.pyplot as plt

# Set the number of patches to plot
n_patches = 5

# Plot the patches
fig, axes = plt.subplots(nrows=len(areas), ncols=n_patches, figsize=(15, 15))
axes = np.atleast_2d(axes)  # Ensure axes is 2D for consistent indexing

for i, key in enumerate(areas.keys()):
    non_zero_indices = np.where(np.array(areas[key]) > 0)[0]
    chosen_indices = random.sample(list(non_zero_indices), n_patches)
    for j, idx in enumerate(chosen_indices):
        axes[i, j].imshow(patches[key][idx], cmap="gray")
        axes[i, j].set_title(f"Area: {areas[key][idx]}")
        axes[i, j].axis("off")
    axes[i, 0].set_ylabel(key, rotation=0, size="large", labelpad=50)

In [182]:
# Create the graph dictionary
graph_data_dict = create_graph_data_dict(adatas, areas, neighbors, cell_coords, embeddings=["X"])


# Efficiently convert patches to tensors in batch
for key, graph in tqdm(graph_data_dict.items(), desc="Adding patches to graph data"):
    # Ensure patches[key] is a list of NumPy arrays
    patches_tensor = torch.tensor(np.array(patches[key]), dtype=torch.float)  # Convert patches to a single tensor efficiently
    graph.patches = patches_tensor

In [183]:
# Normalize the edge_attr
for key, graph in tqdm(graph_data_dict.items(), desc="Normalizing edge_attr"):
    graph.edge_attr = graph.edge_attr / graph.edge_attr.max()

In [184]:
graph_data_dict

In [88]:
# Visualize the values distribution of the node features
node_features = torch.cat([graph.x for graph in graph_data_dict.values()], dim=0)
plt.figure(figsize=(10, 5))
plt.hist(node_features.numpy().flatten(), bins=100)
plt.xlabel("Node feature values")
plt.ylabel("Frequency")
plt.title("Distribution of node features")
plt.show()


# Graph Autoregressive Construction

In [137]:
import torch
import torch.nn as nn
from torch_geometric.nn import GATv2Conv, LayerNorm
from tqdm import tqdm

class GATv2EncoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, heads=4, edge_dim=None, dropout=0.1):
        super(GATv2EncoderLayer, self).__init__()
        self.gatv2 = GATv2Conv(
            in_channels,
            out_channels,
            heads=heads,
            edge_dim=edge_dim,
            concat=True,
            dropout=dropout,
            add_self_loops=True,
            bias=True,
            residual=True
        )

        # Learnable linear transformation layer
        self.heads_transform = nn.Linear(out_channels * heads, out_channels)

        self.mlp = nn.Sequential(
            nn.Linear(out_channels, out_channels),
            nn.GELU(),
            nn.Linear(out_channels, out_channels),
        )

        self.norm1 = LayerNorm(out_channels * heads)  # After GATv2Conv
        self.norm2 = LayerNorm(out_channels)  # After MLP
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x, edge_index, edge_attr):
        # GATv2 forward pass
        out = self.gatv2(x, edge_index, edge_attr)
        out = self.norm1(out)  # Layer normalization after GATv2
        out = self.dropout(out)
        out = self.activation(out)

        # Learnable linear transformation
        out = self.heads_transform(out)

        # MLP with residual connection
        residual_mlp = out
        out = self.mlp(out)
        out = out + residual_mlp  # Add residual connection for MLP
        out = self.norm2(out)  # Layer normalization after MLP
        out = self.dropout(out)

        return out


class MaskedNodePredictorWithEncoder(nn.Module):
    def __init__(self, in_features, hidden_channels, edge_dim=None, heads=4, num_encoders=2, dropout=0.1):
        super(MaskedNodePredictorWithEncoder, self).__init__()
        self.encoders = nn.ModuleList()

        # Add encoder layers
        for i in range(num_encoders):
            self.encoders.append(
                GATv2EncoderLayer(
                    in_channels=in_features if i == 0 else hidden_channels,
                    out_channels=hidden_channels,
                    heads=heads,
                    edge_dim=edge_dim,
                    dropout=dropout,
                )
            )

        # Prediction MLP
        self.predictor = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.GELU(),
            nn.Linear(hidden_channels, in_features),  # Predict original features
        )

    def forward(self, data, mask):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = x.clone()  # To avoid modifying the original features

        # # Apply masking
        # x[mask] = 0  # Mask features (can use a learnable token here)

        # Pass through encoder layers
        for encoder in self.encoders:
            x = encoder(x, edge_index, edge_attr)

        # Extract embeddings of masked nodes
        masked_embeddings = x[mask]

        # Predict original features
        predictions = self.predictor(masked_embeddings)

        return predictions



In [138]:
# 定义加权 MSE 损失
class WeightedMSELoss(nn.Module):
    def __init__(self, pos_weight=10.0):
        super(WeightedMSELoss, self).__init__()
        self.pos_weight = pos_weight

    def forward(self, prediction, target):
        weight = (target != 0).float() * self.pos_weight + (target == 0).float()
        loss = weight * (prediction - target) ** 2
        return loss.mean()
    


In [113]:
import torch
import torch.nn as nn
from torch_geometric.nn import GATv2Conv, LayerNorm
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import random

# Objective: Mimic BERT-style training (Masked Node Prediction)

# Dataset preparation
# Replace synthetic data with your dataset
data = graph_data_dict['13months-disease-replicate_1'].clone().to(device)

# Extract data dimensions
num_nodes = data.x.size(0)
in_features = data.x.size(1)
hidden_channels = 64
edge_dim = data.edge_attr.size(1) if data.edge_attr is not None else None
heads = 4

# Initialize model
masked_pre_model = MaskedNodePredictorWithEncoder(
    in_features=in_features,
    hidden_channels=hidden_channels,
    edge_dim=edge_dim,
    heads=heads,
    num_encoders=4,
    dropout=0.2,
).to(device)

# Learning rate schedule parameters
warmup_epochs = 200
initial_lr = 1e-4
warmup_lr = 1e-6
epochs = 500
eta_min = 1e-7

# Initialize optimizer with warmup_lr
optimizer = torch.optim.Adam(masked_pre_model.parameters(), lr=warmup_lr)

# Create cosine scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=epochs - warmup_epochs, eta_min=eta_min)

# Training loop with BERT-style masking
loss_history = []
lr_history = []
smoothing_factor = 0.9
smoothed_loss = None
# 使用加权 MSE 损失
criterion = WeightedMSELoss(pos_weight=10.0)

for epoch in tqdm(range(epochs), desc="Training", leave=True):
    masked_pre_model.train()
    optimizer.zero_grad()

    # Warmup learning rate
    if epoch < warmup_epochs:
        lr = warmup_lr + (initial_lr - warmup_lr) * epoch / warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

    # Record current learning rate
    lr_history.append(optimizer.param_groups[0]['lr'])

    # Randomly select 15% of the nodes to mask
    mask = torch.rand(num_nodes, device=device) < 0.15
    target = data.x[mask]

    # Apply BERT-style masking strategy
    modified_data_x = data.x.clone()
    for idx in torch.where(mask)[0]:
        rand = random.random()
        if rand < 0.8:  # 80% of the time, replace with [MASK] token (0)
            modified_data_x[idx] = 0
        elif rand < 0.9:  # 10% of the time, replace with a random value
            modified_data_x[idx] = torch.randn_like(data.x[idx])
        # 10% of the time, leave it unchanged

    # Update the data object
    data.x = modified_data_x

    # Forward pass
    predictions = masked_pre_model(data, mask)

    # Compute weighted loss only for masked nodes
    loss = criterion(predictions, target)

    # Smooth loss using exponential moving average
    if smoothed_loss is None:
        smoothed_loss = loss.item()
    else:
        smoothed_loss = smoothing_factor * smoothed_loss + (1 - smoothing_factor) * loss.item()

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    # Log smoothed loss
    loss_history.append(smoothed_loss)

    # Update progress bar with loss info
    if epoch % 50 == 0:
        tqdm.write(f"Epoch {epoch}/{epochs}, Loss: {smoothed_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")

In [94]:
# Plot learning rate schedule
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 4))
plt.plot(lr_history, label='Learning Rate')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.legend()
plt.show()

# Visualize the loss history
plt.figure(figsize=(10, 5))
plt.plot(loss_history, label='Loss')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss History')
plt.show()


In [95]:
# Copy the encoder part with parameters of the masked_pre_model as a new model

import torch
import torch.nn as nn

class EncoderOnly(nn.Module):
    def __init__(self, encoders):
        """
        Initializes the EncoderOnly model.

        Args:
            encoders (nn.ModuleList): A list of encoder layers from the trained model.
        """
        super(EncoderOnly, self).__init__()
        self.encoders = encoders  # This should be a nn.ModuleList

    def forward(self, x, edge_index, edge_attr):
        """
        Forward pass through the encoder layers.

        Args:
            x (Tensor): Node feature matrix.
            edge_index (Tensor): Graph connectivity.
            edge_attr (Tensor): Edge feature matrix.

        Returns:
            Tensor: Encoded node features.
        """
        for encoder in self.encoders:
            x = encoder(x, edge_index, edge_attr)
        return x

# Assuming masked_pre_model is already trained and on the correct device
encoder_model = EncoderOnly(encoders=masked_pre_model.encoders).to(device)

# It's a good practice to set the model to evaluation mode if you're using it for inference
encoder_model.eval()

# (Optional) Verify that the parameters are correctly copied
for param_encoder, param_original in zip(encoder_model.parameters(), masked_pre_model.parameters()):
    assert torch.equal(param_encoder, param_original), "Parameters do not match!"



In [148]:
from src.visualization.validate_prediction import collect_predictions, visualize_all


In [97]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [114]:
# Collect predictions and targets
all_predictions, all_targets = collect_predictions(
    model=masked_pre_model,
    data=graph_data_dict['13months-disease-replicate_1'].clone().to(device),
    num_evaluations=10,
    mask_percentage=0.15,
    device=device
)


In [115]:
all_predictions.shape

In [116]:
# Perform all visualizations
visualize_all(all_predictions, all_targets)

In [36]:
# Perform all visualizations
visualize_all(all_predictions, all_targets)

In [89]:
data = graph_data_dict['13months-disease-replicate_1'].clone().to(device)

In [90]:
data.x[0, :]

# Visualize the values distribution one node feature
node_feature_idx = 0
plt.figure(figsize=(10, 5))
plt.hist(data.x[:, node_feature_idx].cpu().numpy(), bins=100)
plt.xlabel(f"Node feature {node_feature_idx} values")
plt.ylabel("Frequency")
plt.title(f"Distribution of node feature {node_feature_idx}")
plt.show()


In [47]:
encoded_embeddings = encoder_model(data.x, data.edge_index, data.edge_attr)

In [56]:
encoded_embeddings.shape

In [52]:
test_adata

In [72]:
test_adata = adatas['13months-disease-replicate_1'].copy()
test_adata.obsm["X_encoded"] = encoded_embeddings.detach().cpu().numpy()

# Step 1: Compute neighbors and UMAP for the 'X_encoded' representation
sc.pp.neighbors(test_adata, use_rep="X_encoded", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_encoded"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X_encoded")

# Step 2: Compute neighbors and UMAP for the 'X' representation
sc.pp.neighbors(test_adata, use_rep="X", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_X"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X")

# Step 3: Compute PCA, neighbors, and UMAP for the 'X_pca' representation
sc.pp.pca(test_adata, n_comps=50)
sc.pp.neighbors(test_adata, use_rep="X_pca", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_pca"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X_pca")

# Comput PCA for the 'X_encoded' representation, then use it for UMAP, neighbors, and leiden clustering
# Step 1: Extract the 'X_encoded' representation
X_encoded = test_adata.obsm["X_encoded"]
# Step 2: Compute PCA on the 'X_encoded' representation
# Note: sc.pp.pca works directly on the AnnData object; for custom data, use sc.tl.pca
pca_result = sc.tl.pca(X_encoded, n_comps=50, svd_solver='arpack', return_info=True)
# Step 3: Store the PCA results in 'obsm'
test_adata.obsm["X_encoded_pca"] = pca_result[0]  # PCA coordinates
sc.pp.neighbors(test_adata, use_rep="X_encoded_pca", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_encoded_pca"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X_encoded_pca")

# Plot UMAP for 'X_encoded' representation
sc.pl.embedding(test_adata, basis="X_umap_encoded", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)
# Plot UMAP for 'X' representation
sc.pl.embedding(test_adata, basis="X_umap_X", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)
# Plot UMAP for 'X' representation
sc.pl.embedding(test_adata, basis="X_umap_pca", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)
# Plot UMAP for 'X_encoded_pca' representation
sc.pl.embedding(test_adata, basis="X_umap_encoded_pca", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)


In [76]:
sc.tl.leiden(test_adata, resolution=0.1, key_added="leiden_X_encoded_pca")

In [77]:
# Step 4: Visualize clustering results on the 'spatial' basis
# Assuming 'spatial' coordinates are stored in `test_adata.obsm['spatial']`
fig, axes = plt.subplots(1, 4, figsize=(24, 6))

# Plot for 'X_encoded'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X_encoded",
    frameon=False,
    ax=axes[0],
    show=False,
    title="Clusters: X_encoded"
)

# Plot for 'X'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X",
    frameon=False,
    ax=axes[1],
    show=False,
    title="Clusters: X"
)

# Plot for 'X_pca'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X_pca",
    frameon=False,
    ax=axes[2],
    show=False,
    title="Clusters: X_pca"
)

# Plot for 'X_encoded_pca'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X_encoded_pca",
    frameon=False,
    ax=axes[3],
    show=False,
    title="Clusters: X_encoded_pca"
)

for ax in axes:
    ax.set_aspect("equal")
    ax.axis("off")

plt.tight_layout()
plt.show()

# GPT 优化后

In [118]:
import torch
import torch.nn as nn
from torch_geometric.nn import GATv2Conv, LayerNorm
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import random

# 定义模型
class GATv2EncoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, heads=4, edge_dim=None, dropout=0.1):
        super(GATv2EncoderLayer, self).__init__()
        self.gatv2 = GATv2Conv(
            in_channels,
            out_channels,
            heads=heads,
            edge_dim=edge_dim,
            concat=True,
            dropout=dropout,
            add_self_loops=True,
            bias=True,
            residual=True
        )

        # Learnable linear transformation layer
        self.heads_transform = nn.Linear(out_channels * heads, out_channels)

        self.mlp = nn.Sequential(
            nn.Linear(out_channels, out_channels),
            nn.GELU(),
            nn.Linear(out_channels, out_channels),
        )

        self.norm1 = LayerNorm(out_channels * heads)  # After GATv2Conv
        self.norm2 = LayerNorm(out_channels)  # After MLP
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x, edge_index, edge_attr):
        # GATv2 forward pass
        out = self.gatv2(x, edge_index, edge_attr)
        out = self.norm1(out)  # Layer normalization after GATv2
        out = self.dropout(out)
        out = self.activation(out)

        # Learnable linear transformation
        out = self.heads_transform(out)

        # MLP with residual connection
        residual_mlp = out
        out = self.mlp(out)
        out = out + residual_mlp  # Add residual connection for MLP
        out = self.norm2(out)  # Layer normalization after MLP
        out = self.dropout(out)

        return out

class MaskedNodePredictorWithEncoder(nn.Module):
    def __init__(self, in_features, hidden_channels, edge_dim=None, heads=4, num_encoders=2, dropout=0.1):
        super(MaskedNodePredictorWithEncoder, self).__init__()
        self.encoders = nn.ModuleList()

        # Add encoder layers
        for i in range(num_encoders):
            self.encoders.append(
                GATv2EncoderLayer(
                    in_channels=in_features if i == 0 else hidden_channels,
                    out_channels=hidden_channels,
                    heads=heads,
                    edge_dim=edge_dim,
                    dropout=dropout,
                )
            )

        # Prediction MLP
        self.predictor = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.GELU(),
            nn.Linear(hidden_channels, in_features),  # Predict original features
        )

    def forward(self, x, edge_index, edge_attr):
        # Pass through encoder layers
        for encoder in self.encoders:
            x = encoder(x, edge_index, edge_attr)

        # Predict original features
        predictions = self.predictor(x)

        return predictions

# 定义损失函数
class WeightedMSELoss(nn.Module):
    def __init__(self, pos_weight=10.0):
        super(WeightedMSELoss, self).__init__()
        self.pos_weight = pos_weight

    def forward(self, prediction, target):
        # 计算每个特征是否为非零
        weight = (target != 0).float() * self.pos_weight + (target == 0).float()
        loss = weight * (prediction - target) ** 2
        return loss.mean()

class CombinedLoss(nn.Module):
    def __init__(self, mse_weight=1.0, l1_weight=0.1):
        super(CombinedLoss, self).__init__()
        self.mse = nn.MSELoss()
        self.l1 = nn.L1Loss()
        self.mse_weight = mse_weight
        self.l1_weight = l1_weight

    def forward(self, prediction, target):
        return self.mse_weight * self.mse(prediction, target) + self.l1_weight * self.l1(prediction, target)

# 初始化模型
masked_pre_model = MaskedNodePredictorWithEncoder(
    in_features=in_features,
    hidden_channels=hidden_channels,
    edge_dim=edge_dim,
    heads=heads,
    num_encoders=4,
    dropout=0.2,
).to(device)

# 初始化损失函数
criterion = CombinedLoss(mse_weight=1.0, l1_weight=0.1)

# 优化器和学习率调度
optimizer = torch.optim.Adam(masked_pre_model.parameters(), lr=1e-4, weight_decay=1e-5)  # 包含 L2 正则
scheduler = CosineAnnealingLR(optimizer, T_max=300, eta_min=1e-7)  # 调整 T_max 以匹配 warmup 后的阶段

# 训练循环
loss_history = []
lr_history = []
smoothing_factor = 0.9
smoothed_loss = None
epochs = 100

for epoch in tqdm(range(epochs), desc="Training", leave=True):
    masked_pre_model.train()
    optimizer.zero_grad()

    # Warmup 学习率
    if epoch < warmup_epochs:
        lr = warmup_lr + (initial_lr - warmup_lr) * epoch / warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

    # 记录当前学习率
    lr_history.append(optimizer.param_groups[0]['lr'])

    # 特征级别掩蔽
    mask = torch.rand(data.x.size(), device=device) < 0.15  # 15% 的特征被掩蔽
    target = data.x.clone()
    target[~mask] = 0  # 只对被掩蔽的特征计算损失

    # 应用掩蔽策略
    modified_data_x = data.x.clone()
    mask_indices = mask.nonzero(as_tuple=False)
    for idx in mask_indices:
        rand = random.random()
        if rand < 0.8:  # 80% 的时候，用0替代
            modified_data_x[idx[0], idx[1]] = 0
        elif rand < 0.9:  # 10% 的时候，用随机值替代
            modified_data_x[idx[0], idx[1]] = torch.randn_like(data.x[idx[0], idx[1]])
        # 10% 的时候，保持原值不变

    # 更新数据对象
    data.x = modified_data_x

    # 前向传播
    predictions = masked_pre_model(data.x, data.edge_index, data.edge_attr)

    # 计算损失，只对被掩蔽的特征
    loss = criterion(predictions[mask], target[mask])

    # 平滑损失
    if smoothed_loss is None:
        smoothed_loss = loss.item()
    else:
        smoothed_loss = smoothing_factor * smoothed_loss + (1 - smoothing_factor) * loss.item()

    # 反向传播和优化
    loss.backward()
    optimizer.step()

    # 记录损失
    loss_history.append(smoothed_loss)

    # 每50个epoch打印一次损失和学习率
    if epoch % 50 == 0:
        tqdm.write(f"Epoch {epoch}/{epochs}, Loss: {smoothed_loss:.6f}, LR: {optimizer.param_groups[0]['lr']:.8f}")

    # 恢复原始数据以防止累积掩蔽
    data.x = target.clone()

# 可选：保存训练历史
torch.save({
    'loss_history': loss_history,
    'lr_history': lr_history,
}, 'training_history.pth')


In [119]:
# Assuming masked_pre_model is already trained and on the correct device
encoder_model = EncoderOnly(encoders=masked_pre_model.encoders).to(device)

# It's a good practice to set the model to evaluation mode if you're using it for inference
encoder_model.eval()

# (Optional) Verify that the parameters are correctly copied
for param_encoder, param_original in zip(encoder_model.parameters(), masked_pre_model.parameters()):
    assert torch.equal(param_encoder, param_original), "Parameters do not match!"


In [120]:
data = graph_data_dict['13months-disease-replicate_1'].clone().to(device)
encoded_embeddings = encoder_model(data.x, data.edge_index, data.edge_attr)

In [121]:
encoded_embeddings.shape

In [122]:
test_adata = adatas['13months-disease-replicate_1'].copy()
test_adata.obsm["X_encoded"] = encoded_embeddings.detach().cpu().numpy()

# Step 1: Compute neighbors and UMAP for the 'X_encoded' representation
sc.pp.neighbors(test_adata, use_rep="X_encoded", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_encoded"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X_encoded")

# Step 2: Compute neighbors and UMAP for the 'X' representation
sc.pp.neighbors(test_adata, use_rep="X", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_X"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X")

# Step 3: Compute PCA, neighbors, and UMAP for the 'X_pca' representation
sc.pp.pca(test_adata, n_comps=50)
sc.pp.neighbors(test_adata, use_rep="X_pca", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_pca"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X_pca")

# Comput PCA for the 'X_encoded' representation, then use it for UMAP, neighbors, and leiden clustering
# Step 1: Extract the 'X_encoded' representation
X_encoded = test_adata.obsm["X_encoded"]
# Step 2: Compute PCA on the 'X_encoded' representation
# Note: sc.pp.pca works directly on the AnnData object; for custom data, use sc.tl.pca
pca_result = sc.tl.pca(X_encoded, n_comps=50, svd_solver='arpack', return_info=True)
# Step 3: Store the PCA results in 'obsm'
test_adata.obsm["X_encoded_pca"] = pca_result[0]  # PCA coordinates
sc.pp.neighbors(test_adata, use_rep="X_encoded_pca", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_encoded_pca"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X_encoded_pca")

# Plot UMAP for 'X_encoded' representation
sc.pl.embedding(test_adata, basis="X_umap_encoded", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)
# Plot UMAP for 'X' representation
sc.pl.embedding(test_adata, basis="X_umap_X", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)
# Plot UMAP for 'X' representation
sc.pl.embedding(test_adata, basis="X_umap_pca", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)
# Plot UMAP for 'X_encoded_pca' representation
sc.pl.embedding(test_adata, basis="X_umap_encoded_pca", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)


In [123]:
# Step 4: Visualize clustering results on the 'spatial' basis
# Assuming 'spatial' coordinates are stored in `test_adata.obsm['spatial']`
fig, axes = plt.subplots(1, 4, figsize=(24, 6))

# Plot for 'X_encoded'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X_encoded",
    frameon=False,
    ax=axes[0],
    show=False,
    title="Clusters: X_encoded"
)

# Plot for 'X'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X",
    frameon=False,
    ax=axes[1],
    show=False,
    title="Clusters: X"
)

# Plot for 'X_pca'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X_pca",
    frameon=False,
    ax=axes[2],
    show=False,
    title="Clusters: X_pca"
)

# Plot for 'X_encoded_pca'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X_encoded_pca",
    frameon=False,
    ax=axes[3],
    show=False,
    title="Clusters: X_encoded_pca"
)

for ax in axes:
    ax.set_aspect("equal")
    ax.axis("off")

plt.tight_layout()
plt.show()

In [149]:
# Collect predictions and targets
all_predictions, all_targets = collect_predictions(
    model=masked_pre_model,
    data=graph_data_dict['13months-disease-replicate_1'].clone().to(device),
    num_evaluations=10,
    mask_percentage=0.15,
    device=device
)


In [None]:
all_predictions.shape

# PCA 实验

In [127]:
from tqdm import tqdm

file_dir = "/public/home/jijh/st_project/cellbin_analysis/spatial_variation/wx_data/"  # Directory containing the data files
files = os.listdir(file_dir)  # List all files in the directory
files = [i for i in files if i.endswith(".h5ad") and "month" in i]  # Filter files to include only those ending with ".h5ad" and containing "month"
file_paths = [os.path.join(file_dir, i) for i in files]  # Create full file paths for the filtered files
adatas = {}  # Initialize an empty dictionary to store AnnData objects

# Read each file and store the AnnData object in the dictionary
for i in range(len(file_paths)):
    adatas[files[i].split(".")[0]] = sc.read(file_paths[i])

# Preprocess each AnnData object
for key in tqdm(adatas.keys(), desc="Preprocessing datasets"):
    sc.pp.normalize_total(adatas[key], target_sum=1e4)  # Normalize counts per cell
    sc.pp.log1p(adatas[key])  # Logarithmize the data
    adatas[key].layers["raw"] = adatas[key].X.copy()  # Store the raw data in the "raw" layer
    sc.pp.scale(adatas[key], max_value=10)  # Scale the data to have a maximum value of 10
    sc.tl.pca(adatas[key], svd_solver="arpack")  # Perform PCA

In [128]:
# Extract spatial coordinates for each cell
cell_coords = {}
for key in adatas.keys():
    cell_coords[key] = adatas[key].obsm["spatial"].copy()

In [129]:
neighbors = {}
for key in cell_coords.keys():
    neighbors[key] = construct_affinity_matrix(cell_coords[key], mode='radius', cutoff=100)

In [130]:

# Load the plaque dataset
img_dir = "/public/home/jijh/st_project/cellbin_analysis/spatial_variation/wx_data/protein_seg_result/"
img_files = os.listdir(img_dir)
img_files = [i for i in img_files if i.endswith(".tiff") and "plaque" in i]
# Read the images

imgs = {}
for i in range(len(img_files)):
    imgs[img_files[i].split(".")[0]] = io.imread(os.path.join(img_dir, img_files[i]))
imgs.keys()
# Rename the imgs to match the adata keys
for key in list(imgs.keys()):
    parts = key.split("_")
    if len(parts) > 1:
        new_key = parts[1] + "_" + parts[2]
        imgs[new_key] = imgs.pop(key)



In [131]:
# Extract patches from the images
patches = {}
for key in imgs.keys():
    patches[key] = extract_patches(imgs[key], cell_coords[key], patch_size=128)

# Convert the patches to binary and calculate the area of positive pixels for each patch
binary_patches = {}
for key in patches.keys():
    binary_patches[key] = [patch > 0 for patch in patches[key]]
# Calculate the area of positive pixels for each patch
areas = {}
for key in binary_patches.keys():
    areas[key] = [np.sum(patch) for patch in binary_patches[key]]

In [132]:
adatas

In [133]:
# Create the graph dictionary
graph_data_dict = create_graph_data_dict(adatas, areas, neighbors, cell_coords, embeddings=["X_pca"])


# Efficiently convert patches to tensors in batch
for key, graph in tqdm(graph_data_dict.items(), desc="Adding patches to graph data"):
    # Ensure patches[key] is a list of NumPy arrays
    patches_tensor = torch.tensor(np.array(patches[key]), dtype=torch.float)  # Convert patches to a single tensor efficiently
    graph.patches = patches_tensor

In [134]:
# Normalize the edge_attr
for key, graph in tqdm(graph_data_dict.items(), desc="Normalizing edge_attr"):
    graph.edge_attr = graph.edge_attr / graph.edge_attr.max()

In [135]:
graph_data_dict

In [139]:
import torch
import torch.nn as nn
from torch_geometric.nn import GATv2Conv, LayerNorm
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import random

# Objective: Mimic BERT-style training (Masked Node Prediction)

# Dataset preparation
# Replace synthetic data with your dataset
data = graph_data_dict['13months-disease-replicate_1'].clone().to(device)

# Extract data dimensions
num_nodes = data.x.size(0)
in_features = data.x.size(1)
hidden_channels = 64
edge_dim = data.edge_attr.size(1) if data.edge_attr is not None else None
heads = 4

# Initialize model
masked_pre_model = MaskedNodePredictorWithEncoder(
    in_features=in_features,
    hidden_channels=hidden_channels,
    edge_dim=edge_dim,
    heads=heads,
    num_encoders=4,
    dropout=0.2,
).to(device)

# Learning rate schedule parameters
warmup_epochs = 200
initial_lr = 1e-4
warmup_lr = 1e-6
epochs = 500
eta_min = 1e-7

# Initialize optimizer with warmup_lr
optimizer = torch.optim.Adam(masked_pre_model.parameters(), lr=warmup_lr)

# Create cosine scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=epochs - warmup_epochs, eta_min=eta_min)

# Training loop with BERT-style masking
loss_history = []
lr_history = []
smoothing_factor = 0.9
smoothed_loss = None
# 使用加权 MSE 损失
criterion = WeightedMSELoss(pos_weight=10.0)

for epoch in tqdm(range(epochs), desc="Training", leave=True):
    masked_pre_model.train()
    optimizer.zero_grad()

    # Warmup learning rate
    if epoch < warmup_epochs:
        lr = warmup_lr + (initial_lr - warmup_lr) * epoch / warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

    # Record current learning rate
    lr_history.append(optimizer.param_groups[0]['lr'])

    # Randomly select 15% of the nodes to mask
    mask = torch.rand(num_nodes, device=device) < 0.15
    target = data.x[mask]

    # Apply BERT-style masking strategy
    modified_data_x = data.x.clone()
    for idx in torch.where(mask)[0]:
        rand = random.random()
        if rand < 0.8:  # 80% of the time, replace with [MASK] token (0)
            modified_data_x[idx] = 0
        elif rand < 0.9:  # 10% of the time, replace with a random value
            modified_data_x[idx] = torch.randn_like(data.x[idx])
        # 10% of the time, leave it unchanged

    # Update the data object
    data.x = modified_data_x

    # Forward pass
    predictions = masked_pre_model(data, mask)

    # Compute weighted loss only for masked nodes
    loss = criterion(predictions, target)

    # Smooth loss using exponential moving average
    if smoothed_loss is None:
        smoothed_loss = loss.item()
    else:
        smoothed_loss = smoothing_factor * smoothed_loss + (1 - smoothing_factor) * loss.item()

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    # Log smoothed loss
    loss_history.append(smoothed_loss)

    # Update progress bar with loss info
    if epoch % 50 == 0:
        tqdm.write(f"Epoch {epoch}/{epochs}, Loss: {smoothed_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")

In [None]:
# Collect predictions and targets
all_predictions, all_targets = collect_predictions(
    model=masked_pre_model,
    data=graph_data_dict['13months-disease-replicate_1'].clone().to(device),
    num_evaluations=10,
    mask_percentage=0.15,
    device=device
)


In [152]:
importlib.reload(src.visualization.validate_prediction)

from src.visualization.validate_prediction import visualize_all

In [None]:
all_predictions.shape

In [153]:
# Perform all visualizations
visualize_all(all_predictions, all_targets)

In [143]:
# Assuming masked_pre_model is already trained and on the correct device
encoder_model = EncoderOnly(encoders=masked_pre_model.encoders).to(device)

# It's a good practice to set the model to evaluation mode if you're using it for inference
encoder_model.eval()

# (Optional) Verify that the parameters are correctly copied
for param_encoder, param_original in zip(encoder_model.parameters(), masked_pre_model.parameters()):
    assert torch.equal(param_encoder, param_original), "Parameters do not match!"


In [144]:
data = graph_data_dict['13months-disease-replicate_1'].clone().to(device)
encoded_embeddings = encoder_model(data.x, data.edge_index, data.edge_attr)

In [145]:
encoded_embeddings.shape

In [146]:
test_adata = adatas['13months-disease-replicate_1'].copy()
test_adata.obsm["X_encoded"] = encoded_embeddings.detach().cpu().numpy()

# Step 1: Compute neighbors and UMAP for the 'X_encoded' representation
sc.pp.neighbors(test_adata, use_rep="X_encoded", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_encoded"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X_encoded")

# Step 2: Compute neighbors and UMAP for the 'X' representation
sc.pp.neighbors(test_adata, use_rep="X", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_X"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X")

# Step 3: Compute PCA, neighbors, and UMAP for the 'X_pca' representation
sc.pp.pca(test_adata, n_comps=50)
sc.pp.neighbors(test_adata, use_rep="X_pca", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_pca"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X_pca")

# Comput PCA for the 'X_encoded' representation, then use it for UMAP, neighbors, and leiden clustering
# Step 1: Extract the 'X_encoded' representation
X_encoded = test_adata.obsm["X_encoded"]
# Step 2: Compute PCA on the 'X_encoded' representation
# Note: sc.pp.pca works directly on the AnnData object; for custom data, use sc.tl.pca
pca_result = sc.tl.pca(X_encoded, n_comps=50, svd_solver='arpack', return_info=True)
# Step 3: Store the PCA results in 'obsm'
test_adata.obsm["X_encoded_pca"] = pca_result[0]  # PCA coordinates
sc.pp.neighbors(test_adata, use_rep="X_encoded_pca", n_neighbors=10)
sc.tl.umap(test_adata)
test_adata.obsm["X_umap_encoded_pca"] = test_adata.obsm["X_umap"].copy()
sc.tl.leiden(test_adata, resolution=1.0, key_added="leiden_X_encoded_pca")

# Plot UMAP for 'X_encoded' representation
sc.pl.embedding(test_adata, basis="X_umap_encoded", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)
# Plot UMAP for 'X' representation
sc.pl.embedding(test_adata, basis="X_umap_X", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)
# Plot UMAP for 'X' representation
sc.pl.embedding(test_adata, basis="X_umap_pca", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)
# Plot UMAP for 'X_encoded_pca' representation
sc.pl.embedding(test_adata, basis="X_umap_encoded_pca", color=["top_level_cell_type_x"], ncols=1, frameon=False, wspace=0.5)


In [147]:
# Step 4: Visualize clustering results on the 'spatial' basis
# Assuming 'spatial' coordinates are stored in `test_adata.obsm['spatial']`
fig, axes = plt.subplots(1, 4, figsize=(24, 6))

# Plot for 'X_encoded'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X_encoded",
    frameon=False,
    ax=axes[0],
    show=False,
    title="Clusters: X_encoded"
)

# Plot for 'X'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X",
    frameon=False,
    ax=axes[1],
    show=False,
    title="Clusters: X"
)

# Plot for 'X_pca'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X_pca",
    frameon=False,
    ax=axes[2],
    show=False,
    title="Clusters: X_pca"
)

# Plot for 'X_encoded_pca'
sc.pl.embedding(
    test_adata,
    basis="spatial",
    color="leiden_X_encoded_pca",
    frameon=False,
    ax=axes[3],
    show=False,
    title="Clusters: X_encoded_pca"
)

for ax in axes:
    ax.set_aspect("equal")
    ax.axis("off")

plt.tight_layout()
plt.show()

# 两步训练

In [185]:
##################################################
# 生成 “零 or 非零” 的标签矩阵
##################################################
# Z[i, d] = 1 if X[i,d]!=0 else 0
Z = (data.x != 0).float()

In [186]:
Z

In [200]:
class TwoStepGATModel(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_channels,
        edge_dim=None,
        heads=4,
        num_encoders=2,
        dropout=0.1,
    ):
        """
        包含:
         1) GATv2 编码器 (多个encoder layer堆叠)
         2) classifier_head: 用来预测 0/1 (zero or non-zero)
         3) regressor_head : 用来预测数值大小
        """
        super(TwoStepGATModel, self).__init__()
        self.encoders = nn.ModuleList()
        for i in range(num_encoders):
            self.encoders.append(
                GATv2EncoderLayer(
                    in_channels=in_features if i == 0 else hidden_channels,
                    out_channels=hidden_channels,
                    heads=heads,
                    edge_dim=edge_dim,
                    dropout=dropout,
                )
            )

        # Head 1: 分类(是否为零)，在多标签问题上，每个特征维度输出一个logit
        self.classifier_head = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.GELU(),
            nn.Linear(hidden_channels, in_features)  # 最终输出 [N, D]
        )

        # Head 2: 回归(如果是非零，输出值)，也是 [N, D]
        self.regressor_head = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.GELU(),
            nn.Linear(hidden_channels, in_features)
        )

    def forward(self, x, edge_index, edge_attr):
        """
        Args:
          x:         [N, D] 节点特征
          edge_index: 图结构
          edge_attr:  边特征 (可选)
        Returns:
          logits_class:  [N, D], (是否为零)的logits
          pred_reg:      [N, D], (数值回归)的输出
        """
        # 1) 先过 encoder
        h = x.clone()
        for encoder in self.encoders:
            h = encoder(h, edge_index, edge_attr)

        # 2) 分别走分类和回归两个头
        logits_class = self.classifier_head(h)   # zero vs non-zero logits
        pred_reg = self.regressor_head(h)        # regression for actual values

        return logits_class, pred_reg


In [199]:
class ZeroNonZeroLoss(nn.Module):
    """
    多任务Loss:
      1) 分类损失(零 or 非零): BCEWithLogitsLoss
      2) 回归损失: MSE，只在真实非零位置
    """
    def __init__(self, lambda_cls=0.5, lambda_reg=0.5, pos_weight=5.0):
        super(ZeroNonZeroLoss, self).__init__()
        self.lambda_cls = lambda_cls
        self.lambda_reg = lambda_reg
        # 用BCEWithLogitsLoss带pos_weight
        self.criterion_cls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        # 回归损失也可自定义
        self.criterion_reg = nn.MSELoss(reduction='mean')

    def forward(self, logits_class, pred_reg, x, z):
        """
        logits_class: [N, D], classifier输出 (logits)
        pred_reg:     [N, D], regressor输出
        x:            [N, D], 节点特征(原始真值)
        z:            [N, D], 0/1指示 (是否非零)
        Returns:
          total_loss
        """
        # 1) 分类损失
        loss_cls = self.criterion_cls(logits_class, z)

        # 2) 回归损失: 只在 z==1 的地方计算
        mask_nonzero = (z == 1)
        if mask_nonzero.sum() > 0:
            loss_reg = self.criterion_reg(pred_reg[mask_nonzero], x[mask_nonzero])
        else:
            # 万一数据全是0，那就不算reg了
            loss_reg = torch.tensor(0.0, device=x.device)

        return self.lambda_cls * loss_cls + self.lambda_reg * loss_reg
    
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLossWithLogits(nn.Module):
    """
    Focal Loss for binary classification (per-label).
    Accepts logits as input, applies sigmoid internally.
    Allows a 'pos_weight' for class imbalance.
    """
    def __init__(self, gamma=2.0, alpha=1.0, pos_weight=1.0, reduction='mean'):
        super(FocalLossWithLogits, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.pos_weight = pos_weight
        self.reduction = reduction

    def forward(self, logits, targets):
        """
        logits: [N, D] (raw, un-sigmoided)
        targets: [N, D] in {0,1}
        """
        # BCE with logits
        bce_term = F.binary_cross_entropy_with_logits(
            logits, targets, pos_weight=torch.tensor(self.pos_weight, device=logits.device),
            reduction='none'
        )
        # p = sigmoid(logits)
        p = torch.sigmoid(logits)
        # focal weight
        pt = p * targets + (1 - p) * (1 - targets)
        focal_factor = (1 - pt).pow(self.gamma)

        loss = self.alpha * focal_factor * bce_term

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class ZeroNonZeroFocalMultiTaskLoss(nn.Module):
    """
    多任务Loss:
      1) 分类损失(零 or 非零): FocalLossWithLogits
      2) 回归损失: MSE，只在真实非零处计算
    """
    def __init__(self, lambda_cls=0.5, lambda_reg=0.5, focal_gamma=2.0, focal_alpha=1.0, focal_pos_weight=20.0):
        super(ZeroNonZeroFocalMultiTaskLoss, self).__init__()
        self.lambda_cls = lambda_cls
        self.lambda_reg = lambda_reg
        self.focal_loss = FocalLossWithLogits(
            gamma=focal_gamma,
            alpha=focal_alpha,
            pos_weight=focal_pos_weight,
            reduction='mean'
        )
        self.criterion_reg = nn.MSELoss(reduction='mean')  # 回归用普通 MSE

    def forward(self, logits_class, pred_reg, x, z):
        """
        logits_class: [N, D], classifier输出 (logits)
        pred_reg:     [N, D], regressor输出
        x:            [N, D], 节点特征(真实值)
        z:            [N, D], 0/1指示 (是否非零)
        """
        # 1) 分类损失 (Focal Loss)
        loss_cls = self.focal_loss(logits_class, z)

        # 2) 回归损失：只在 z==1 的地方
        mask_nonzero = (z == 1)
        if mask_nonzero.sum() > 0:
            loss_reg = self.criterion_reg(pred_reg[mask_nonzero], x[mask_nonzero])
        else:
            loss_reg = torch.tensor(0.0, device=x.device)

        return self.lambda_cls * loss_cls + self.lambda_reg * loss_reg


In [201]:
# 初始化模型
hidden_channels = 64
heads = 4
num_encoders = 4
dropout = 0.2

model = TwoStepGATModel(
    in_features=in_features,
    hidden_channels=hidden_channels,
    edge_dim=edge_dim,
    heads=heads,
    num_encoders=num_encoders,
    dropout=dropout
).to(device)

# 定义优化器与Scheduler
warmup_epochs = 500
initial_lr = 1e-3
warmup_lr = 1e-5
epochs = 1000
eta_min = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=warmup_lr)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs - warmup_epochs, eta_min=eta_min)

# 1) 初始化 Loss
criterion_multi = ZeroNonZeroFocalMultiTaskLoss(
    lambda_cls=1,
    lambda_reg=0,
    focal_gamma=2.0,
    focal_alpha=1.0,
    focal_pos_weight=20.0  # 适当增大
).to(device)

loss_history = []
lr_history = []
smoothing_factor = 0.9
smoothed_loss = None

# 训练循环
for epoch in tqdm(range(epochs), desc="Training", leave=True):
    model.train()
    optimizer.zero_grad()

    # Warmup learning rate
    if epoch < warmup_epochs:
        lr = warmup_lr + (initial_lr - warmup_lr) * epoch / warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

    lr_history.append(optimizer.param_groups[0]['lr'])

    ###########################
    # 1) 构造随机mask (例如15%)
    ###########################
    mask = torch.rand(num_nodes, device=device) < 0.15
    # 对被mask掉的节点特征进行处理(与之前BERT-style做法类似)
    target_x = data.x[mask]           # 真实值(仅mask节点)
    target_z = Z[mask]               # 零/非零标签(仅mask节点)

    modified_data_x = data.x.clone()
    for idx in torch.where(mask)[0]:
        rand = random.random()
        if rand < 0.8:
            modified_data_x[idx] = 0
        elif rand < 0.9:
            modified_data_x[idx] = torch.randn_like(data.x[idx])
        # else: 保持原值

    # 更新 data.x
    data.x = modified_data_x

    ###########################
    # 2) 前向计算
    ###########################
    logits_class, pred_reg = model(
        x=data.x,
        edge_index=data.edge_index,
        edge_attr=data.edge_attr
    )
    # 只拿mask节点的输出
    logits_class_masked = logits_class[mask]
    pred_reg_masked     = pred_reg[mask]

    ###########################
    # 3) 计算损失并反向传播
    ###########################
    loss = criterion_multi(
        logits_class_masked,  # [num_masked, D]
        pred_reg_masked,      # [num_masked, D]
        target_x,             # [num_masked, D]
        target_z              # [num_masked, D]
    )

    # 平滑loss
    if smoothed_loss is None:
        smoothed_loss = loss.item()
    else:
        smoothed_loss = smoothing_factor * smoothed_loss + (1 - smoothing_factor) * loss.item()

    loss.backward()
    optimizer.step()

    loss_history.append(smoothed_loss)

    if epoch % 50 == 0:
        tqdm.write(f"Epoch {epoch}/{epochs}, Loss: {smoothed_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")


# 可视化学习率和损失
plt.figure(figsize=(8,4))
plt.plot(lr_history, label='Learning Rate')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.legend()
plt.show()

plt.figure(figsize=(8,4))
plt.plot(loss_history, label='Train Loss (smoothed)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()


In [202]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import precision_recall_curve, roc_curve, auc
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def visualize_zero_nonzero_classification(model, data, Z, threshold=0.5, sample_ratio=1.0):
    """
    可视化 zero/non-zero 分类效果。
    
    Args:
        model:          训练好的模型(包含分类头)
        data:           图数据对象(data.x, data.edge_index, data.edge_attr)
        Z:              零/非零标签 [N, D], 0 or 1
        threshold:      用于将预测概率 -> 0/1 的阈值
        sample_ratio:   如果数据太多，可以只采样一部分(0,1], 默认1.0表示不采样
    """
    model.eval()
    with torch.no_grad():
        # 得到分类 logits
        logits_class, _ = model(
            x=data.x, 
            edge_index=data.edge_index, 
            edge_attr=data.edge_attr
        )
    
    # 计算概率
    probs = torch.sigmoid(logits_class)  # [N, D], in [0,1]
    
    # 把数据拉平，得到 1D 向量
    true_label = Z.view(-1).cpu().numpy()      # [N*D]
    pred_prob  = probs.view(-1).cpu().numpy()  # [N*D]
    
    # 可选：只采样部分数据做可视化(若非常巨大)
    total_size = len(true_label)
    if sample_ratio < 1.0:
        idx = np.random.choice(total_size, int(sample_ratio * total_size), replace=False)
        true_label = true_label[idx]
        pred_prob  = pred_prob[idx]
    
    # 1) 绘制预测概率分布直方图
    fig, ax = plt.subplots(figsize=(8,5))
    # 分别取真值为0 和 真值为1
    prob_for_0 = pred_prob[true_label == 0]
    prob_for_1 = pred_prob[true_label == 1]
    
    ax.hist(prob_for_0, bins=50, alpha=0.5, label='True=0', density=True, color='blue')
    ax.hist(prob_for_1, bins=50, alpha=0.5, label='True=1', density=True, color='red')
    ax.set_xlabel("Predicted Probability")
    ax.set_ylabel("Density")
    ax.set_title("Distribution of Predicted Probability (Zero vs Non-Zero)")
    ax.legend()
    plt.show()
    
    # 2) 混淆矩阵 (基于 threshold=0.5)
    pred_label = (pred_prob >= threshold).astype(int)
    cm = confusion_matrix(true_label, pred_label, labels=[0,1])
    
    # 可视化混淆矩阵
    plt.figure(figsize=(4,4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Pred=0','Pred=1'],
                yticklabels=['True=0','True=1'])
    plt.title(f"Confusion Matrix (threshold={threshold})")
    plt.show()
    
    # 3) 分类报告
    print("Classification Report (threshold={:.2f}):".format(threshold))
    print(classification_report(true_label, pred_label, digits=4))
    
    # 4) (可选) Plot Precision-Recall / ROC 曲线
    # 这里示例 Precision-Recall
    precision, recall, _ = precision_recall_curve(true_label, pred_prob)
    pr_auc = auc(recall, precision)
    plt.figure(figsize=(5,5))
    plt.plot(recall, precision, label=f'PR curve (AUC={pr_auc:.4f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.show()

    # 也可以加 ROC
    fpr, tpr, _ = roc_curve(true_label, pred_prob)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(5,5))
    plt.plot(fpr, tpr, label=f'ROC curve (AUC={roc_auc:.4f})')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend()
    plt.show()


In [203]:
# 假设以下变量已准备好:
#  model       : 训练好的 TwoStepGATModel (或者别的包含分类头的模型)
#  data        : 数据(含x, edge_index, edge_attr等)
#  Z           : [N, D]的0/1矩阵 (指示某特征维度是否为非零)
#  device      : cuda 或 cpu

# 确保 model 和 data 在同一个 device 上
model.to(device)
data = data.to(device)
Z = Z.to(device)  # [N, D]

visualize_zero_nonzero_classification(model, data, Z, threshold=0.5, sample_ratio=0.2)
