# meanflow

In [3]:
import torch
import torch.nn.functional as F
from einops import rearrange
from functools import partial
import numpy as np
import torch.nn as nn


class Normalizer:
    # minmax for raw image, mean_std for vae latent
    def __init__(self, mode='minmax', mean=None, std=None):
        assert mode in ['minmax', 'mean_std'], "mode must be 'minmax' or 'mean_std'"
        self.mode = mode

        if mode == 'mean_std':
            if mean is None or std is None:
                raise ValueError("mean and std must be provided for 'mean_std' mode")
            self.mean = torch.tensor(mean).view(-1, 1, 1)
            self.std = torch.tensor(std).view(-1, 1, 1)

    @classmethod
    def from_list(cls, config):
        """
        config: [mode, mean, std]
        """
        mode, mean, std = config
        return cls(mode, mean, std)

    def norm(self, x):
        if self.mode == 'minmax':
            return x * 2 - 1
        elif self.mode == 'mean_std':
            return (x - self.mean.to(x.device)) / self.std.to(x.device)

    def unnorm(self, x):
        if self.mode == 'minmax':
            x = x.clip(-1, 1)
            return (x + 1) * 0.5
        elif self.mode == 'mean_std':
            return x * self.std.to(x.device) + self.mean.to(x.device)


def stopgrad(x):
    return x.detach()


def adaptive_l2_loss(error, gamma=0.5, c=1e-3):
    """
    Adaptive L2 loss: sg(w) * ||Δ||_2^2, where w = 1 / (||Δ||^2 + c)^p, p = 1 - γ
    Args:
        error: Tensor of shape (B, C, W, H)
        gamma: Power used in original ||Δ||^{2γ} loss
        c: Small constant for stability
    Returns:
        Scalar loss
    """
    delta_sq = torch.mean(error ** 2, dim=(1, 2, 3), keepdim=False)
    p = 1.0 - gamma
    w = 1.0 / (delta_sq + c).pow(p)
    loss = delta_sq  # ||Δ||^2
    return (stopgrad(w) * loss).mean()


