# 新段落

## Installation

In [None]:
!pip install git+https://github.com/google-research/swirl-dynamics.git@main

In [None]:
!pip install tensorflow-datasets

## Imports

In [None]:
import functools

from clu import metric_writers
import jax
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import optax
import orbax.checkpoint as ocp
import tensorflow as tf
import tensorflow_datasets as tfds

from swirl_dynamics import templates
from swirl_dynamics.lib import diffusion as dfn_lib
from swirl_dynamics.lib import solvers as solver_lib
from swirl_dynamics.projects import probabilistic_diffusion as dfn

## Example I - Unconditional diffusion model with guidance

### Dataset

First we need a dataset containing samples whose distribution is to be modeled by the diffusion model. For demonstration purpose, we use the MNIST dataset provided by TensorFlow Datasets.

Our code setup accepts any Python iterable objects to be used as dataloaders. The expectation is that they should continuously yield a dictionary with a field named `x` whose corresponding value is a numpy array with shape `(batch, *spatial_dims, channels)`.

In [None]:
def get_mnist_dataset(split: str, batch_size: int):
  ds = tfds.load("mnist", split=split)
  ds = ds.map(
      # Change field name from "image" to "x" (required by `DenoisingModel`)
      # and normalize the value to [0, 1].
      lambda x: {"x": tf.cast(x["image"], tf.float32) / 255.0}
  )
  ds = ds.repeat()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.AUTOTUNE)
  ds = ds.as_numpy_iterator()
  return ds

# The standard deviation of the normalized dataset.
# This is useful for determining the diffusion scheme and preconditioning
# of the neural network parametrization.
DATA_STD = 0.31

### Architecture

Next let's define the U-Net backbone. The "Preconditioning" is to ensure that the inputs and outputs of the network are roughly standardized (for more details, see Appendix B.6. in [this paper](https://arxiv.org/abs/2206.00364)).

In [None]:
denoiser_model = dfn_lib.PreconditionedDenoiserUNet(
    out_channels=1,
    num_channels=(64, 128),
    downsample_ratio=(2, 2),
    num_blocks=4,
    noise_embed_dim=128,
    padding="SAME",
    use_attention=True,
    use_position_encoding=True,
    num_heads=8,
    sigma_data=DATA_STD,
)

### Training

For diffusion model training, the above-defined U-Net backbone serves as a denoiser, which takes as input a batch of (isotropic Gaussian noise) corrupted samples and outputs its best guess for what the uncorrupted image would be.

