In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import nn
import torch.nn.functional as F

In [23]:
# CIFAR-10

batch_size = 128

norm_mean = torch.tensor([0.4914, 0.4822, 0.4465])
norm_std = torch.tensor([0.2470, 0.2435, 0.2616])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=10)
test_dataloader = DataLoader(test_set, batch_size=batch_size, num_workers=10)

Files already downloaded and verified
Files already downloaded and verified


In [15]:
next(iter(train_dataloader))[0].shape

torch.Size([128, 3, 32, 32])

# Training EBMs using MLE with MCMC sampling

In [30]:
# Model p(x) = e^{-E(x)} / Z

class EBM(nn.Module):
    def __init__(self, hidden_dim=256):
        super(EBM, self).__init__()
        
        # 卷积 Backbone
        self.conv_net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 8x8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # 4x4
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )
        
        # 映射到标量能量
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512*4*4, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 1)   # 输出 Energy
        )
    
    def forward(self, x):
        feat = self.conv_net(x)
        energy = self.fc(feat)
        return energy.squeeze(-1)  # 输出 shape: (batch,)
    

In [39]:
def parallel_langevin_mcmc(log_prob_fn, initial_samples, step_size=0.01, n_steps=1000):
    """
    并行多链 Langevin MCMC 函数
    Args:
        log_prob_fn: 目标分布的对数概率函数 log p(x)
                     输入 shape (n_chains, dim)
                     输出 shape (n_chains,)
        initial_samples: 初始样本 tensor, shape (n_chains, dim)
        step_size: Langevin 步长
        n_steps: 每条链的采样步数
    Returns:
        samples: shape (n_chains, n_steps, dim)
    """
    x = initial_samples.clone()
    n_chains, dim = x.shape
    samples = torch.zeros(n_chains, n_steps, dim, device=x.device)

    for t in range(n_steps):
        x.requires_grad_(True)
        log_prob = log_prob_fn(x)              # shape: (n_chains,)
        grad = torch.autograd.grad(log_prob.sum(), x)[0]
        noise = torch.randn_like(x)
        x = x + 0.5 * step_size * grad + torch.sqrt(torch.tensor(step_size)) * noise
        samples[:, t, :] = x.detach()
    
    return samples

def langevin_sampling(model, x_init, n_steps=200, step_size=0.01, noise_scale=0.01, clamp=(-1, 1)):
    """
    从能量模型采样
    Args:
        x_init: 初始图像, shape (B, C, H, W)
        model: EBM
        n_steps: Langevin 步数
        step_size: 梯度步长
        noise_scale: 噪声尺度
        clamp: 图像值范围
    Returns:
        x_neg: 采样后的图像
    """
    x = x_init.clone().detach()
    for _ in range(n_steps):
        x.requires_grad_(True)
        energy = model(x)
        grad = torch.autograd.grad(energy.sum(), x)[0]
        # Langevin 更新
        noise = torch.randn_like(x) * noise_scale
        x = x - 0.5 * step_size * grad + noise
        # 保持像素范围 [0,1]
        x = x.clamp(*clamp).detach()
    return x


In [None]:
# train
epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 1e-4

# MCMC
sample_batchsize = 100
step_size = 0.01
n_steps = 1000

image_size = (3, 32, 32)
model = EBM(256).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    print("Epoch: ", epoch)
    epoch_nums = len(train_dataloader)
    for idx, (X, _) in enumerate(train_dataloader):
        X = X.to(device)
        E = model(X).mean()
        x0 = torch.randn(sample_batchsize, *image_size).to(device)
        samples = langevin_sampling(model, x0)
        logZ = -model(samples).mean()
        loss = E + 1.0 / batch_size * logZ
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if idx % 50 == 0 or idx == epoch_nums - 1:
            print(f"loss: {E:>7f}  [{idx:>3d}/{epoch_nums:>3d}]")

Epoch:  0
loss: 0.087524  [  0/391]
loss: -166.681183  [ 50/391]
loss: -521.161743  [100/391]
loss: -997.901367  [150/391]
loss: -1607.284424  [200/391]
loss: -2361.301270  [250/391]
loss: -3270.023438  [300/391]
loss: -4343.725098  [350/391]
loss: -5326.732422  [390/391]
Epoch:  1
loss: -5352.578125  [  0/391]
loss: -6747.754395  [ 50/391]
loss: -8328.219727  [100/391]
loss: -10105.615234  [150/391]
loss: -12083.794922  [200/391]
loss: -14267.084961  [250/391]
loss: -16660.207031  [300/391]
loss: -19268.425781  [350/391]
loss: -21512.683594  [390/391]
Epoch:  2
loss: -21572.201172  [  0/391]
loss: -24582.423828  [ 50/391]
loss: -27818.328125  [100/391]
loss: -31282.498047  [150/391]
loss: -34988.542969  [200/391]
loss: -38932.148438  [250/391]
loss: -43013.000000  [300/391]
loss: -47532.394531  [350/391]
loss: -51254.445312  [390/391]
Epoch:  3
loss: -51350.460938  [  0/391]
loss: -56219.882812  [ 50/391]
loss: -61355.187500  [100/391]
loss: -66726.539062  [150/391]
loss: -72371.69531