class MeanFlow:
    def __init__(
        self,
        channels=1,
        image_size=32,
        num_classes=10,
        normalizer=['minmax', None, None],
        # mean flow settings
        flow_ratio=0.50,
        # time distribution, mu, sigma
        time_dist=['lognorm', -0.4, 1.0],
        cfg_ratio=0.10,
        # set scale as none to disable CFG distill
        cfg_scale=2.0,
        # experimental
        cfg_uncond='u',
        jvp_api='autograd',
    ):
        super().__init__()
        self.channels = channels
        self.image_size = image_size
        self.num_classes = num_classes
        self.use_cond = num_classes is not None

        self.normer = Normalizer.from_list(normalizer)

        self.flow_ratio = flow_ratio
        self.time_dist = time_dist
        self.cfg_ratio = cfg_ratio
        self.w = cfg_scale

        self.cfg_uncond = cfg_uncond
        self.jvp_api = jvp_api

        assert jvp_api in ['funtorch', 'autograd'], "jvp_api must be 'funtorch' or 'autograd'"
        if jvp_api == 'funtorch':
            self.jvp_fn = torch.func.jvp
            self.create_graph = False
        elif jvp_api == 'autograd':
            self.jvp_fn = torch.autograd.functional.jvp
            self.create_graph = True

    # fix: r should be always not larger than t
    def sample_t_r(self, batch_size, device):
        if self.time_dist[0] == 'uniform':
            samples = np.random.rand(batch_size, 2).astype(np.float32)

        elif self.time_dist[0] == 'lognorm':
            mu, sigma = self.time_dist[-2], self.time_dist[-1]
            normal_samples = np.random.randn(batch_size, 2).astype(np.float32) * sigma + mu
            samples = 1 / (1 + np.exp(-normal_samples))  # Apply sigmoid

        # Assign t = max, r = min, for each pair
        t_np = np.maximum(samples[:, 0], samples[:, 1])
        r_np = np.minimum(samples[:, 0], samples[:, 1])

        num_selected = int(self.flow_ratio * batch_size)
        indices = np.random.permutation(batch_size)[:num_selected]
        r_np[indices] = t_np[indices]

        t = torch.tensor(t_np, device=device)
        r = torch.tensor(r_np, device=device)
        return t, r

    def loss(self, model, x, c=None):
        batch_size = x.shape[0]
        device = x.device

        t, r = self.sample_t_r(batch_size, device)

        t_ = rearrange(t, "b -> b 1 1 1").detach().clone()
        r_ = rearrange(r, "b -> b 1 1 1").detach().clone()

        e = torch.randn_like(x)
        x = self.normer.norm(x)

        z = (1 - t_) * x + t_ * e
        v = e - x

        if c is not None:
            assert self.cfg_ratio is not None
            uncond = torch.ones_like(c) * self.num_classes
            cfg_mask = torch.rand_like(c.float()) < self.cfg_ratio
            c = torch.where(cfg_mask, uncond, c)
            if self.w is not None:
                with torch.no_grad():
                    u_t = model(z, t, t, uncond)
                v_hat = self.w * v + (1 - self.w) * u_t
                if self.cfg_uncond == 'v':
                    # In the unconditional case, v = w * v + (1 - w) * u,
                    # so if we're choosing to use 'v' for uncond settings, we can just keep v.
                    # Apply this only to the unconditional samples indicated by cfg_mask.
                    cfg_mask = rearrange(cfg_mask, "b -> b 1 1 1").bool()
                    v_hat = torch.where(cfg_mask, v, v_hat)
            else:
                v_hat = v

        # forward pass
        # u = model(z, t, r, y=c)
        model_partial = partial(model, y=c)
        jvp_args = (
            lambda z, t, r: model_partial(z, t, r),
            (z, t, r),
            (v_hat, torch.ones_like(t), torch.zeros_like(r)),
        )

        if self.create_graph:
            u, dudt = self.jvp_fn(*jvp_args, create_graph=True)
        else:
            u, dudt = self.jvp_fn(*jvp_args)

        u_tgt = v_hat - (t_ - r_) * dudt

        error = u - stopgrad(u_tgt)
        loss = adaptive_l2_loss(error)
        # loss = F.mse_loss(u, stopgrad(u_tgt))

        mse_val = (stopgrad(error) ** 2).mean()
        return loss, mse_val

    @torch.no_grad()
    def sample_each_class(self, model, n_per_class, classes=None,
                          sample_steps=5, device='cuda'):
        model.eval()

        if classes is None:
            c = torch.arange(self.num_classes, device=device).repeat(n_per_class)
        else:
            c = torch.tensor(classes, device=device).repeat(n_per_class)

        z = torch.randn(c.shape[0], self.channels,
                        self.image_size, self.image_size,
                        device=device)

        t_vals = torch.linspace(1.0, 0.0, sample_steps + 1, device=device)

        # print(t_vals)

        for i in range(sample_steps):
            t = torch.full((z.size(0),), t_vals[i], device=device)
            r = torch.full((z.size(0),), t_vals[i + 1], device=device)

            # print(f"t: {t[0].item():.4f};  r: {r[0].item():.4f}")

            t_ = rearrange(t, "b -> b 1 1 1").detach().clone()
            r_ = rearrange(r, "b -> b 1 1 1").detach().clone()

            v = model(z, t, r, c)
            z = z - (t_-r_) * v

        z = self.normer.unnorm(z)
        return z

## demo

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from functools import partial
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