Besides the backbone architecture, we also need to specify how to sample the noise levels (i.e. standard deviations) used to corrupt the samples and the weighting for each noise level in the loss function (for available options and configurations, see [`swirl_dynamics.lib.diffusion.diffusion`](https://github.com/google-research/swirl-dynamics/blob/main/swirl_dynamics/lib/diffusion/diffusion.py)):

In [None]:
diffusion_scheme = dfn_lib.Diffusion.create_variance_exploding(
    sigma=dfn_lib.tangent_noise_schedule(),
    data_std=DATA_STD,
)

model = dfn.DenoisingModel(
    # `input_shape` must agree with the expected sample shape (without the batch
    # dimension), which in this case is simply the dimensions of a single MNIST
    # sample.
    input_shape=(28, 28, 1),
    denoiser=denoiser_model,
    noise_sampling=dfn_lib.log_uniform_sampling(
        diffusion_scheme, clip_min=1e-4, uniform_grid=True,
    ),
    noise_weighting=dfn_lib.edm_weighting(data_std=DATA_STD),
)

We are now ready to define the learning parameters.

In [None]:
# !rm -R -f $workdir  # optional: clear the working directory

In [None]:
num_train_steps = 100_000  #@param
workdir = "/tmp/diffusion_demo_mnist"  #@param
train_batch_size = 32  #@param
eval_batch_size = 32  #@param
initial_lr = 0.0  #@param
peak_lr = 1e-4  #@param
warmup_steps = 1000  #@param
end_lr = 1e-6  #@param
ema_decay = 0.999  #@param
ckpt_interval = 1000  #@param
max_ckpt_to_keep = 5  #@param

To start training, we first need to initialize the trainer.

In [None]:
# NOTE: use `trainers.DistributedDenoisingTrainer` for multi-device
# training with data parallelism.
trainer = dfn.DenoisingTrainer(
    model=model,
    rng=jax.random.PRNGKey(888),
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
    # We keep track of an exponential moving average of the model parameters
    # over training steps. This alleviates the "color-shift" problems known to
    # exist in the diffusion models.
    ema_decay=ema_decay,
)

Now we are ready to kick start training. A couple of "callbacks" are passed to assist with monitoring and checkpointing.

The first step will be a little slow as Jax needs to JIT compile the step function (the same goes for the first step where evaluation is performed). Fortunately, steps after that should continue much faster.

In [None]:
templates.run_train(
    train_dataloader=get_mnist_dataset(
        split="train[:75%]", batch_size=train_batch_size
    ),
    trainer=trainer,
    workdir=workdir,
    total_train_steps=num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        workdir, asynchronous=False
    ),
    metric_aggregation_steps=100,
    eval_dataloader=get_mnist_dataset(
        split="train[75%:]", batch_size=eval_batch_size
    ),
    eval_every_steps = 1000,
    num_batches_per_eval = 2,
    callbacks=(
        # This callback displays the training progress in a tqdm bar
        templates.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
        ),
        # This callback saves model checkpoint periodically
        templates.TrainStateCheckpoint(
            base_dir=workdir,
            options=ocp.CheckpointManagerOptions(
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)

### Inference

#### Unconditional generation

After training is complete, the trained denoiser may be used to generate unconditional samples.

First, let's restore the model from checkpoint.

In [None]:
# Restore train state from checkpoint. By default, the move recently saved
# checkpoint is restored. Alternatively, one can directly use
# `trainer.train_state` if continuing from the training section above.
trained_state = dfn.DenoisingModelTrainState.restore_from_orbax_ckpt(
    f"{workdir}/checkpoints", step=None
)
# Construct the inference function
denoise_fn = dfn.DenoisingTrainer.inference_fn_from_state_dict(
    trained_state, use_ema=True, denoiser=denoiser_model
)

Diffusion samples are generated by plugging the trained denoising function in a stochastic differential equation (parametrized by the diffusion scheme) and solving it backwards in time.

In [None]:
sampler = dfn_lib.SdeSampler(
    input_shape=(28, 28, 1),
    integrator=solver_lib.EulerMaruyama(),
    tspan=dfn_lib.edm_noise_decay(
        diffusion_scheme, rho=7, num_steps=256, end_sigma=1e-3,
    ),
    scheme=diffusion_scheme,
    denoise_fn=denoise_fn,
    guidance_transforms=(),
    apply_denoise_at_end=True,
    return_full_paths=False,  # Set to `True` if the full sampling paths are needed
)

The sampler may be run by calling its `.generate()` function. Optionally, we may JIT compile this function so that it runs faster if repeatedly called.

In [None]:
generate = jax.jit(sampler.generate, static_argnames=('num_samples',))

In [None]:
samples = generate(
    rng=jax.random.PRNGKey(8888), num_samples=4
)

Visualize the generated samples:

In [None]:
# Plot generated samples
fig, ax = plt.subplots(1, 4, figsize=(8, 2))
for i in range(4):
  im = ax[i].imshow(samples[i, :, :, 0] * 255, cmap="gray", vmin=0, vmax=255)

plt.tight_layout()
plt.show()

#### Guided generation

To achieve 'guided' generation, we can modify a trained denoising function and tailor it to produce samples with specific desired characteristics. For instance, in an out-filling task where the goal is to generate full images from a given patch, we can guide the denoiser to create samples whose crops at certain positions precisely align with the provided patch.

In [None]:
guidance_fn = dfn_lib.InfillFromSlices(
    # This specifies location of the guide input using python slices.
    # Here it implies that the guide input corresponds the 7x7 patch in the
    # center of the image.
    slices=(slice(None), slice(11, 18), slice(11, 18)),

    # This is a parameter that controls how "hard" the denoiser pushes for
    # the conditioning to be satisfied. It is a tradeoff between strictness of
    # constraint satisfication and diversity in the generated samples.
    guide_strength=0.1,
)

This transform function is passed through the `guidance_transforms` arg of the sampler.

In [None]:
guided_sampler = dfn_lib.SdeSampler(
    input_shape=(28, 28, 1),
    integrator=solver_lib.EulerMaruyama(),
    tspan=dfn_lib.edm_noise_decay(
        diffusion_scheme, rho=7, num_steps=256, end_sigma=1e-3,
    ),
    scheme=diffusion_scheme,
    denoise_fn=denoise_fn,
    guidance_transforms=(guidance_fn,),
    apply_denoise_at_end=True,
    return_full_paths=False,
)

guided_generate = jax.jit(guided_sampler.generate, static_argnames=('num_samples',))

We construct an example guidance input from a real sample and use it to guide the sampling:

In [None]:
test_ds = get_mnist_dataset(split="test", batch_size=1)
test_example = next(iter(test_ds))["x"]
example_guidance_inputs = {'observed_slices': test_example[:, 11:18, 11:18]}

In [None]:
guided_samples = guided_generate(
    rng=jax.random.PRNGKey(66),
    num_samples=4,
    # Note that the shape of the guidance input must be compatible with
    # `sample[guidance_fn.slices]`
    guidance_inputs=example_guidance_inputs,
)

Visualize guided samples:

In [None]:
# Plot guide patch.
fig, ax = plt.subplots(1, 1, figsize=(2, 2))
im = ax.imshow(
    test_example[0, 11:18, 11:18, 0] * 255, cmap="gray", vmin=0, vmax=255
)
ax.axis("off")
ax.set_title("Guide patch")
plt.tight_layout()
plt.show()

# Plot generated samples.
fig, ax = plt.subplots(1, 4, figsize=(8, 2))
for i in range(4):
  im = ax[i].imshow(
      guided_samples[i, :, :, 0] * 255, cmap="gray", vmin=0, vmax=255
  )
  # Mark out the patch where guidance is enabled.
  square = patches.Rectangle(
      xy=(11, 11), width=7, height=7, fill=False, edgecolor='red'
  )
  ax[i].add_patch(square)
  ax[i].axis("off")
  ax[i].set_title(f"Sample #{i}")

plt.tight_layout()
plt.show()

## Example II - Conditional diffusion model

In the above example, we trained an *unconditional* diffusion model and applied conditioning at inference time. This is not always easy to do, depending on how the conditioning input relates to the samples.

Alternatively, we can directly *train a conditional model*, where the conditional signal is provided at training time as an additional input to the denoising neural network, which may then use it to compute the denoised target.

Below we show an example of how to accomplish this. We again generate samples of handwritten digits, using the MNIST dataset for training. We will condition the generation on the `x[11:18, 11:18]` patch.

### Dataset

Besides the sample in `x`, the dataset for training conditional models require a `cond` key which contains the condition signals.

In [None]:
def preproc_example(example: dict[str, tf.Tensor]):
  processed = {}
  processed["x"] = tf.cast(example["image"], tf.float32) / 255.0

  # The "channel:" prefix indicate that the conditioning signal is to be
  # incorporated by resizing and concatenating along the channel dimension.
  # This is implemented at the backbone level.
  processed["cond"] = {"channel:low_res": processed["x"][11:18, 11:18]}
  return processed


def get_cond_mnist_dataset(split: str, batch_size: int):
  ds = tfds.load("mnist", split=split)
  ds = ds.map(preproc_example)
  ds = ds.repeat()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.AUTOTUNE)
  ds = ds.as_numpy_iterator()
  return ds

DATA_STD = 0.31

### Architecture

The architecture is similar to the unconditional case. We provide additional args that specify how to resize the conditioning signal (in order to be compatible with the noisy sample for channel-wise concatenation).

In [None]:
cond_denoiser_model = dfn_lib.PreconditionedDenoiserUNet(
    out_channels=1,
    num_channels=(64, 128),
    downsample_ratio=(2, 2),
    num_blocks=4,
    noise_embed_dim=128,
    padding="SAME",
    use_attention=True,
    use_position_encoding=True,
    num_heads=8,
    sigma_data=DATA_STD,
    cond_resize_method="cubic",
    cond_embed_dim=128,
)

### Training

The `DenoisingModel` is again similar to the unconditional case. We additionally provide the shape information of the `cond` input.

In [None]:
diffusion_scheme = dfn_lib.Diffusion.create_variance_exploding(
    sigma=dfn_lib.tangent_noise_schedule(),
    data_std=DATA_STD,
)

cond_model = dfn.DenoisingModel(
    input_shape=(28, 28, 1),
    # `cond_shape` must agree with the expected structure and shape
    # (without the batch dimension) of the `cond` input.
    cond_shape={"channel:low_res": (7, 7, 1)},
    denoiser=cond_denoiser_model,
    noise_sampling=dfn_lib.log_uniform_sampling(
        diffusion_scheme, clip_min=1e-4, uniform_grid=True,
    ),
    noise_weighting=dfn_lib.edm_weighting(data_std=DATA_STD),
)

The rest mostly repeats the unconditional training example, replacing the datasets and model with their conditional counterparts.

In [None]:
# !rm -R -f $cond_workdir  # optional: clear the working directory

In [None]:
num_train_steps = 100_000  #@param
cond_workdir = "/tmp/cond_diffusion_demo_mnist"  #@param
train_batch_size = 32  #@param
eval_batch_size = 32  #@param
initial_lr = 0.0  #@param
peak_lr = 1e-4  #@param
warmup_steps = 1000  #@param
end_lr = 1e-6  #@param
ema_decay = 0.999  #@param
ckpt_interval = 1000  #@param
max_ckpt_to_keep = 5  #@param

In [None]:
cond_trainer = dfn.DenoisingTrainer(
    model=cond_model,
    rng=jax.random.PRNGKey(888),
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
    ema_decay=ema_decay,
)

templates.run_train(
    train_dataloader=get_cond_mnist_dataset(
        split="train[:75%]", batch_size=train_batch_size
    ),
    trainer=cond_trainer,
    workdir=cond_workdir,
    total_train_steps=num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        cond_workdir, asynchronous=False
    ),
    metric_aggregation_steps=100,
    eval_dataloader=get_cond_mnist_dataset(
        split="train[75%:]", batch_size=eval_batch_size
    ),
    eval_every_steps = 1000,
    num_batches_per_eval = 2,
    callbacks=(
        templates.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
        ),
        templates.TrainStateCheckpoint(
            base_dir=cond_workdir,
            options=ocp.CheckpointManagerOptions(
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)

### Inference

To perform inference/sampling, let's load back the trained conditional model checkpoint:

In [None]:
trained_state = dfn.DenoisingModelTrainState.restore_from_orbax_ckpt(
    f"{cond_workdir}/checkpoints", step=None
)
# Construct the inference function
cond_denoise_fn = dfn.DenoisingTrainer.inference_fn_from_state_dict(
    trained_state, use_ema=True, denoiser=cond_denoiser_model
)

The conditional sampler again follows the previous example, with the only exception being that the conditional model replaces the unconditional one.

Below we do not apply any guidance, but one can be easily added in the same way as in the unconditional example above.

In [None]:
cond_sampler = dfn_lib.SdeSampler(
    input_shape=(28, 28, 1),
    integrator=solver_lib.EulerMaruyama(),
    tspan=dfn_lib.edm_noise_decay(
        diffusion_scheme, rho=7, num_steps=256, end_sigma=1e-3,
    ),
    scheme=diffusion_scheme,
    denoise_fn=cond_denoise_fn,
    guidance_transforms=(),
    apply_denoise_at_end=True,
    return_full_paths=False,
)

We again JIT the generate function for the sake of faster repeated sampling calls. Here we employ `functools.partial` to specify `num_samples=5`, making it easier to vectorize across the batch dimension with `jax.vmap`.

In [None]:
num_samples_per_cond = 5

generate = jax.jit(
    functools.partial(cond_sampler.generate, num_samples_per_cond)
)

Loading a test batch of conditions with 4 elements:

In [None]:
batch_size = 4
test_ds = get_cond_mnist_dataset(split="test", batch_size=4)
test_batch_cond = next(iter(test_ds))["cond"]

The vectorized generate function is applied to the loaded batch. The vectorization occurs for the leading dimensions of both the random seed and the condition (for those unfamiliarized with vectorized operations in jax, think of a more efficient `for` loop that iterates over the random seeds and batch conditions zipped together).

In [None]:
cond_samples = jax.vmap(generate, in_axes=(0, 0, None))(
    jax.random.split(jax.random.PRNGKey(8888), batch_size),
    test_batch_cond,
    None,  # Guidance inputs = None since no guidance transforms involved
)

The result `cond_samples` has shape `(batch_size, num_samples_per_cond, *input_shape)`.

In [None]:
print(cond_samples.shape)

Visualize generated examples alongside their low-res conditioning:

In [None]:
for i in range(batch_size):
  fig, ax = plt.subplots(1, 1, figsize=(2, 2))
  im = ax.imshow(
      test_batch_cond["channel:low_res"][i, :, :, 0] * 255,
      cmap="gray", vmin=0, vmax=255
  )
  ax.axis("off")
  ax.set_title(f"Low-res condition: #{i + 1}")


  # Plot generated samples.
  fig, ax = plt.subplots(
      1, num_samples_per_cond, figsize=(num_samples_per_cond * 2, 2)
  )
  for j in range(num_samples_per_cond):
    im = ax[j].imshow(
        cond_samples[i, j, :, :, 0] * 255, cmap="gray", vmin=0, vmax=255
    )
    square = patches.Rectangle(
        xy=(11, 11), width=7, height=7, fill=False, edgecolor='red'
    )
    ax[j].add_patch(square)
    ax[j].set_title(f"conditional sample: #{j + 1}")
    ax[j].axis("off")

  plt.tight_layout()

plt.show()

new内容


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torchdiffeq import odeint_adjoint as odeint  # 用于SDE/SDDE积分


# ============= 1. 基础组件：现代Hopfield层（三种模型复用） =============
class HopfieldLayer(nn.Module):
    """基于连续状态的现代Hopfield网络（参考2008.02217v3.pdf的连续记忆机制）"""
    def __init__(self, dim_model, num_heads=4, beta=1.0):
        super().__init__()
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.beta = beta  # 温度系数，调节注意力权重锐度
        self.head_dim = dim_model // num_heads
        assert dim_model % num_heads == 0, "dim_model must be divisible by num_heads"

        # Query/Key/Value投影（实现模式存储与检索，参考文档的Key-Query映射）
        self.query_proj = nn.Linear(dim_model, dim_model)
        self.key_proj = nn.Linear(dim_model, dim_model)
        self.value_proj = nn.Linear(dim_model, dim_model)
        self.output_proj = nn.Linear(dim_model, dim_model)  # 残差投影

    def forward(self, x):
        """x: (batch_size, seq_len, dim_model)，输出带记忆增强的特征"""
        batch_size, seq_len, _ = x.shape

        # 1. 线性投影：将输入映射到Hopfield空间
        Q = self.query_proj(x)  # (B, T, D)
        K = self.key_proj(x)    # (B, T, D)
        V = self.value_proj(x)  # (B, T, D)

        # 2. 多头注意力拆分（并行记忆检索，提升模式匹配精度）
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (B, H, T, hD)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 3. Hopfield记忆更新：softmax注意力近似模式关联（参考文档的ξ_new = X·softmax(βX^Tξ)）
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)  # 相似度分数
        scores = scores * self.beta  # 温度调节
        attn_weights = F.softmax(scores, dim=-1)  # 记忆关联权重

        # 4. 加权检索记忆模式
        memory_output = torch.matmul(attn_weights, V)  # (B, H, T, hD)

        # 5. 合并多头并残差连接（保留原始特征，避免记忆偏移）
        memory_output = memory_output.transpose(1, 2).contiguous()  # (B, T, H, hD)
        memory_output = memory_output.view(batch_size, seq_len, self.dim_model)  # (B, T, D)
        memory_output = self.output_proj(memory_output) + x  # 残差连接

        return memory_output


# ============= 2. 对比模型1：神经随机时滞微分方程（SDDE-Net） =============
class SDDE_Net(nn.Module):
    """神经随机时滞微分方程模型（参考SDDE_net.pdf，含时滞项建模）"""
    def __init__(self, input_dim, hidden_dim=32, tau=2):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.tau = tau  # 时滞步数：使用前tau步的状态作为延迟输入（核心区别于SDE）

        # 1. 时滞特征编码器：拼接当前+延迟状态，建模时滞依赖
        self.lag_encoder = nn.Sequential(
            nn.Linear(input_dim * (tau + 1), hidden_dim),  # 输入：当前状态+前tau步状态
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 2. SDDE核心组件：漂移项（确定性时滞动力学）+ 扩散项（随机时滞噪声）
        self.drift_net = nn.Sequential(  # 漂移项：建模带时滞的确定性动态
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.diffusion_net = nn.Sequential(  # 扩散项：建模带时滞的随机性强度
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 3. 解码器：从SDDE隐藏空间映射回原始输出
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, input_dim)
        )

    def get_lag_state(self, x_seq):
        """提取时滞状态：x_seq (B, T, D) → 拼接当前状态+前tau步状态 (B, T-tau, D*(tau+1))"""
        batch_size, seq_len, _ = x_seq.shape
        lag_states = []
        # 从第tau步开始，每步拼接前tau步状态
        for t in range(self.tau, seq_len):
            lag_window = x_seq[:, t - self.tau : t + 1, :]  # (B, tau+1, D)
            lag_window_flat = lag_window.view(batch_size, -1)  # (B, D*(tau+1))
            lag_states.append(lag_window_flat)
        return torch.stack(lag_states, dim=1)  # (B, T-tau, D*(tau+1))

    def sde_dynamics(self, t, z):
        """SDDE动力学方程：dz = drift(t,z)dt + diffusion(t,z)dW(t)"""
        drift = self.drift_net(z)  # 漂移项：带时滞的确定性增量
        diffusion = F.softplus(self.diffusion_net(z)) + 1e-6  # 扩散项：确保非负噪声强度
        return drift, diffusion

    def forward(self, x_seq, t_eval=None):
        batch_size, seq_len, _ = x_seq.shape
        device = x_seq.device

        # 1. 提取时滞特征（核心：建模时滞依赖）
        lag_states = self.get_lag_state(x_seq)  # (B, T-tau, D*(tau+1))
        encoded_lag = self.lag_encoder(lag_states)  # (B, T-tau, H)

        # 2. 初始化SDDE积分：取第一个时滞特征作为初始状态
        z0 = encoded_lag[:, 0, :]  # (B, H)
        if t_eval is None:
            t_eval = torch.linspace(0.0, 1.0, seq_len - self.tau, device=device)

        # 3. SDDE数值积分（使用odeint近似随机积分，参考SDDE_net.pdf的Euler-Maruyama离散化）
        def drift_func(t, z):  # 仅返回漂移项，用于odeint接口（扩散项用于不确定性分析）
            return self.drift_net(z)

        z_path = odeint(
            func=drift_func,
            y0=z0,
            t=t_eval,
            method='dopri5',  # 高精度数值积分方法
            adjoint_params=list(self.drift_net.parameters())
        )  # z_path: (T-lag, B, H)
        z_path = z_path.transpose(0, 1)  # (B, T-lag, H)

        # 4. 计算扩散项（用于不确定性量化，不参与预测）
        diffusion = F.softplus(self.diffusion_net(z_path)) + 1e-6

        # 5. 解码预测（补齐前tau步的预测值为0，便于统一长度）
        pred = self.decoder(z_path)  # (B, T-lag, D)
        pred_pad = torch.zeros(batch_size, self.tau, self.input_dim, device=device)  # 前tau步补0
        pred_full = torch.cat([pred_pad, pred], dim=1)  # (B, T, D)

        return pred_full, diffusion


# ============= 3. 对比模型2：MHNN（Hopfield+Transformer） =============
class TransformerEncoderBlock(nn.Module):
    """轻量级Transformer编码器块（捕捉长程时序依赖）"""
    def __init__(self, dim_model, num_heads=4, dim_feedforward=64):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dim_model, num_heads, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(dim_model, dim_feedforward),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, dim_model)
        )
        self.norm1 = nn.LayerNorm(dim_model)  # 层归一化，稳定训练
        self.norm2 = nn.LayerNorm(dim_model)

    def forward(self, x):
        # 自注意力层：捕捉长程时序关联
        attn_out, _ = self.self_attn(x, x, x)
        x = self.norm1(x + attn_out)  # 残差+归一化

        # 前馈网络层：非线性变换
        ff_out = self.feed_forward(x)
        x = self.norm2(x + ff_out)  # 残差+归一化
        return x

class MHNN(nn.Module):
    """MHNN模型：Hopfield记忆层 + Transformer（无随机性建模，纯时序记忆）"""
    def __init__(self, input_dim, hidden_dim=32, num_heads=4, num_transformer_blocks=2):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        # 1. 输入编码器：映射原始序列到隐藏空间
        self.input_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 2. 核心层：Hopfield记忆（提取关键模式）→ Transformer（长程依赖）
        self.hopfield_layer = HopfieldLayer(hidden_dim, num_heads=num_heads)  # 复用Hopfield层
        self.transformer_encoder = nn.Sequential(
            *[TransformerEncoderBlock(hidden_dim, num_heads) for _ in range(num_transformer_blocks)]
        )

        # 3. 解码器：从记忆-时序特征映射回输出
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x_seq):
        """x_seq: (B, T, D) → 输出预测序列 (B, T, D)"""
        # 1. 输入编码
        encoded = self.input_encoder(x_seq)  # (B, T, H)

        # 2. 记忆增强 + 长程依赖捕捉
        memory_enhanced = self.hopfield_layer(encoded)  # Hopfield提取关键模式
        temporal_enhanced = self.transformer_encoder(memory_enhanced)  # Transformer捕捉长程依赖

        # 3. 解码预测
        pred = self.decoder(temporal_enhanced)  # (B, T, D)
        return pred


