In [1]:
import os, math, glob
import torch
from torch import nn
from torch.nn import functional as F

# device 选择：CUDA > XPU > CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch, "xpu") and torch.xpu.is_available():   # type: ignore
    device = torch.device("xpu")                           # type: ignore
else:
    device = torch.device("cpu")

print(">>> Using device:", device)


>>> Using device: xpu


In [2]:
class SiLU(nn.Module):
    def forward(self, x): return F.silu(x)

def timestep_embedding(timesteps, dim):
    """
    返回 (B, dim) 的 sin/cos 时间步嵌入
    """
    device = timesteps.device
    half = dim // 2
    freqs = torch.exp(
        -math.log(10000) * torch.arange(0, half, dtype=torch.float32, device=device) / half
    )
    args = timesteps[:, None].float() * freqs[None]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2:  # 维度为奇数时补 0
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)
    return emb

class FiLM(nn.Module):
    """把 time‑embedding 映射到 (scale, shift) 并加到特征图"""
    def __init__(self, in_channels, time_dim):
        super().__init__()
        self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, 2 * in_channels))

    def forward(self, x, t_emb):
        gamma, beta = self.mlp(t_emb).chunk(2, dim=-1)
        gamma, beta = gamma[:, :, None, None], beta[:, :, None, None]
        return x * (1 + gamma) + beta

class ResBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        self.norm1 = nn.GroupNorm(32, out_c)
        self.norm2 = nn.GroupNorm(32, out_c)
        self.act   = SiLU()
        self.film  = FiLM(out_c, time_dim)
        self.skip  = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()

    def forward(self, x, t_emb):
        h = self.act(self.norm1(self.conv1(x)))
        h = self.film(h, t_emb)
        h = self.act(self.norm2(self.conv2(h)))
        return h + self.skip(x)

def center_crop_to(src, ref):
    """
    将 src 在 H,W 维度上中心裁剪 / 适量 pad 到 ref 的空间尺寸
    只会裁剪或对齐到更小边界，不改变 Batch/Channel 维
    """
    _, _, Hs, Ws = src.shape
    _, _, Hr, Wr = ref.shape

    # center crop
    if Hs > Hr:
        dh = (Hs - Hr) // 2
        src = src[:, :, dh:dh+Hr, :]
    if Ws > Wr:
        dw = (Ws - Wr) // 2
        src = src[:, :, :, dw:dw+Wr]

    # (可选) pad 到目标尺寸 —— 理论上 Hs>=Hr & Ws>=Wr 已满足
    if src.shape[2] < Hr or src.shape[3] < Wr:
        pad_h = Hr - src.shape[2]
        pad_w = Wr - src.shape[3]
        src = F.pad(src, (0, pad_w, 0, pad_h))

    return src


In [3]:
class ArtifactMiniEncoder(nn.Module):
    """
    仅输出两层特征：
      A5  (512, 22, 22)
      ABN (1024, 11, 11)
    """
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 7, 2, 3), SiLU(),
            nn.Conv2d(64, 128, 3, 2, 1), SiLU(),
            nn.Conv2d(128, 256, 3, 2, 1), SiLU(),
            nn.Conv2d(256, 512, 3, 2, 1), SiLU()
        )
        self.down1024 = nn.Conv2d(512, 1024, 3, 2, 1)
        self.act = SiLU()

    def forward(self, x):
        h = self.backbone(x)                     # -> (*,512, H/16, W/16)
        A5 = F.adaptive_avg_pool2d(h, (22, 22)) # fixed grid
        h  = self.act(self.down1024(h))         # 1024
        ABN = F.adaptive_avg_pool2d(h, (11, 11))
        return A5, ABN