# 1. 模拟一个简单的U-Net模型 (用于图像)
# 注意: 这是一个非常简化的模型，仅用于演示目的。
class SimpleUnet(nn.Module):
    def __init__(self, in_channels=1, num_classes=10, time_emb_dim=32):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim * 4),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim)
        )
        self.class_emb = nn.Embedding(num_classes + 1, time_emb_dim) # +1 for uncond

        self.down1 = nn.Conv2d(in_channels, 32, 3, padding=1)
        self.down2 = nn.Conv2d(32, 64, 3, padding=1)
        self.up1 = nn.Conv2d(64 + 32, 32, 3, padding=1)
        self.up2 = nn.Conv2d(32, in_channels, 3, padding=1)
        self.act = nn.SiLU()

    def _time_embedding(self, t, emb_dim):
        # t: [B]
        half_dim = emb_dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        return torch.cat([emb.sin(), emb.cos()], dim=1)

    def forward(self, z, t, r, y): # 模型需要接受 z, t, r, 和 y (条件)
        time_emb = self._time_embedding(t, 32)
        time_emb = self.time_mlp(time_emb)

        class_emb = self.class_emb(y)
        
        cond_emb = time_emb + class_emb
        
        # 将条件信息融入网络 (简单地加到每个像素上)
        cond_emb = rearrange(cond_emb, 'b c -> b c 1 1')

        x1 = self.act(self.down1(z))
        x2 = self.act(self.down2(x1))
        
        x = x2 + cond_emb.expand(-1, -1, x2.shape[2], x2.shape[3]) # 融入条件
        
        x = F.interpolate(x, scale_factor=1) # 模拟上采样
        x = torch.cat([x, x1], dim=1)
        x = self.act(self.up1(x))
        x = self.up2(x)
        return x

# 2. 设置参数和模拟数据
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
IMG_SIZE = 32
CHANNELS = 1
NUM_CLASSES = 10
EPOCHS = 20 # 增加迭代次数以看到损失下降