# ============= 4. 原有模型：Hopfield-SDE（基准模型） =============
class HopfieldSDENet(nn.Module):
    """融合Hopfield记忆与Neural SDE的模型（原有代码，保持一致）"""
    def __init__(self, input_dim, hidden_dim=32, num_heads=4):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        # 1. Hopfield记忆层：提取序列历史特征
        self.memory_layer = HopfieldLayer(hidden_dim, num_heads=num_heads)

        # 2. 编码器：将输入序列映射到隐藏空间
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 3. Neural SDE 组件：描述隐藏空间动力学
        self.drift_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.diffusion_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 4. 解码器：将隐藏空间映射回原始输入空间
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x_seq, t_eval=None, return_path=False):
        batch_size, seq_len, _ = x_seq.shape
        device = x_seq.device

        if t_eval is None:
            t_eval = torch.linspace(0.0, 1.0, seq_len, device=device)

        # 步骤1：编码输入序列到隐藏空间
        encoded = self.encoder(x_seq)  # (B, T, H)

        # 步骤2：Hopfield记忆层提取历史关键特征
        memory = self.memory_layer(encoded)  # (B, T, H)

        # 步骤3：取最后时刻的记忆作为SDE初始状态
        z0 = memory[:, -1, :]  # (B, H)

        # 步骤4：SDE动力学积分
        def drift_func(t, z):
            return self.drift_net(z)

        z_path = odeint(
            func=drift_func,
            y0=z0,
            t=t_eval,
            method='dopri5',
            adjoint_params=list(self.drift_net.parameters())
        )  # (T, B, H)
        z_path = z_path.transpose(0, 1)  # (B, T, H)

        # 步骤5：计算扩散项（随机性强度）
        diffusion = F.softplus(self.diffusion_net(z_path)) + 1e-6

        # 步骤6：解码预测
        output = self.decoder(z_path)  # (B, T, D)

        if return_path:
            return output, z_path, diffusion
        return output