In [4]:
class CrossAttention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_kv = nn.Linear(dim, dim * 2, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, q, kv):
        B, N, C = q.shape
        h = self.heads
        q = self.to_q(q).view(B, N, h, C // h).transpose(1, 2) * self.scale
        k, v = self.to_kv(kv).chunk(2, dim=-1)
        k = k.view(B, -1, h, C // h).transpose(1, 2)
        v = v.view(B, -1, h, C // h).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)).softmax(dim=-1)
        out = (att @ v).transpose(1, 2).reshape(B, N, C)
        return self.to_out(out)


In [5]:
class DiffusionUNet(nn.Module):
    def __init__(self, time_dim=256):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim)
        )

        # =========== Encoder ===========
        self.res0 = ResBlock(3, 64, time_dim)
        self.res0b = ResBlock(64, 64, time_dim)
        self.down1 = nn.Conv2d(64, 128, 4, 2, 1)
        self.res1a = ResBlock(128, 128, time_dim)
        self.res1b = ResBlock(128, 128, time_dim)
        self.down2 = nn.Conv2d(128, 256, 4, 2, 1)
        self.res2a = ResBlock(256, 256, time_dim)
        self.res2b = ResBlock(256, 256, time_dim)
        self.down3 = nn.Conv2d(256, 512, 4, 2, 1)
        self.res3a = ResBlock(512, 512, time_dim)
        self.res3b = ResBlock(512, 512, time_dim)
        self.down4 = nn.Conv2d(512, 512, 4, 2, 1)
        self.res4a = ResBlock(512, 512, time_dim)
        self.res4b = ResBlock(512, 512, time_dim)
        self.down5 = nn.Conv2d(512, 512, 4, 2, 1)
        self.res5a = ResBlock(512, 512, time_dim)
        self.res5b = ResBlock(512, 512, time_dim)

        # Artifact attention @ 22×22
        self.att5 = CrossAttention(512)

        # bottleneck 11×11
        self.down_bn = nn.Conv2d(512, 1024, 4, 2, 1)
        self.res_bn1 = ResBlock(1024, 1024, time_dim)

        self.att_bn = CrossAttention(1024)
        self.res_bn2 = ResBlock(1024, 1024, time_dim)

        # =========== Decoder ===========
        self.up5 = nn.ConvTranspose2d(1024, 512, 4, 2, 1)
        self.up5_res1 = ResBlock(1024, 512, time_dim)   # 512+512
        self.up5_res2 = ResBlock(512, 512, time_dim)

        self.up4 = nn.ConvTranspose2d(512, 512, 4, 2, 1)
        self.up4_res1 = ResBlock(1024, 512, time_dim)   # 512+512
        self.up4_res2 = ResBlock(512, 512, time_dim)

        self.up3 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.up3_res1 = ResBlock(768, 256, time_dim)    # 256+512  ★
        self.up3_res2 = ResBlock(256, 256, time_dim)

        self.up2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.up2_res1 = ResBlock(384, 128, time_dim)    # 128+256  ★
        self.up2_res2 = ResBlock(128, 128, time_dim)

        self.up1 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.up1_res1 = ResBlock(192, 64, time_dim)     # 64+128   ★
        self.up1_res2 = ResBlock(64, 64, time_dim)

        self.out_conv = nn.Conv2d(64, 3, 1)

        # Artifact encoder
        self.art_enc = ArtifactMiniEncoder()

    # --------------------------
    def forward(self, x, art, t):
        t_emb = self.time_mlp(timestep_embedding(t, self.time_mlp[0].in_features))

        # Encoder
        h0 = self.res0b(self.res0(x, t_emb), t_emb)
        h1 = self.res1b(self.res1a(self.down1(h0), t_emb), t_emb)
        h2 = self.res2b(self.res2a(self.down2(h1), t_emb), t_emb)
        h3 = self.res3b(self.res3a(self.down3(h2), t_emb), t_emb)
        h4 = self.res4b(self.res4a(self.down4(h3), t_emb), t_emb)
        h5 = self.res5b(self.res5a(self.down5(h4), t_emb), t_emb)   # 22×22

        # Artifact features
        A5, ABN = self.art_enc(art)
        B, C, H, W = h5.shape
        att_in = torch.cat([h5.flatten(2).transpose(1,2), A5.flatten(2).transpose(1,2)], dim=1)
        h5 = self.att5(h5.flatten(2).transpose(1,2), att_in).transpose(1,2).view(B,C,H,W)

        # Bottleneck 11×11
        bn = self.res_bn1(self.down_bn(h5), t_emb)
        B, Cb, Hb, Wb = bn.shape
        att_in_bn = torch.cat([bn.flatten(2).transpose(1,2), ABN.flatten(2).transpose(1,2)], dim=1)
        bn = self.att_bn(bn.flatten(2).transpose(1,2), att_in_bn).transpose(1,2).view(B,Cb,Hb,Wb)
        bn = self.res_bn2(bn, t_emb)

        # Decoder
        up5_dec = self.up5(bn)
        up5_dec = center_crop_to(up5_dec, h5)      # ★ new
        up5 = torch.cat([up5_dec, h5], dim=1)
        up5 = self.up5_res2(self.up5_res1(up5, t_emb), t_emb)

        up4_dec = self.up4(up5)
        up4_dec = center_crop_to(up4_dec, h4)      # ★ new
        up4 = torch.cat([up4_dec, h4], dim=1)
        up4 = self.up4_res2(self.up4_res1(up4, t_emb), t_emb)

        up3_dec = self.up3(up4)
        up3_dec = center_crop_to(up3_dec, h3)      # ★ new
        up3 = torch.cat([up3_dec, h3], dim=1)
        up3 = self.up3_res2(self.up3_res1(up3, t_emb), t_emb)

        up2_dec = self.up2(up3)
        up2_dec = center_crop_to(up2_dec, h2)      # ★ new
        up2 = torch.cat([up2_dec, h2], dim=1)
        up2 = self.up2_res2(self.up2_res1(up2, t_emb), t_emb)

        up1_dec = self.up1(up2)
        up1_dec = center_crop_to(up1_dec, h1)      # ★ new
        up1 = torch.cat([up1_dec, h1], dim=1)
        up1 = self.up1_res2(self.up1_res1(up1, t_emb), t_emb)

        out = self.out_conv(up1)
        return out


In [6]:
from PIL import Image
from torchvision import transforms

class ArtifactCharacterDataset(torch.utils.data.Dataset):
    def __init__(self, root_art='../Data/Source_Graph',
                       root_role='../Data/Target_Graph'):
        super().__init__()
        self.art_files = sorted(glob.glob(os.path.join(root_art, '*.png')))
        self.role_dir = root_role
        self.trans_art = transforms.Compose([
            transforms.ToTensor()
        ])
        self.trans_role = transforms.Compose([
            transforms.Resize((700,700)),
            transforms.ToTensor()
        ])

    def __len__(self): return len(self.art_files)

    def __getitem__(self, idx):
        art_path = self.art_files[idx]
        name = os.path.splitext(os.path.basename(art_path))[0]
        role1 = Image.open(os.path.join(self.role_dir, f'{name} 肖形.png')).convert('RGB')
        role2 = Image.open(os.path.join(self.role_dir, f'{name} 写照.png')).convert('RGB')
        art_img = Image.open(art_path).convert('RGB')

        # 这里仅返回 role1，可自行随机选择两张
        return (self.trans_role(role1),
                self.trans_art(art_img),
                torch.randint(0, 1000, (1,)))  # dummy time‑step


In [7]:
model = DiffusionUNet().to(device)
dummy_role = torch.randn(1,3,700,700, device=device)
dummy_art  = torch.randn(1,3,800,512, device=device)   # 随意尺寸
dummy_t    = torch.randint(0,1000,(1,), device=device)

with torch.no_grad():
    out = model(dummy_role, dummy_art, dummy_t)
print('ε‑pred output shape:', out.shape)

# ---- torchinfo summary (可选) ----
try:
    from torchinfo import summary
    summary(model, input_data=(dummy_role, dummy_art, dummy_t), depth=2)
except ImportError:
    print("安装 torchinfo 后可获得层级参数统计： pip install torchinfo")


ε‑pred output shape: torch.Size([1, 3, 350, 350])


In [8]:
# =======  1. 线性 beta 调度  =======
T = 1000                                           # diffusion steps
beta_start, beta_end = 1e-4, 2e-2
betas  = torch.linspace(beta_start, beta_end, T, dtype=torch.float32, device=device)
alphas = 1.0 - betas
alphas_cum = torch.cumprod(alphas, dim=0)          # \bar α_t

# =======  2. q_sample  =======
def q_sample(x_start, t, noise=None):
    """
    根据公式  x_t = sqrt(abar_t) x0 + sqrt(1-abar_t) ε
    """
    if noise is None:
        noise = torch.randn_like(x_start)
    abar = alphas_cum[t].view(-1, 1, 1, 1)
    return (abar.sqrt() * x_start + (1.0 - abar).sqrt() * noise)

# =======  3. 随机时间步生成器  =======
def sample_timesteps(batch_size):
    return torch.randint(0, T, (batch_size,), device=device, dtype=torch.long)


In [9]:
from torch.utils.data import DataLoader
dataset = ArtifactCharacterDataset()
loader  = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, pin_memory=True)