# 创建假的图像数据和标签
dummy_images = torch.rand(BATCH_SIZE * 5, CHANNELS, IMG_SIZE, IMG_SIZE, device=DEVICE)
dummy_labels = torch.randint(0, NUM_CLASSES, (BATCH_SIZE * 5,), device=DEVICE)
dataset = TensorDataset(dummy_images, dummy_labels)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# 3. 初始化所有组件
model = SimpleUnet(in_channels=CHANNELS, num_classes=NUM_CLASSES).to(DEVICE)
mean_flow = MeanFlow(
    channels=CHANNELS,
    image_size=IMG_SIZE,
    num_classes=NUM_CLASSES,
    cfg_scale=3.0, # 使用分类器无关引导
    jvp_api='autograd'
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 4. 训练循环
print("Starting training demo...")
for epoch in range(EPOCHS):
    total_loss = 0
    for i, (images, labels) in enumerate(dataloader):
        optimizer.zero_grad()
        
        loss, mse = mean_flow.loss(model, images, labels)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Average Loss: {avg_loss:.4f}")

print("Training finished.")

# 5. 生成样本
print("\nGenerating samples...")
generated_samples = mean_flow.sample_each_class(
    model,
    n_per_class=2, # 每个类别生成2个样本
    sample_steps=10,
    device=DEVICE
)

print(f"Generated a tensor of shape: {generated_samples.shape}")
# 在真实场景中，你会在这里将张量保存为图像或进行可视化
# import torchvision
# torchvision.utils.save_image(generated_samples, "generated_samples.png", nrow=NUM_CLASSES)

Starting training demo...


RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 1

# 1d

In [7]:
def adaptive_l2_loss_1d(error, gamma=0.5, c=1e-3):
    """
    适用于 (B, C, L) 数据的自适应L2损失
    """
    # 原来的维度是 (1, 2, 3) for (C, H, W)
    # 新的维度是 (1, 2) for (C, L)
    delta_sq = torch.mean(error ** 2, dim=(1, 2), keepdim=False)
    p = 1.0 - gamma
    w = 1.0 / (delta_sq + c).pow(p)
    loss = delta_sq
    return (stopgrad(w) * loss).mean()
class Normalizer1D(Normalizer):
    def __init__(self, mode='minmax', mean=None, std=None):
        assert mode in ['minmax', 'mean_std'], "mode must be 'minmax' or 'mean_std'"
        self.mode = mode

        if mode == 'mean_std':
            if mean is None or std is None:
                raise ValueError("mean and std must be provided for 'mean_std' mode")
            # 原来是 view(-1, 1, 1) for (C, H, W)
            # 现在是 view(-1, 1) for (C, L)
            self.mean = torch.tensor(mean).view(-1, 1)
            self.std = torch.tensor(std).view(-1, 1)
class MeanFlow1D(MeanFlow):
    def __init__(self, channels=1, sequence_length=1024, **kwargs):
        # 用 sequence_length 替换 image_size
        super().__init__(channels=channels, image_size=sequence_length, **kwargs)
        self.sequence_length = sequence_length
        # 使用修改后的 Normalizer 和 loss
        self.normer = Normalizer1D.from_list(kwargs.get('normalizer', ['minmax', None, None]))

    def loss(self, model, x, c=None): # x 的形状是 (B, C, L)
        batch_size = x.shape[0]
        device = x.device

        t, r = self.sample_t_r(batch_size, device)

        # 关键修改：rearrange 从 "b -> b 1 1 1" 到 "b -> b 1 1"
        t_ = rearrange(t, "b -> b 1 1").detach().clone()
        r_ = rearrange(r, "b -> b 1 1").detach().clone()

        e = torch.randn_like(x)
        x = self.normer.norm(x)

        z = (1 - t_) * x + t_ * e
        v = e - x
        v_hat = v

        if c is not None:
            assert self.cfg_ratio is not None
            uncond = torch.ones_like(c) * self.num_classes
            cfg_mask = torch.rand(c.shape[0], device=c.device) < self.cfg_ratio
            c = torch.where(cfg_mask, uncond, c)
            if self.w is not None:
                with torch.no_grad():
                    u_t = model(z, t, t, c)
                v_hat = self.w * v + (1 - self.w) * u_t
                if self.cfg_uncond == 'v':
                    # 关键修改：rearrange
                    cfg_mask_expanded = rearrange(cfg_mask, "b -> b 1 1")
                    v_hat = torch.where(cfg_mask_expanded, v, v_hat)

        model_partial = partial(model, y=c)
        jvp_args = (
            lambda z_arg, t_arg, r_arg: model_partial(z_arg, t_arg, r_arg),
            (z, t, r),
            (v_hat, torch.ones_like(t), torch.zeros_like(r)),
        )

        if self.create_graph:
            u, dudt = self.jvp_fn(*jvp_args, create_graph=True)
        else:
            u, dudt = self.jvp_fn(*jvp_args)

        u_tgt = v_hat - (t_ - r_) * dudt
        error = u - stopgrad(u_tgt)
        
        # 使用1D版本的loss
        loss = adaptive_l2_loss_1d(error)
        mse_val = (stopgrad(error) ** 2).mean()
        return loss, mse_val

    @torch.no_grad()
    def sample_each_class(self, model, n_per_class, classes=None,
                            sample_steps=5, device='cuda'):
        model.eval()

        if classes is None:
            c = torch.arange(self.num_classes, device=device).repeat(n_per_class)
        else:
            c = torch.tensor(classes, device=device).repeat(n_per_class)
            
        # 关键修改：噪声形状
        z = torch.randn(c.shape[0], self.channels, self.sequence_length, device=device)

        t_vals = torch.linspace(1.0, 0.0, sample_steps + 1, device=device)

        for i in range(sample_steps):
            t = torch.full((z.size(0),), t_vals[i], device=device)
            r = torch.full((z.size(0),), t_vals[i + 1], device=device)

            # 关键修改：rearrange
            t_ = rearrange(t, "b -> b 1 1")
            r_ = rearrange(r, "b -> b 1 1")
            
            if self.w is not None and self.use_cond:
                uncond = torch.ones_like(c) * self.num_classes
                v_cond = model(z, t, r, c)
                v_uncond = model(z, t, r, uncond)
                v = (1 + self.w) * v_cond - self.w * v_uncond
            else:
                 v = model(z, t, r, c)

            z = z - (t_ - r_) * v

        z = self.normer.unnorm(z)
        # 如果你需要 (B, L, C) 格式，在这里进行转换
        # z = z.permute(0, 2, 1)
        return z

##  network

In [10]:
class SimpleUnet1D(nn.Module):
    def __init__(self, in_channels=1, num_classes=10, time_emb_dim=32, sequence_length=1024):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim * 4),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim)
        )
        self.class_emb = nn.Embedding(num_classes + 1, time_emb_dim)

        # 使用1D卷积
        self.down1 = nn.Conv1d(in_channels, 32, kernel_size=3, padding=1)
        self.down2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
        
        # 上采样可以使用 ConvTranspose1d 或 Interpolate
        self.up1 = nn.Conv1d(64 + 32, 32, kernel_size=3, padding=1)
        self.up2 = nn.Conv1d(32, in_channels, kernel_size=3, padding=1)
        self.act = nn.SiLU()

    def _time_embedding(self, t, emb_dim):
        half_dim = emb_dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        return torch.cat([emb.sin(), emb.cos()], dim=1)

    def forward(self, z, t, r, y): # z: (B, C, L)
        time_emb = self._time_embedding(t, 64)
        time_emb = self.time_mlp(time_emb)

        class_emb = self.class_emb(y)
        cond_emb = time_emb + class_emb
        
        # 调整条件形状以匹配序列
        cond_emb = rearrange(cond_emb, 'b c -> b c 1')

        x1 = self.act(self.down1(z))
        x2 = self.act(self.down2(x1))
        
        # 融入条件信息
        x = x2 + cond_emb.expand(-1, -1, x2.shape[2])
        
        # 使用插值进行上采样
        x = F.interpolate(x, size=x1.shape[-1])
        x = torch.cat([x, x1], dim=1)
        x = self.act(self.up1(x))
        x = self.up2(x)
        return x