# ============= 5. 数据生成：Lorenz混沌系统（统一数据集） =============
def generate_lorenz_data(
    num_samples=200,
    seq_len=50,
    dt=0.01,
    sigma=10.0,
    rho=28.0,
    beta=8/3.0
):
    """生成Lorenz混沌系统数据（带随机性，适配时滞/随机模型对比）"""
    def lorenz_dynamics(state):
        x, y, z = state
        dx = sigma * (y - x)
        dy = x * (rho - z) - y
        dz = x * y - beta * z
        return np.array([dx, dy, dz])

    data = []
    for _ in range(num_samples):
        # 随机初始状态（[-10, 10]）
        state = np.random.uniform(-10, 10, size=3)
        trajectory = [state.copy()]

        # 四阶龙格-库塔法（RK4）生成轨迹（含混沌特性）
        for _ in range(seq_len - 1):
            k1 = lorenz_dynamics(state)
            k2 = lorenz_dynamics(state + 0.5 * dt * k1)
            k3 = lorenz_dynamics(state + 0.5 * dt * k2)
            k4 = lorenz_dynamics(state + dt * k3)
            state = state + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)
            # 加入微小噪声（模拟真实场景随机性）
            state += np.random.normal(0, 0.01, size=3)
            trajectory.append(state.copy())

        data.append(trajectory)

    return np.array(data, dtype=np.float32)  # (num_samples, seq_len, 3)


