In [1]:
# %% 
import os, random, math, json, glob, re
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from einops import rearrange, repeat
from PIL import Image

device = torch.device("xpu" if torch.xpu.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: xpu


In [2]:
missing_src, missing_other = [], []
for role_path in Path("../Data/Target_Graph").glob("* 肖形.png"):
    name = role_path.stem.split(" ")[0]
    src_path   = Path("../Data/Source_Graph")/f"{name}.png"
    other_path = Path("../Data/Target_Graph")/f"{name} 写照.png"
    if not src_path.exists():   missing_src.append(name)
    if not other_path.exists(): missing_other.append(name)
print("文物缺失:", missing_src)
print("写照缺失:", missing_other)


文物缺失: ['可爱花束']
写照缺失: []


In [3]:
# %%
def pad_to_square_pil(im: Image.Image, fill=0):
    w, h = im.size
    l = max(w, h)
    new = Image.new("RGB", (l, l), (fill,)*3)
    new.paste(im, ((l-w)//2, (l-h)//2))
    return new

def pad_square_1024(im: Image.Image):
    if max(im.size) > 1024:
        ratio = 1024 / max(im.size)
        im = im.resize((round(im.size[0]*ratio), round(im.size[1]*ratio)),
                       Image.BICUBIC)
    w, h = im.size
    new = Image.new("RGB", (1024, 1024), (0,0,0))
    new.paste(im, ((1024-w)//2, (1024-h)//2))
    return new


In [4]:
# %%
class ArtifactAnimeDataset(Dataset):
    def __init__(self, root_src="../Data/Source_Graph",
                 root_tgt="../Data/Target_Graph"):
        self.src_root = Path(root_src)
        self.tgt_root = Path(root_tgt)

        self.pairs = []
        for role_path in self.tgt_root.glob("* 肖形.png"):
            name  = role_path.stem.split(" ")[0]
            src   = self.src_root/f"{name}.png"
            other = self.tgt_root/f"{name} 写照.png"
            if src.exists() and other.exists():
                self.pairs.append((src, role_path, other))

        self.tgt_tf = transforms.Compose([
            transforms.Lambda(pad_to_square_pil),
            transforms.Resize((700, 700), interpolation=Image.BICUBIC),
            transforms.ToTensor()
        ])
        self.src_tf = transforms.Compose([
            transforms.Lambda(pad_square_1024),
            transforms.ToTensor()
        ])

    def __len__(self): return len(self.pairs)

    def __getitem__(self, idx):
        src_path, role_a, role_b = self.pairs[idx]
        role_path = random.choice([role_a, role_b])
        role_img  = Image.open(role_path).convert("RGB")
        src_img   = Image.open(src_path).convert("RGB")
        return self.src_tf(src_img), self.tgt_tf(role_img)


In [5]:
dataset = ArtifactAnimeDataset()
print("Aligned samples:", len(dataset))
loader  = DataLoader(dataset, batch_size=2, shuffle=True,
                     num_workers=4, pin_memory=True)


Aligned samples: 99


In [6]:
class BetaSchedule:
    def __init__(self, n_steps=1000, max_beta=0.02, device="cpu"):
        self.N = n_steps
        betas  = torch.linspace(1e-4, max_beta, n_steps)
        alphas = 1. - betas
        self.register_buffer = lambda n, x: setattr(self, n, x.to(device))
        self.register_buffer('betas', betas)
        self.register_buffer('alphas_bar', torch.cumprod(alphas, 0))

    def get_params(self, t):
        return self.betas[t], self.alphas_bar[t]

schedule = BetaSchedule(device=device)

In [7]:
# %% 
# --- 4.1 时间嵌入 ---
class TimeEmbedding(nn.Module):
    def __init__(self, time_dim=256):
        super().__init__()
        self.time_dim = time_dim
        self.mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim*4),
            nn.SiLU(),
            nn.Linear(time_dim*4, time_dim)
        )

    def forward(self, t):
        half = self.time_dim // 2
        # 位置编码
        freqs = torch.exp(
            torch.arange(half, device=t.device) * -(math.log(10000.0) / (half - 1))
        )
        args = t.float().unsqueeze(1) * freqs.unsqueeze(0)          # (B, half)
        emb  = torch.cat([torch.sin(args), torch.cos(args)], dim=1) # (B, time_dim)
        return self.mlp(emb)                                        # (B, time_dim)


In [8]:
# %% 
# --- 4.2 ResBlock with FiLM ---
class ResBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim=None, act='silu'):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        self.act   = getattr(F, act)
        self.use_film = time_dim is not None and time_dim > 0
        if self.use_film:
            self.time_proj = nn.Linear(time_dim, out_c * 2)
        else:
            self.time_proj = None
        self.skip = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()

    def forward(self, x, t_emb=None):
        h = self.act(self.conv1(x))
        if self.use_film and t_emb is not None:
            γ, β = self.time_proj(t_emb).chunk(2, dim=1)  # (B, C)
            h = h * (1 + γ[..., None, None]) + β[..., None, None]
        h = self.conv2(h)
        return self.act(h + self.skip(x))


In [9]:
# %% 
# --- 4.3 文物 Adaptive FPN Encoder ---
class CondEncoder(nn.Module):
    def __init__(self, channels=(64,128,256,512,768)):
        super().__init__()
        self.stages = nn.ModuleList()
        in_c=3
        for c in channels:
            self.stages.append(nn.Sequential(
                ResBlock(in_c,c,0),       # no time emb
                ResBlock(c,c,0),
                nn.Conv2d(c,c,2,stride=2) # Down2
            ))
            in_c=c
        self.target_sizes = [(350,350),(175,175),(88,88),(44,44)]
        self.proj = nn.ModuleList([nn.Conv2d(c,c,1) for c in channels[:-1]])
    def forward(self, x):
        feats=[]
        for i,stage in enumerate(self.stages):
            x = stage(x)
            if i<4:
                pooled = F.adaptive_avg_pool2d(x, self.target_sizes[i])
                feats.append(self.proj[i](pooled))
        feats.append(x)  # 最深层原尺寸44x44
        return feats   # len=5


In [10]:
# %% 
# --- 4.4 Role Encoder / Decoder with Skip-Attn 融合 ---
class SkipAttnFusion(nn.Module):
    def __init__(self, c, heads):
        super().__init__()
        self.q = nn.Conv2d(c, c, 1)
        self.kv = nn.Conv2d(2*c, c*2,1)
        self.fc = nn.Conv2d(c, c,1)
        self.heads=heads; self.scale = (c//heads)**-0.5
    def forward(self, dec, role, art):
        B,C,H,W = dec.shape
        q = self.q(dec).reshape(B,self.heads,C//self.heads,H*W)
        kv = self.kv(torch.cat([role,art],1))
        k,v = kv.chunk(2,1)
        k = k.reshape(B,self.heads,C//self.heads,-1)
        v = v.reshape(B,self.heads,C//self.heads,-1)
        attn = (q.transpose(-2,-1) @ k)*self.scale
        attn = attn.softmax(-1)
        out  = (v @ attn.transpose(-2,-1)).reshape(B,C,H,W)
        return dec + self.fc(out)


In [11]:
# %%
class DownBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, is_last=False):
        super().__init__()
        self.res1 = ResBlock(in_c,  out_c, time_dim)
        self.res2 = ResBlock(out_c, out_c, time_dim)
        self.down = None
        if not is_last:                      # bottleneck 不再下采样
            self.down = nn.Conv2d(out_c, out_c, 2, stride=2)

    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        if self.down is not None:
            x = self.down(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim):
        super().__init__()
        self.up   = nn.Upsample(scale_factor=2, mode='nearest')
        self.res1 = ResBlock(in_c,  out_c, time_dim)
        self.res2 = ResBlock(out_c, out_c, time_dim)

    def forward(self, x, t_emb):
        x = self.up(x)
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        return x


In [12]:
# %% 
class UNetDiff(nn.Module):
    def __init__(self, time_dim=256, base_c=(64,128,256,512,768), act='silu'):
        super().__init__()
        self.time_emb = TimeEmbedding(time_dim=256)
        # Role encoder
        self.down_blocks = nn.ModuleList()
        in_c = 3
        for i, c in enumerate(base_c):
            is_last = (i == len(base_c)-1)
            self.down_blocks.append(DownBlock(in_c, c, time_dim, is_last=is_last))
            in_c = c
        # Cond encoder
        self.cond_enc = CondEncoder(base_c)
        # Skip Attention
        self.skip_fus = nn.ModuleList([SkipAttnFusion(c,max(4,c//64)) for c in base_c[::-1]])
        # Up path
        self.up_blocks = nn.ModuleList()
        rev = base_c[::-1]
        for idx in range(len(rev)-1):
            self.up_blocks.append(UpBlock(rev[idx], rev[idx+1], time_dim))
        # Bottleneck extra block
        self.mid = ResBlock(base_c[-1],base_c[-1],time_dim,act)
        # Output
        self.out_head = nn.Sequential(nn.Conv2d(base_c[0],3,1))
    def forward(self, role_noisy, artifact, t):
        t_emb = self.time_emb(t)
        role_feats = []
        x = role_noisy
        for blk in self.down_blocks:
            x = blk(x, t_emb)
            role_feats.append(x)
        art_feats = self.cond_enc(artifact)
        # bottleneck
        d = self.mid(x, t_emb)
        for i, up in enumerate(self.up_blocks):
            d = up(d, t_emb)
        d = self.skip_fus[-1](d, role_feats[0], art_feats[0])
        return self.out_head(d)


In [13]:
# %% 
def train_epoch(model, loader, opt, sc, epoch):
    model.train()
    for art, role in tqdm(loader, desc=f"Epoch {epoch}"):
        art, role = art.to(device), role.to(device)
        B = role.size(0)
        t = torch.randint(0, sc.N, (B,), device=device).long()
        β, αbar = sc.get_params(t.cpu())
        β, αbar = β.to(device), αbar.to(device)
        noise = torch.randn_like(role)
        noisy = (αbar.sqrt()[...,None,None,None]*role +
                 (1-αbar).sqrt()[...,None,None,None]*noise)
        ε_pred = model.forward(noisy, art, t)
        loss = F.mse_loss(ε_pred, noise)
        opt.zero_grad(); loss.backward(); opt.step()


In [14]:
# %% 
model = UNetDiff().to(device)
opt   = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
loader= DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0, pin_memory=False)

for epoch in range(10):
    train_epoch(model, loader, opt, schedule, epoch)
    torch.save(model.state_dict(), f"ckpt_epoch{epoch}.pt")


Epoch 0:   0%|          | 0/50 [00:20<?, ?it/s]


OutOfMemoryError: XPU out of memory. Tried to allocate 1728.08 GiB. GPU 0 has a total capacity of 11.60 GiB. Of the allocated memory 18.63 GiB is allocated by PyTorch, and 773.27 MiB is reserved by PyTorch but unallocated. Please use `empty_cache` to release all unoccupied cached memory.