## demo

In [11]:
# 振动信号生成任务的示例设置
SEQ_LEN = 1024
CHANNELS = 3  # 假设是3轴振动信号
NUM_CLASSES = 5 # 假设有5种工况

# 1. 实例化1D模型
model_1d = SimpleUnet1D(
    in_channels=CHANNELS, 
    num_classes=NUM_CLASSES, 
    sequence_length=SEQ_LEN
).to(DEVICE)

# 2. 实例化 MeanFlow1D
mean_flow_1d = MeanFlow1D(
    channels=CHANNELS,
    sequence_length=SEQ_LEN,
    num_classes=NUM_CLASSES,
    cfg_scale=3.0
)

# 3. 创建模拟的振动信号数据 (B, C, L)
# 生成多谐波信号 (多个正弦波叠加)
t = torch.linspace(0, 2 * np.pi, SEQ_LEN, device=DEVICE)
dummy_signals = torch.zeros(BATCH_SIZE * 5, CHANNELS, SEQ_LEN, device=DEVICE)

for i in range(BATCH_SIZE * 5):
    for c in range(CHANNELS):
        # 为每个通道生成不同频率的多谐波信号
        signal = torch.zeros_like(t)
        # 基频和谐波
        freqs = [1, 2, 3, 4]  # 基频和3个谐波
        amps = [1.0, 0.5, 0.3, 0.2]  # 对应幅度
        phases = torch.rand(len(freqs)) * 2 * np.pi  # 随机相位
        
        for freq, amp, phase in zip(freqs, amps, phases):
            signal += amp * torch.sin(freq * (c + 1) * t + phase)
        
        # 添加少量噪声
        signal += 0.1 * torch.randn_like(signal)
        dummy_signals[i, c] = signal
dummy_labels = torch.randint(0, NUM_CLASSES, (BATCH_SIZE * 5,), device=DEVICE)
dataset_1d = TensorDataset(dummy_signals, dummy_labels)
dataloader_1d = DataLoader(dataset_1d, batch_size=BATCH_SIZE, shuffle=True)

# 4. 运行训练循环 (与之前类似，但使用1D组件)
optimizer = torch.optim.Adam(model_1d.parameters(), lr=1e-4)

for epoch in range(EPOCHS):
    for signals, labels in dataloader_1d:
        optimizer.zero_grad()
        loss, mse = mean_flow_1d.loss(model_1d, signals, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item():.4f}")



RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x64 and 32x128)

## generate

In [None]:
# 5. 生成信号
generated_signals = mean_flow_1d.sample_each_class(
    model_1d, 
    n_per_class=1, 
    device=DEVICE
)
print(f"Generated signals shape: {generated_signals.shape}") # (B, C, L)

# 如果你的最终目标是 (B, L, C)
generated_signals_final_shape = generated_signals.permute(0, 2, 1)
print(f"Final shape after permute: {generated_signals_final_shape.shape}")