model    = DiffusionUNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
criterion = nn.MSELoss()

# 可选：学习率调度
lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100*len(loader))


In [10]:
EPOCHS = 10
save_every = 2        # 每 N epoch 保存一次
ckpt_dir = "./checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)

for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    
    for role_img, art_img, _ in loader:
        role_img = role_img.to(device)
        art_img  = art_img.to(device)
        bsz = role_img.size(0)

        # 1. 采样时间步 & 真实噪声
        t     = sample_timesteps(bsz)
        noise = torch.randn_like(role_img)

        # 2. x_t 输入模型
        x_t = q_sample(role_img, t, noise)
        pred_noise = model(x_t, art_img, t)

        # 3. 计算 MSE 损失
        loss = criterion(pred_noise, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_sched.step()

        running_loss += loss.item() * bsz

    avg_loss = running_loss / len(dataset)
    print(f"Epoch {epoch:02d}/{EPOCHS} | loss={avg_loss:.6f}")

    # --------- checkpoint ----------
    if epoch % save_every == 0:
        ckpt_path = os.path.join(ckpt_dir, f"epoch_{epoch:03d}.pt")
        torch.save({
            "model_state": model.state_dict(),
            "optim_state": optimizer.state_dict(),
            "epoch": epoch
        }, ckpt_path)
        print("Checkpoint saved:", ckpt_path)


RuntimeError: stack expects each tensor to be equal size, but got [3, 532, 368] at entry 0 and [3, 460, 690] at entry 1