In [None]:
# -*- coding: utf-8 -*-
"""
AGE-CGAN High — Young → Senior (Balanced, Mid+)
- 256px / ResNet-6 + Self-Attention / Multi-Scale SN PatchGAN
- Hinge GAN + (옵션) R1 GP, Feature Matching, Identity, Cycle, Perceptual(VGG19 ON)
- DiffAug(색/이동) 옵션, Replay Buffer, EMA
- Epoch=40, steps/epoch ≤ 400 (가볍진 않지만 과하지 않게 상향)

사용법은 기존과 동일. 성능/VRAM 맞춰 하이퍼만 조절하세요.
"""

import os, random, time, copy
from pathlib import Path
from PIL import Image
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torchvision.utils as vutils
from torch.nn.utils import spectral_norm as SN
from tqdm.auto import tqdm

# ===== 기본 설정 =====
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed(42)

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

DEVICE  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = torch.cuda.is_available()
print("DEVICE:", DEVICE, "| AMP:", USE_AMP)

DATA_DIR   = Path("data")
YOUNG_DIR  = DATA_DIR/"Young"
SENIOR_DIR = DATA_DIR/"Senior"

OUT_DIR    = Path("outputs_high"); OUT_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR   = OUT_DIR/"ckpt"; CKPT_DIR.mkdir(parents=True, exist_ok=True)
SAMPLE_DIR = OUT_DIR/"samples"; SAMPLE_DIR.mkdir(parents=True, exist_ok=True)

# ===== 하이퍼(상향 세팅) =====
IMG_SIZE   = 256
BATCH_SIZE = 4              # 256px + 멀티스케일 → VRAM 고려 (여유면 6~8)
EPOCHS     = 40
MAX_STEPS_PER_EPOCH = 400   # None이면 전체
LR       = 2e-4
BETAS    = (0.5, 0.999)
DECAY_FROM = 25             # 이후 선형 감쇠

GEN_BLOCKS = 6              # 4→6 (조금 더 표현력)
NGF = 64; NDF = 64          # 48→64
LAMBDA_CYC = 10.0
LAMBDA_IDT = 3.0
USE_PERCEPTUAL = True
LAMBDA_PER = 0.2
LAMBDA_FM  = 10.0           # Feature Matching 강도

USE_DIFFAUG = True          # 간단한 색/이동 DiffAug
USE_R1 = True               # Hinge + R1(옵션)
R1_LAMBDA = 10.0
R1_INTERVAL = 16            # N 스텝마다 R1

REPLAY_POOL_SIZE = 50
EMA_DECAY = 0.998

SAVE_EVERY_EPOCHS = 2
LOG_EVERY_STEPS   = 50

# ===== 유틸 =====
def denorm(t):  # [-1,1]→[0,1]
    return (t*0.5 + 0.5).clamp(0,1)

def is_img(p: Path):
    return p.suffix.lower() in (".jpg",".jpeg",".png",".bmp",".webp")

def safe_open(path: Path, img_size: int):
    try:
        return Image.open(path).convert("RGB")
    except Exception as e:
        print(f"[WARN] bad image: {path} ({e})")
        return Image.new("RGB", (img_size, img_size), (0,0,0))

# ===== 전처리: Resize(286) → RandomCrop(256) + HFlip =====
class RandomJitter256:
    def __call__(self, img):
        img = TF.resize(img, 286, interpolation=T.InterpolationMode.BICUBIC)
        i, j, h, w = T.RandomCrop.get_params(img, (IMG_SIZE, IMG_SIZE))
        img = TF.crop(img, i, j, h, w)
        if random.random() < 0.5: img = TF.hflip(img)
        return img