# ============= 6. 统一训练函数（三种模型复用） =============
def train_unified_model(
    model,
    train_loader,
    val_loader,
    epochs=50,
    lr=1e-3,
    device='cpu',
    print_interval=10,
    model_name="Model"
):
    """统一训练逻辑，返回训练/验证损失与测试MSE"""
    # 优化器与损失函数（三种模型统一）
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)  # L2正则化防过拟合
    criterion = nn.MSELoss()  # 时序预测用MSE损失

    train_losses = []
    val_losses = []
    model.to(device)

    for epoch in range(epochs):
        # 训练阶段
        model.train()
        train_loss = 0.0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            # 适配不同模型输出（SDDE-Net返回(pred, diffusion)，其他返回pred）
            pred_output = model(batch_x)
            batch_pred = pred_output[0] if isinstance(pred_output, tuple) else pred_output

            loss = criterion(batch_pred, batch_y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * batch_x.size(0)

        # 计算平均训练损失
        train_loss /= len(train_loader.dataset)
        train_losses.append(train_loss)

        # 验证阶段（无梯度）
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                pred_output = model(batch_x)
                batch_pred = pred_output[0] if isinstance(pred_output, tuple) else pred_output
                loss = criterion(batch_pred, batch_y)
                val_loss += loss.item() * batch_x.size(0)

        val_loss /= len(val_loader.dataset)
        val_losses.append(val_loss)

        # 打印进度
        if (epoch + 1) % print_interval == 0:
            print(f"[{model_name}] Epoch [{epoch+1:02d}/{epochs:02d}] "
                  f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")

    # 测试阶段：计算测试集MSE
    test_mse = 0.0
    with torch.no_grad():
        for batch_x, batch_y in val_loader:  # 此处用验证集代测试集（数据量限制）
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            pred_output = model(batch_x)
            batch_pred = pred_output[0] if isinstance(pred_output, tuple) else pred_output
            test_mse += criterion(batch_pred, batch_y).item() * batch_x.size(0)
    test_mse /= len(val_loader.dataset)

    return train_losses, val_losses, test_mse, model


# ============= 7. 主程序：三种模型训练+对比可视化 =============
if __name__ == '__main__':
    # 1. 基础配置（统一超参数，确保公平对比）
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input_dim = 3  # Lorenz数据维度（x,y,z）
    hidden_dim = 32  # 隐藏维度统一
    epochs = 50  # 训练轮次统一
    lr = 1e-3  # 学习率统一
    batch_size = 32  # 批次大小统一
    print(f"使用设备: {device} | 超参数统一：hidden_dim={hidden_dim}, epochs={epochs}, lr={lr}")

    # 2. 生成并预处理数据（三种模型共用同一数据集）
    print("\nStep 1: 生成Lorenz混沌数据集...")
    lorenz_data = generate_lorenz_data(num_samples=200, seq_len=50)  # (200, 50, 3)
    # 时序预测任务：用前T-1步预测后T-1步（一步预测）
    X = lorenz_data[:, :-1, :]  # 输入：(200, 49, 3)
    y = lorenz_data[:, 1:, :]   # 标签：(200, 49, 3)

    # 数据标准化（避免量级差异影响训练）
    X_mean = X.mean()
    X_std = X.std()
    X = (X - X_mean) / (X_std + 1e-8)
    y = (y - X_mean) / (X_std + 1e-8)

    # 划分训练集/验证集（8:2）
    split_idx = int(0.8 * len(X))
    X_train, X_val = X[:split_idx], X[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]

    # 创建DataLoader（统一加载逻辑）
    train_dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
    val_dataset = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 3. 初始化三种模型（超参数完全统一）
    print("\nStep 2: 初始化三种对比模型...")
    # 模型1：Hopfield-SDE（基准）
    model_hopfield_sde = HopfieldSDENet(
        input_dim=input_dim, hidden_dim=hidden_dim, num_heads=4
    )
    # 模型2：SDDE-Net（带时滞建模）
    model_sdde = SDDE_Net(
        input_dim=input_dim, hidden_dim=hidden_dim, tau=2  # 时滞步数=2
    )
    # 模型3：MHNN（Hopfield+Transformer）
    model_mhnn = MHNN(
        input_dim=input_dim, hidden_dim=hidden_dim, num_heads=4, num_transformer_blocks=2
    )

    # 打印模型参数量（复杂度对比）
    def count_params(model):
        return sum(p.numel() for p in model.parameters())
    print(f"Hopfield-SDE参数量: {count_params(model_hopfield_sde):,}")
    print(f"SDDE-Net参数量: {count_params(model_sdde):,}")
    print(f"MHNN参数量: {count_params(model_mhnn):,}")

    # 4. 训练三种模型（统一训练逻辑）
    print("\nStep 3: 训练三种模型...")
    # 训练Hopfield-SDE
    print("\n=== 训练Hopfield-SDE ===")
    losses_hsde_train, losses_hsde_val, mse_hsde, model_hsde_trained = train_unified_model(
        model_hopfield_sde, train_loader, val_loader, epochs, lr, device, model_name="Hopfield-SDE"
    )
    # 训练SDDE-Net
    print("\n=== 训练SDDE-Net ===")
    losses_sdde_train, losses_sdde_val, mse_sdde, model_sdde_trained = train_unified_model(
        model_sdde, train_loader, val_loader, epochs, lr, device, model_name="SDDE-Net"
    )
    # 训练MHNN
    print("\n=== 训练MHNN ===")
    losses_mhnn_train, losses_mhnn_val, mse_mhnn, model_mhnn_trained = train_unified_model(
        model_mhnn, train_loader, val_loader, epochs, lr, device, model_name="MHNN"
    )

    # 5. 结果可视化对比
    print("\nStep 4: 可视化三种模型对比结果...")
    plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    models = ["Hopfield-SDE", "SDDE-Net", "MHNN"]
    train_losses = [losses_hsde_train, losses_sdde_train, losses_mhnn_train]
    val_losses = [losses_hsde_val, losses_sdde_val, losses_mhnn_val]
    test_mses = [mse_hsde, mse_sdde, mse_mhnn]
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c']

    # 子图1：训练/验证损失曲线对比
    for i, (model, train_loss, val_loss, color) in enumerate(zip(models, train_losses, val_losses, colors)):
        ax1.plot(range(1, epochs+1), train_loss, label=f'{model} (Train)', linewidth=2, color=color)
        ax1.plot(range(1, epochs+1), val_loss, label=f'{model} (Val)', linewidth=2, color=color, linestyle='--')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('MSE Loss', fontsize=12)
    ax1.set_title('三种模型训练/验证损失对比', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)

    # 子图2：测试集MSE对比（柱状图）
    bars = ax2.bar(models, test_mses, color=colors, alpha=0.7, edgecolor='black')
    for bar, mse in zip(bars, test_mses):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1e-5,
                 f'{mse:.6f}', ha='center', va='bottom', fontsize=10)
    ax2.set_ylabel('测试集MSE', fontsize=12)
    ax2.set_title('三种模型预测精度对比（MSE越低越好）', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3, axis='y')

    # 子图3：单样本预测轨迹对比（X维度，恢复原始尺度）
    model_hsde_trained.eval()
    model_sdde_trained.eval()
    model_mhnn_trained.eval()
    with torch.no_grad():
        sample_idx = 0  # 选择第一条验证样本
        x_sample = torch.from_numpy(X_val[sample_idx:sample_idx+1]).to(device)  # (1, 49, 3)
        y_true = y_val[sample_idx, :, 0] * X_std + X_mean  # 真实值（原始尺度）

        # 三种模型预测
        pred_hsde = model_hsde_trained(x_sample)[0, :, 0].cpu().numpy() * X_std + X_mean
        pred_sdde = model_sdde_trained(x_sample)[0][0, :, 0].cpu().numpy() * X_std + X_mean
        pred_mhnn = model_mhnn_trained(x_sample)[0, :, 0].cpu().numpy() * X_std + X_mean

        time_steps = np.arange(len(y_true))
        ax3.plot(time_steps, y_true, label='真实值（X维度）', linewidth=3, color='black')
        ax3.plot(time_steps, pred_hsde, label='Hopfield-SDE', linewidth=2.5, color=colors[0])
        ax3.plot(time_steps, pred_sdde, label='SDDE-Net', linewidth=2.5, color=colors[1])
        ax3.plot(time_steps, pred_mhnn, label='MHNN', linewidth=2.5, color=colors[2])
        ax3.set_xlabel('时间步', fontsize=12)
        ax3.set_ylabel('Lorenz系统X维度值', fontsize=12)
        ax3.set_title('单样本预测轨迹对比', fontsize=14, fontweight='bold')
        ax3.legend(fontsize=10)
        ax3.grid(True, alpha=0.3)

    # 子图4：不确定性量化对比（仅SDDE-Net和Hopfield-SDE支持）
    with torch.no_grad():
        # 计算两种模型的平均扩散项（不确定性强度）
        x_batch = torch.from_numpy(X_val[:5]).to(device)  # 取5个样本统计
        # Hopfield-SDE扩散项
        _, _, diff_hsde = model_hsde_trained(x_batch, return_path=True)
        mean_diff_hsde = diff_hsde.mean(dim=(0,1,2)).cpu().numpy()  # 平均不确定性
        # SDDE-Net扩散项
        _, diff_sdde = model_sdde_trained(x_batch)
        mean_diff_sdde = diff_sdde.mean(dim=(0,1,2)).cpu().numpy()  # 平均不确定性

        ax4.bar(["Hopfield-SDE", "SDDE-Net"], [mean_diff_hsde, mean_diff_sdde],
                color=colors[:2], alpha=0.7, edgecolor='black')
        ax4.set_ylabel('平均扩散项（不确定性强度）', fontsize=12)
        ax4.set_title('不确定性量化对比（MHNN无此功能）', fontsize=14, fontweight='bold')
        ax4.grid(True, alpha=0.3, axis='y')

    # 保存对比图
    plt.tight_layout()
    plt.savefig('three_models_comparison_lorenz.png', dpi=300, bbox_inches='tight')
    print("\n对比结果图保存至: three_models_comparison_lorenz.png")

    # 6. 三种模型优劣总结
    print("\nStep 5: 三种模型优劣总结")
    print("="*80)
    print(f"{'模型':<15} {'预测精度（MSE）':<20} {'时滞建模':<15} {'不确定性量化':<15} {'计算复杂度（参数）':<20}")
    print("="*80)
    print(f"{models[0]:<15} {test_mses[0]:<20.6f} {'无':<15} {'支持':<15} {count_params(model_hopfield_sde):<20,}")
    print(f"{models[1]:<15} {test_mses[1]:<20.6f} {'支持（tau=2）':<15} {'支持':<15} {count_params(model_sdde):<20,}")
    print(f"{models[2]:<15} {test_mses[2]:<20.6f} {'无（靠Transformer近似）':<15} {'不支持':<15} {count_params(model_mhnn):<20,}")
    print("="*80)
    print("\n训练与对比完成！")

In [None]:
##HNSDDE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torchdiffeq import odeint_adjoint as odeint  # 适配SDDE数值积分


# ============ 1. 基础组件：支持时滞特征的Hopfield记忆层 ============
class HopfieldLayer(nn.Module):
    """现代Hopfield层（适配时滞特征：输入为“当前+时滞”联合特征）"""
    def __init__(self, dim_model, num_heads=4, beta=1.0):
        super().__init__()
        self.dim_model = dim_model  # 输入维度=原特征维度*(时滞步数+1)
        self.num_heads = num_heads
        self.beta = beta
        self.head_dim = dim_model // num_heads
        assert dim_model % num_heads == 0, "dim_model must be divisible by num_heads"

        # Query/Key/Value投影（适配时滞联合特征的记忆检索）
        self.query_proj = nn.Linear(dim_model, dim_model)
        self.key_proj = nn.Linear(dim_model, dim_model)
        self.value_proj = nn.Linear(dim_model, dim_model)
        self.output_proj = nn.Linear(dim_model, dim_model)

    def forward(self, x):
        """x: (batch_size, seq_len, dim_model)，含时滞信息的联合特征"""
        batch_size, seq_len, _ = x.shape

        # 1. 时滞特征投影到Hopfield空间
        Q = self.query_proj(x)
        K = self.key_proj(x)
        V = self.value_proj(x)

        # 2. 多头记忆检索（强化时滞模式的关联匹配）
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 3. Hopfield记忆更新：聚焦时滞模式的相似度匹配
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        scores = scores * self.beta  # 调节时滞模式的匹配锐度
        attn_weights = F.softmax(scores, dim=-1)

        # 4. 加权输出时滞增强的记忆特征
        memory_output = torch.matmul(attn_weights, V)
        memory_output = memory_output.transpose(1, 2).contiguous()
        memory_output = memory_output.view(batch_size, seq_len, self.dim_model)
        return self.output_proj(memory_output) + x  # 残差保留原始时滞信息


# ============ 2. 核心模型：Hopfield-SDDE（Hopfield+神经时滞随机微分方程） ============
class HopfieldSDDENet(nn.Module):
    """
    修正核心：显式加入时滞项
    - 时滞窗口：提取前τ步状态，与当前状态拼接为联合特征
    - SDDE动力学：漂移项/扩散项均依赖Hopfield记忆后的时滞特征
    """
    def __init__(self, input_dim, hidden_dim=32, num_heads=4, tau=2):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.tau = tau  # 时滞步数：依赖前τ步状态（核心新增参数）
        self.lag_feat_dim = input_dim * (tau + 1)  # 时滞联合特征维度=输入维度*(τ+1)

        # 1. 时滞特征编码器：将“当前+τ步时滞”状态映射到隐藏空间
        self.lag_encoder = nn.Sequential(
            nn.Linear(self.lag_feat_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 2. Hopfield记忆层：记忆时滞联合特征的关键模式
        self.hopfield_memory = HopfieldLayer(hidden_dim, num_heads=num_heads)

        # 3. SDDE核心（显式依赖时滞记忆特征）
        # 漂移项：带时滞的确定性动力学
        self.sdde_drift = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )
        # 扩散项：带时滞的随机性强度（确保非负）
        self.sdde_diffusion = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 4. 解码器：从SDDE时滞动力学输出映射回原始维度
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, input_dim)
        )

    def extract_lag_features(self, x_seq):
        """
        提取时滞联合特征：x_seq (B, T, D) → 输出 (B, T-τ, D*(τ+1))
        - 对每个时刻t，拼接 [x(t-τ), x(t-τ+1), ..., x(t)] 作为时滞窗口
        """
        batch_size, seq_len, _ = x_seq.shape
        lag_features = []
        # 从第τ步开始（确保有足够历史时滞）
        for t in range(self.tau, seq_len):
            # 时滞窗口：前τ步 + 当前步
            lag_window = x_seq[:, t - self.tau : t + 1, :]  # (B, τ+1, D)
            lag_window_flat = lag_window.view(batch_size, -1)  # (B, D*(τ+1))
            lag_features.append(lag_window_flat)
        return torch.stack(lag_features, dim=1)  # (B, T-τ, D*(τ+1))

    def sdde_dynamics(self, t, z):
        """
        SDDE动力学方程（显式依赖时滞记忆特征z）
        dz = drift(z)dt + diffusion(z)dW(t)
        """
        drift = self.sdde_drift(z)  # 时滞依赖的确定性增量
        diffusion = F.softplus(self.sdde_diffusion(z)) + 1e-6  # 时滞依赖的随机强度
        return drift, diffusion

    def forward(self, x_seq, t_eval=None, return_diffusion=False):
        batch_size, seq_len, _ = x_seq.shape
        device = x_seq.device

        # 1. 步骤1：提取时滞联合特征（核心修正：新增时滞处理）
        lag_feat = self.extract_lag_features(x_seq)  # (B, T-τ, D*(τ+1))
        # 适配后续序列长度（T' = T-τ）
        seq_len_lag = lag_feat.shape[1]

        # 2. 步骤2：时滞特征编码 + Hopfield记忆增强
        encoded_lag = self.lag_encoder(lag_feat)  # (B, T', H)
        memory_lag = self.hopfield_memory(encoded_lag)  # (B, T', H)：含时滞的记忆特征

        # 3. 步骤3：SDDE数值积分（初始状态为第一个时滞记忆特征）
        z0 = memory_lag[:, 0, :]  # (B, H)：SDDE初始状态（含时滞信息）
        if t_eval is None:
            # 积分时间点：与记忆特征长度一致
            t_eval = torch.linspace(0.0, 1.0, seq_len_lag, device=device)

        # SDDE积分（用odeint近似随机积分，适配时滞特征）
        def drift_func(t, z):
            return self.sdde_drift(z)  # 仅返回漂移项用于积分接口

        z_path = odeint(
            func=drift_func,
            y0=z0,
            t=t_eval,
            method='dopri5',  # 高精度数值积分
            adjoint_params=list(self.sdde_drift.parameters())
        )  # z_path: (T', B, H)
        z_path = z_path.transpose(0, 1)  # (B, T', H)：时滞记忆特征的动力学演化路径

        # 4. 步骤4：计算扩散项（用于不确定性量化）
        diffusion = F.softplus(self.sdde_diffusion(z_path)) + 1e-6

        # 5. 步骤5：解码预测（补齐前τ步为0，与原始输入长度一致）
        pred_lag = self.decoder(z_path)  # (B, T', D)
        pred_pad = torch.zeros(batch_size, self.tau, self.input_dim, device=device)  # 前τ步补0
        pred_full = torch.cat([pred_pad, pred_lag], dim=1)  # (B, T, D)：与输入长度一致

        # 按需返回结果
        if return_diffusion:
            return pred_full, diffusion
        return pred_full