# ===== DiffAug(간단) =====
def diffaug(x):
    # x in [-1,1]
    if random.random() < 0.5:
        # 색 변화 (brightness/contrast/saturation/hue)
        x = (x+1)/2
        b = (1.0 + (random.random()*0.4 - 0.2))
        c = (1.0 + (random.random()*0.4 - 0.2))
        s = (1.0 + (random.random()*0.4 - 0.2))
        h = (random.random()*0.1 - 0.05)
        x = TF.adjust_brightness(x, b)
        x = TF.adjust_contrast(x, c)
        x = TF.adjust_saturation(x, s)
        x = TF.adjust_hue(x, h)
        x = x*2-1
    if random.random() < 0.5:
        # 작은 이동
        B, C, H, W = x.shape
        tx = int(np.random.randint(-H//20, H//20))
        ty = int(np.random.randint(-W//20, W//20))
        grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1,H,device=x.device),
                                        torch.linspace(-1,1,W,device=x.device), indexing='ij')
        grid = torch.stack((grid_x, grid_y), -1).unsqueeze(0).repeat(B,1,1,1)
        grid[...,0] = grid[...,0] + (2*tx/W)
        grid[...,1] = grid[...,1] + (2*ty/H)
        x = F.grid_sample(x, grid, padding_mode='reflection', align_corners=True)
    return x

# ===== 데이터셋 =====
class UnpairedDogs(Dataset):
    def __init__(self, young_dir, senior_dir, augment=True):
        self.young_paths  = sorted([p for p in Path(young_dir).glob("*") if is_img(p)])
        self.senior_paths = sorted([p for p in Path(senior_dir).glob("*") if is_img(p)])
        assert self.young_paths and self.senior_paths, "Young/Senior 폴더 확인!"
        self.Ny, self.Ns = len(self.young_paths), len(self.senior_paths)
        ops = [RandomJitter256()] if augment else [T.Resize((IMG_SIZE, IMG_SIZE))]
        ops += [T.ToTensor(), T.Normalize((0.5,)*3, (0.5,)*3)]
        self.tf = T.Compose(ops)
    def __len__(self): return max(self.Ny, self.Ns)
    def __getitem__(self, idx):
        y = self.tf(safe_open(self.young_paths[idx % self.Ny], IMG_SIZE))
        s = self.tf(safe_open(self.senior_paths[random.randint(0, self.Ns-1)], IMG_SIZE))
        return {"young": y, "senior": s}

# ===== Self-Attention =====
class SelfAttention(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.f = SN(nn.Conv2d(in_ch, in_ch//8, 1))
        self.g = SN(nn.Conv2d(in_ch, in_ch//8, 1))
        self.h = SN(nn.Conv2d(in_ch, in_ch, 1))
        self.gamma = nn.Parameter(torch.zeros(1))
    def forward(self, x):
        B,C,H,W = x.shape
        f = self.f(x).view(B, -1, H*W)         # B, C/8, N
        g = self.g(x).view(B, -1, H*W)         # B, C/8, N
        beta = torch.softmax(torch.bmm(f.transpose(1,2), g), dim=-1)  # B, N, N
        h = self.h(x).view(B, C, H*W)          # B, C, N
        o = torch.bmm(h, beta).view(B, C, H, W)
        return x + self.gamma * o

# ===== 모델 =====
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, 3), nn.InstanceNorm2d(dim), nn.ReLU(True),
            nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, 3), nn.InstanceNorm2d(dim)
        )
    def forward(self, x): return x + self.block(x)

class ResnetGenerator(nn.Module):
    def __init__(self, in_c=3, out_c=3, n_blocks=GEN_BLOCKS, ngf=NGF):
        super().__init__()
        m = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_c, ngf, 7), nn.InstanceNorm2d(ngf), nn.ReLU(True),
            nn.Conv2d(ngf, ngf*2, 3, 2, 1), nn.InstanceNorm2d(ngf*2), nn.ReLU(True),
            nn.Conv2d(ngf*2, ngf*4, 3, 2, 1), nn.InstanceNorm2d(ngf*4), nn.ReLU(True),
        ]
        for i in range(n_blocks):
            m += [ResnetBlock(ngf*4)]
            if i == n_blocks//2:
                m += [SelfAttention(ngf*4)]  # 중간 해상도에서 주의집중
        m += [
            nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 1, output_padding=1), nn.InstanceNorm2d(ngf*2), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*2, ngf,   3, 2, 1, output_padding=1), nn.InstanceNorm2d(ngf),   nn.ReLU(True),
            nn.ReflectionPad2d(3), nn.Conv2d(ngf, out_c, 7), nn.Tanh()
        ]
        self.model = nn.Sequential(*m)
    def forward(self, x): return self.model(x)

# 멀티스케일 PatchGAN(두 스케일) + Feature map 반환
class PatchDiscriminator(nn.Module):
    def __init__(self, in_c=3, ndf=NDF):
        super().__init__()
        def block(i,o,norm=True,s=2):
            layers=[SN(nn.Conv2d(i,o,4,s,1))]
            if norm: layers += [nn.InstanceNorm2d(o)]
            layers += [nn.LeakyReLU(0.2, True)]
            return nn.Sequential(*layers)
        self.b1 = block(in_c, ndf, norm=False)     # 128
        self.b2 = block(ndf, ndf*2)                # 64
        self.b3 = block(ndf*2, ndf*4)              # 32
        self.out = SN(nn.Conv2d(ndf*4, 1, 4, 1, 1))
    def forward(self, x):
        f1 = self.b1(x)
        f2 = self.b2(f1)
        f3 = self.b3(f2)
        logits = self.out(f3)
        return logits, [f1, f2, f3]

class MultiScaleD(nn.Module):
    def __init__(self, in_c=3, ndf=NDF, scales=2):
        super().__init__()
        self.scales = nn.ModuleList([PatchDiscriminator(in_c, ndf) for _ in range(scales)])
        self.pool = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
    def forward(self, x):
        preds, feats = [], []
        xi = x
        for i, d in enumerate(self.scales):
            logit, fmap = d(xi)
            preds.append(logit); feats.append(fmap)
            xi = self.pool(xi)
        return preds, feats  # list per scale

def weights_init_normal(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None: nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.InstanceNorm2d):
        if m.weight is not None: nn.init.normal_(m.weight.data, 1.0, 0.02)
        if m.bias is not None: nn.init.constant_(m.bias.data, 0)

# ===== 손실 =====
def hinge_g_loss(fake_logits_list):
    loss = 0.0
    for lg in fake_logits_list:
        loss = loss + (-lg).mean()
    return loss

def hinge_d_loss(real_logits_list, fake_logits_list):
    loss = 0.0
    for lr, lf in zip(real_logits_list, fake_logits_list):
        loss = loss + F.relu(1.0 - lr).mean() + F.relu(1.0 + lf).mean()
    return loss

def r1_reg(D, real):
    real.requires_grad_(True)
    preds, _ = D(real)
    # 합산하여 하나의 스칼라
    s = 0.0
    for p in preds: s = s + p.sum()
    grad = torch.autograd.grad(s, real, create_graph=True, retain_graph=True, only_inputs=True)[0]
    reg = grad.view(grad.size(0), -1).pow(2).sum(1).mean()
    real.requires_grad_(False)
    return reg

def cycle_loss(x, xr): return F.l1_loss(xr, x)
def idt_loss(x, y_same): return F.l1_loss(y_same, x)

# (옵션) Perceptual Loss
if USE_PERCEPTUAL:
    try:
        from torchvision.models import vgg19, VGG19_Weights
        VGG = vgg19(weights=VGG19_Weights.DEFAULT).features[:16].to(DEVICE).eval()
        for p in VGG.parameters(): p.requires_grad=False
        print("VGG19 loaded for perceptual loss.")
        def perceptual_loss(x,y):
            x = denorm(x); y = denorm(y)
            mean = torch.tensor([0.485,0.456,0.406], device=x.device).view(1,3,1,1)
            std  = torch.tensor([0.229,0.224,0.225], device=x.device).view(1,3,1,1)
            x = (x-mean)/std; y = (y-mean)/std
            return F.l1_loss(VGG(x), VGG(y))
    except Exception as e:
        USE_PERCEPTUAL = False
        print("[WARN] VGG unavailable, perceptual loss disabled.", e)