# ============ 3. 数据生成：带时滞特性的Lorenz混沌系统（适配SDDE） ============
def generate_lorenz_with_lag(
    num_samples=200,
    seq_len=50,
    dt=0.01,
    sigma=10.0,
    rho=28.0,
    beta=8/3.0,
    lag_strength=0.2  # 人为增强时滞影响，让数据时滞特性更明显
):
    """生成带显式时滞特性的Lorenz数据（确保SDDE时滞建模有意义）"""
    def lorenz_lag_dynamics(state, state_lag):
        """Lorenz动力学+时滞依赖：当前状态依赖前1步时滞状态"""
        x, y, z = state
        x_lag, y_lag, z_lag = state_lag
        # 加入时滞项影响：当前dx依赖前1步的x_lag
        dx = sigma * (y - x) + lag_strength * x_lag
        dy = x * (rho - z) - y + lag_strength * y_lag
        dz = x * y - beta * z + lag_strength * z_lag
        return np.array([dx, dy, dz])

    data = []
    for _ in range(num_samples):
        # 初始状态（前τ=2步，用于构建初始时滞）
        state_prev2 = np.random.uniform(-10, 10, size=3)  # t-2步
        state_prev1 = np.random.uniform(-10, 10, size=3)  # t-1步（τ=2时的时滞状态）
        state = np.random.uniform(-10, 10, size=3)        # 当前步
        trajectory = [state_prev2, state_prev1, state]

        # 生成带时滞的轨迹（τ=2，依赖前2步状态）
        for _ in range(seq_len - 3):
            # 时滞状态：前2步
            state_lag2 = trajectory[-3]
            state_lag1 = trajectory[-2]
            # 动力学更新（依赖前2步时滞）
            dxdy dz = lorenz_lag_dynamics(trajectory[-1], state_lag1)
            state_next = trajectory[-1] + dt * dxdy dz
            # 加入随机噪声（符合SDDE的随机特性）
            state_next += np.random.normal(0, 0.01, size=3)
            trajectory.append(state_next)

        # 截取seq_len长度（确保统一维度）
        trajectory = trajectory[:seq_len]
        data.append(trajectory)

    return np.array(data, dtype=np.float32)  # (num_samples, seq_len, 3)