# ===== Feature Matching Loss =====
def feature_matching_loss(real_feats_scales, fake_feats_scales):
    # real/fake: list over scales → each: [f1,f2,f3]
    loss = 0.0
    for r_scale, f_scale in zip(real_feats_scales, fake_feats_scales):
        for r, f in zip(r_scale, f_scale):
            loss = loss + F.l1_loss(f, r.detach())
    return loss

# ===== Replay Buffer =====
class ImagePool:
    def __init__(self, size=REPLAY_POOL_SIZE): self.size=size; self.pool=[]
    def query(self, imgs):
        out=[]
        for im in imgs:
            im = im.detach().unsqueeze(0)
            if len(self.pool) < self.size:
                self.pool.append(im); out.append(im)
            elif random.random() > 0.5:
                idx = random.randint(0, len(self.pool)-1)
                tmp = self.pool[idx].clone()
                self.pool[idx] = im; out.append(tmp)
            else:
                out.append(im)
        return torch.cat(out, 0)

# ===== EMA =====
class EMAWrap:
    def __init__(self, model, decay=EMA_DECAY):
        self.decay=decay; self.ema=copy.deepcopy(model).eval()
        for p in self.ema.parameters(): p.requires_grad=False
    @torch.no_grad()
    def update(self, model):
        for pe, p in zip(self.ema.parameters(), model.parameters()):
            pe.mul_(self.decay).add_(p.data, alpha=1.0-self.decay)

# ===== 데이터로더 =====
train_ds = UnpairedDogs(YOUNG_DIR, SENIOR_DIR, augment=True)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=0, drop_last=True, persistent_workers=False,
                          pin_memory=torch.cuda.is_available())
print(f"~{len(train_loader)*BATCH_SIZE} imgs/epoch (batch={BATCH_SIZE}, size={IMG_SIZE})")

# 고정 샘플
fix_batch = next(iter(DataLoader(UnpairedDogs(YOUNG_DIR, SENIOR_DIR, augment=False),
                                 batch_size=4, shuffle=True, num_workers=0, drop_last=True)))
fix_y = fix_batch["young"].to(DEVICE, non_blocking=True)
fix_s = fix_batch["senior"].to(DEVICE, non_blocking=True)

# ===== 모델/옵티마/스케줄러/스케일러 =====
G_y2s = ResnetGenerator().to(DEVICE); G_s2y = ResnetGenerator().to(DEVICE)
D_y   = MultiScaleD().to(DEVICE);     D_s   = MultiScaleD().to(DEVICE)
for m in (G_y2s,G_s2y):
    m.apply(weights_init_normal)

opt_G = torch.optim.Adam(list(G_y2s.parameters())+list(G_s2y.parameters()), lr=LR, betas=BETAS)
opt_D = torch.optim.Adam(list(D_y.parameters())+list(D_s.parameters()), lr=LR, betas=BETAS)

def lr_lambda(ep):
    return 1.0 if ep < DECAY_FROM else max(0.0, 1.0 - (ep+1-DECAY_FROM)/(EPOCHS-DECAY_FROM))
sch_G = torch.optim.lr_scheduler.LambdaLR(opt_G, lr_lambda)
sch_D = torch.optim.lr_scheduler.LambdaLR(opt_D, lr_lambda)
scaler_G = torch.cuda.amp.GradScaler(enabled=USE_AMP)
scaler_D = torch.cuda.amp.GradScaler(enabled=USE_AMP)

ema_y2s = EMAWrap(G_y2s); ema_s2y = EMAWrap(G_s2y)
pool_s, pool_y = ImagePool(), ImagePool()

# ===== 체크포인트 =====
def save_ckpt(epoch):
    torch.save({
        "G_y2s": G_y2s.state_dict(), "G_s2y": G_s2y.state_dict(),
        "D_y": D_y.state_dict(), "D_s": D_s.state_dict(),
        "EMA_y2s": ema_y2s.ema.state_dict(), "EMA_s2y": ema_s2y.ema.state_dict(),
        "opt_G": opt_G.state_dict(), "opt_D": opt_D.state_dict(),
        "sch_G": sch_G.state_dict(), "sch_D": sch_D.state_dict(),
        "epoch": epoch
    }, CKPT_DIR/f"e{epoch:04d}.pt")

# ===== 학습 =====
global_step = 0
for epoch in range(EPOCHS):
    G_y2s.train(); G_s2y.train(); D_y.train(); D_s.train()
    running={"G":0.0,"D":0.0,"cyc":0.0,"idt":0.0,"per":0.0,"fm":0.0,"r1":0.0,"adv":0.0}
    n_steps=0; t0=time.time()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    for batch in pbar:
        real_y = batch["young"].to(DEVICE, non_blocking=True)
        real_s = batch["senior"].to(DEVICE, non_blocking=True)

        if USE_DIFFAUG:
            real_y = diffaug(real_y); real_s = diffaug(real_s)

        # -- G --
        opt_G.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            fake_s = G_y2s(real_y); rec_y = G_s2y(fake_s)
            fake_y = G_s2y(real_s); rec_s = G_y2s(fake_y)
            idt_s  = G_y2s(real_s); idt_y = G_s2y(real_y)

            # Discriminator preds for adv + FM
            pred_s_fake, feats_s_fake = D_s(fake_s)
            pred_y_fake, feats_y_fake = D_y(fake_y)

            # adv (hinge, generator part)
            loss_adv = hinge_g_loss(pred_s_fake) + hinge_g_loss(pred_y_fake)

            # cycle / identity
            loss_cyc = LAMBDA_CYC*(cycle_loss(real_y, rec_y) + cycle_loss(real_s, rec_s))
            loss_idt = LAMBDA_IDT*(idt_loss(real_s, idt_s) + idt_loss(real_y, idt_y))

            # perceptual (옵션)
            if USE_PERCEPTUAL:
                loss_per = LAMBDA_PER*(perceptual_loss(fake_s, real_s) + perceptual_loss(fake_y, real_y))
            else:
                loss_per = torch.tensor(0.0, device=DEVICE)

        # FM: real feats 필요 → D에 real도 통과
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            pred_s_real, feats_s_real = D_s(real_s)
            pred_y_real, feats_y_real = D_y(real_y)
            loss_fm = LAMBDA_FM*(feature_matching_loss(feats_s_real, feats_s_fake) +
                                 feature_matching_loss(feats_y_real, feats_y_fake))

            loss_G = loss_adv + loss_cyc + loss_idt + loss_per + loss_fm

        scaler_G.scale(loss_G).backward()
        scaler_G.step(opt_G); scaler_G.update()

        ema_y2s.update(G_y2s); ema_s2y.update(G_s2y)

        # -- D --
        opt_D.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            fake_s_pool = pool_s.query(fake_s)
            fake_y_pool = pool_y.query(fake_y)

            pred_s_real, _ = D_s(real_s.detach())
            pred_s_fake, _ = D_s(fake_s_pool.detach())
            pred_y_real, _ = D_y(real_y.detach())
            pred_y_fake, _ = D_y(fake_y_pool.detach())

            loss_D = hinge_d_loss(pred_s_real, pred_s_fake) + hinge_d_loss(pred_y_real, pred_y_fake)

            loss_R1 = torch.tensor(0.0, device=DEVICE)
            if USE_R1 and (global_step % R1_INTERVAL == 0):
                loss_R1 = (R1_LAMBDA * 0.5) * (r1_reg(D_s, real_s.detach()) + r1_reg(D_y, real_y.detach()))

            loss_D_total = loss_D + loss_R1

        scaler_D.scale(loss_D_total).backward()
        scaler_D.step(opt_D); scaler_D.update()

        # 로그 적산
        running["G"]  += float(loss_G)
        running["D"]  += float(loss_D)
        running["cyc"]+= float(loss_cyc)
        running["idt"]+= float(loss_idt)
        running["per"]+= float(loss_per)
        running["fm"] += float(loss_fm)
        running["adv"]+= float(loss_adv)
        running["r1"] += float(loss_R1)
        n_steps+=1; global_step+=1

        if n_steps % LOG_EVERY_STEPS == 0:
            pbar.set_postfix(G=running["G"]/n_steps, D=running["D"]/n_steps, adv=running["adv"]/n_steps)

        if (MAX_STEPS_PER_EPOCH is not None) and (n_steps >= MAX_STEPS_PER_EPOCH):
            break

    for k in running: running[k]/=max(1,n_steps)
    print(f"[E{epoch+1:03d}] G={running['G']:.3f} D={running['D']:.3f} adv={running['adv']:.3f} "
          f"cyc={running['cyc']:.3f} idt={running['idt']:.3f} per={running['per']:.3f} "
          f"fm={running['fm']:.3f} r1={running['r1']:.3f} | {n_steps} steps | {time.time()-t0:.1f}s")

    sch_G.step(); sch_D.step()

    if (epoch+1)%SAVE_EVERY_EPOCHS==0:
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=USE_AMP):
            Gs = ema_y2s.ema.eval(); Gy = ema_s2y.ema.eval()
            f_s = Gs(fix_y); r_y = Gy(f_s)
            f_y = Gy(fix_s); r_s = Gs(f_y)
            grid = vutils.make_grid(torch.cat([denorm(fix_y),denorm(f_s),denorm(r_y),
                                               denorm(fix_s),denorm(f_y),denorm(r_s)],0), nrow=4)
            vutils.save_image(grid, SAMPLE_DIR/f"e{epoch+1:04d}.jpg")
        save_ckpt(epoch+1)