# ============ 4. 训练与可视化（适配Hopfield-SDDE） ============
def train_hopfield_sdde(model, train_loader, val_loader, epochs=50, lr=1e-3, device='cpu'):
    """训练Hopfield-SDDE模型（适配时滞特征输入）"""
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.MSELoss()
    train_losses, val_losses = [], []
    model.to(device)

    for epoch in range(epochs):
        # 训练阶段
        model.train()
        train_loss = 0.0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()

            # 前向传播（Hopfield-SDDE处理时滞特征）
            batch_pred = model(batch_x)
            # 计算损失（仅对非补齐部分计算，避免前τ步0值干扰）
            loss = criterion(batch_pred[:, model.tau:, :], batch_y[:, model.tau:, :])
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * batch_x.size(0)
        train_loss /= len(train_loader.dataset)
        train_losses.append(train_loss)

        # 验证阶段
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                batch_pred = model(batch_x)
                loss = criterion(batch_pred[:, model.tau:, :], batch_y[:, model.tau:, :])
                val_loss += loss.item() * batch_x.size(0)
        val_loss /= len(val_loader.dataset)
        val_losses.append(val_loss)

        # 打印进度
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1:02d}/{epochs}] | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")

    return train_losses, val_losses, model


if __name__ == '__main__':
    # 1. 设备与超参数（聚焦时滞τ=2）
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input_dim = 3  # Lorenz数据维度（x,y,z）
    hidden_dim = 32
    tau = 2  # 时滞步数（核心参数）
    epochs = 50
    batch_size = 32

    # 2. 生成带时滞的Lorenz数据（适配SDDE）
    print("Step 1: 生成带时滞特性的Lorenz数据...")
    lorenz_data = generate_lorenz_with_lag(num_samples=200, seq_len=50, lag_strength=0.2)
    X = lorenz_data[:, :-1, :]  # 输入：(200, 49, 3)
    y = lorenz_data[:, 1:, :]   # 标签：(200, 49, 3)

    # 数据标准化
    X_mean = X.mean()
    X_std = X.std()
    X = (X - X_mean) / (X_std + 1e-8)
    y = (y - X_mean) / (X_std + 1e-8)

    # 划分数据集
    split_idx = int(0.8 * len(X))
    X_train, X_val = X[:split_idx], X[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]

    # 创建DataLoader
    train_dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
    val_dataset = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 3. 初始化并训练Hopfield-SDDE模型
    print("\nStep 2: 初始化Hopfield-SDDE模型...")
    model = HopfieldSDDENet(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        num_heads=4,
        tau=tau  # 显式传入时滞步数
    )
    print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")

    print("\nStep 3: 训练Hopfield-SDDE模型...")
    train_losses, val_losses, trained_model = train_hopfield_sdde(
        model, train_loader, val_loader, epochs=epochs, device=device
    )

    # 4. 结果可视化（聚焦时滞部分的预测效果）
    print("\nStep 4: 可视化结果...")
    plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # 子图1：训练/验证损失
    ax1.plot(range(1, epochs+1), train_losses, label='Train Loss', color='#1f77b4', linewidth=2)
    ax1.plot(range(1, epochs+1), val_losses, label='Val Loss', color='#ff7f0e', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('MSE Loss')
    ax1.set_title('Hopfield-SDDE Training & Validation Loss', fontweight='bold')
    ax1.legend()
    ax1.grid(alpha=0.3)

    # 子图2：时滞部分预测轨迹（忽略前τ=2步补齐值）
    trained_model.eval()
    with torch.no_grad():
        sample_idx = 0
        x_sample = torch.from_numpy(X_val[sample_idx:sample_idx+1]).to(device)
        y_true = y_val[sample_idx, tau:, 0] * X_std + X_mean  # 真实值（跳过前τ步）
        y_pred = trained_model(x_sample)[0, tau:, 0].cpu().numpy() * X_std + X_mean  # 预测值（跳过前τ步）
        time_steps = np.arange(len(y_true)) + tau  # 时间步对齐（从τ开始）

        ax2.plot(time_steps, y_true, label='True X (with Lag)', color='black', linewidth=2.5)
        ax2.plot(time_steps, y_pred, label='Pred X (Hopfield-SDDE)', color='#d62728', linewidth=2.5)
        ax2.set_xlabel('Time Step')
        ax2.set_ylabel('Lorenz X Component')
        ax2.set_title(f'Hopfield-SDDE Prediction (Lag τ={tau})', fontweight='bold')
        ax2.legend()
        ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig('hopfield_sdde_results.png', dpi=300, bbox_inches='tight')
    print("\n结果保存至: hopfield_sdde_results.png")
    print("\nHopfield-SDDE（时滞随机微分方程）训练完成！")