print("Training finished.")

# ===== 추론 =====
@torch.no_grad()
def load_for_infer(ckpt_path=None):
    g = ResnetGenerator().to(DEVICE).eval()
    if ckpt_path is None:
        ckpts = sorted(CKPT_DIR.glob("e*.pt")); assert ckpts, "체크포인트가 없습니다."
        ckpt_path = ckpts[-1]
    data = torch.load(ckpt_path, map_location=DEVICE)
    state = data.get("EMA_y2s", None) or data["G_y2s"]
    g.load_state_dict(state); print("Loaded:", Path(ckpt_path).name)
    return g

infer_tf = T.Compose([T.Resize((IMG_SIZE, IMG_SIZE)),
                      T.ToTensor(), T.Normalize((0.5,)*3,(0.5,)*3)])

@torch.no_grad()
def infer_single(input_path, out_path=None, ckpt_path=None):
    G = load_for_infer(ckpt_path)
    x = infer_tf(safe_open(Path(input_path), IMG_SIZE)).unsqueeze(0).to(DEVICE, non_blocking=True)
    with torch.cuda.amp.autocast(enabled=USE_AMP): y = denorm(G(x))
    out_path = out_path or OUT_DIR/f"pred_{Path(input_path).stem}.jpg"
    vutils.save_image(y, out_path); print("Saved:", out_path); return out_path

@torch.no_grad()
def infer_dir(in_dir, out_dir=None, ckpt_path=None):
    G = load_for_infer(ckpt_path)
    out_dir = Path(out_dir or OUT_DIR/"inferred"); out_dir.mkdir(parents=True, exist_ok=True)
    paths = sorted([p for p in Path(in_dir).glob("*") if is_img(p)])
    for p in paths:
        x = infer_tf(safe_open(p, IMG_SIZE)).unsqueeze(0).to(DEVICE, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP): y = denorm(G(x))
        vutils.save_image(y, out_dir/f"{p.stem}_senior.jpg")
    print(f"Saved {len(paths)} files to {out_dir}")
