# SRGAN Training on CelebA / DIV2K (PixelForge)
This notebook trains a Super-Resolution GAN (SRGAN) using the preprocessed LR/HR pairs.
It assumes the dataset has been prepared in `data/celeba/...` or `data/div2k/...` and split into train/val.


1. 环境 & 导入

In [12]:
import os
from pathlib import Path
import time

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

# 本地项目根目录 - 更可靠的路径检测
# 方法：从当前工作目录向上查找，直到找到包含 src 目录的目录
current_dir = Path(".").resolve()
BASE_DIR = current_dir

# 向上查找项目根目录（包含 src 的目录）
max_levels = 3  # 最多向上查找3层
for level in range(max_levels + 1):
    test_dir = current_dir
    for _ in range(level):
        test_dir = test_dir.parent
    if (test_dir / "src").exists() and (test_dir / "src" / "CelebASRDataset.py").exists():
        BASE_DIR = test_dir
        break

# 数据目录可能在多个位置，按优先级查找
# 1. notebooks/outputs/final_tensors_sharded (如果从 notebooks 运行)
# 2. data/celeba/final_tensors_sharded_fast (项目根目录的标准位置)
notebooks_data_dir = BASE_DIR / "notebooks" / "outputs" / "final_tensors_sharded"
standard_data_dir = BASE_DIR / "data" / "celeba" / "final_tensors_sharded_fast"

if notebooks_data_dir.exists() and (notebooks_data_dir / "manifest_train.json").exists():
    DATA_DIR = notebooks_data_dir.parent.parent  # 指向 notebooks 目录的父目录，以便 DATA_DIR / "celeba" 能工作
    TENSOR_DIR = notebooks_data_dir
elif standard_data_dir.exists():
    DATA_DIR = BASE_DIR / "data"
    TENSOR_DIR = standard_data_dir
else:
    # 默认使用 notebooks/outputs
    DATA_DIR = BASE_DIR / "notebooks"
    TENSOR_DIR = BASE_DIR / "notebooks" / "outputs" / "final_tensors_sharded"

EXP_DIR = BASE_DIR / "experiments" / "celeba_srgan_run1"
EXP_DIR.mkdir(parents=True, exist_ok=True)

print("Current working dir:", Path(".").resolve())
print("Base dir (project root):", BASE_DIR)
print("Tensor data dir:", TENSOR_DIR)
print("Tensor dir exists:", TENSOR_DIR.exists())
print("Src dir exists:", (BASE_DIR / "src").exists())
print("Exp dir:", EXP_DIR)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Current working dir: /Users/liangwenlong/study/bme/3/DeepLearning/Topic/SGN/SuperResolution_Project/notebooks
Base dir (project root): /Users/liangwenlong/study/bme/3/DeepLearning/Topic/SGN/SuperResolution_Project
Tensor data dir: /Users/liangwenlong/study/bme/3/DeepLearning/Topic/SGN/SuperResolution_Project/notebooks/outputs/final_tensors_sharded
Tensor dir exists: True
Src dir exists: True
Exp dir: /Users/liangwenlong/study/bme/3/DeepLearning/Topic/SGN/SuperResolution_Project/experiments/celeba_srgan_run1
Device: cpu


2. Dataset 加载（用前面写过的 src/）

In [14]:
import sys
import importlib

# 添加 src 目录到 Python 路径
src_path = BASE_DIR / "src"
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

print(f"Added to sys.path: {src_path}")
print(f"Src dir exists: {src_path.exists()}")

# 清除可能存在的模块缓存
if 'CelebASRDataset' in sys.modules:
    del sys.modules['CelebASRDataset']
    print("Cleared CelebASRDataset from sys.modules")

# 清除 __pycache__ 以确保使用最新代码
import shutil
pycache_path = src_path / "__pycache__"
if pycache_path.exists():
    shutil.rmtree(pycache_path)
    print("Cleared __pycache__")

# 导入数据集类
try:
    from CelebASRDataset import CelebASRDataset
    print("✓ Successfully imported CelebASRDataset")
except ImportError as e:
    print(f"✗ Import error: {e}")
    print(f"Trying to inspect the module...")
    import os
    celeb_file = src_path / "CelebASRDataset.py"
    if celeb_file.exists():
        print(f"File exists: {celeb_file}")
        with open(celeb_file, 'r') as f:
            content = f.read()
            if 'class CelebASRDataset' in content:
                print("✓ Class definition found in file")
            else:
                print("✗ Class definition NOT found in file")
    raise

# 如果你文件名不完全一致，改这里
# from DIV2KSRDataset import DIV2KSRDataset

# 使用检测到的数据目录
print(f"\nUsing tensor data dir: {TENSOR_DIR}")
print(f"Tensor dir exists: {TENSOR_DIR.exists()}")

# 这里选一个数据集
train_dataset = CelebASRDataset(
    root_dir=TENSOR_DIR,
    split="train"
)

val_dataset = CelebASRDataset(
    root_dir=TENSOR_DIR,
    split="val"
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

print("\n✓ Datasets loaded successfully!")
print("Train batches:", len(train_loader))
print("Val batches:", len(val_loader))


Added to sys.path: /Users/liangwenlong/study/bme/3/DeepLearning/Topic/SGN/SuperResolution_Project/src
Src dir exists: True
Cleared CelebASRDataset from sys.modules
✓ Successfully imported CelebASRDataset

Using tensor data dir: /Users/liangwenlong/study/bme/3/DeepLearning/Topic/SGN/SuperResolution_Project/notebooks/outputs/final_tensors_sharded
Tensor dir exists: True

✓ Datasets loaded successfully!
Train batches: 11397
Val batches: 634


In [15]:
# 验证路径设置
import os
print("=" * 60)
print("Path Verification:")
print("=" * 60)
print(f"BASE_DIR: {BASE_DIR}")
print(f"BASE_DIR exists: {BASE_DIR.exists()}")
print(f"\nSrc directory: {BASE_DIR / 'src'}")
print(f"Src exists: {(BASE_DIR / 'src').exists()}")
if (BASE_DIR / "src").exists():
    print(f"Files in src: {os.listdir(BASE_DIR / 'src')}")
print(f"\nTensor data directory: {TENSOR_DIR}")
print(f"Tensor dir exists: {TENSOR_DIR.exists()}")
if TENSOR_DIR.exists():
    manifest_files = [f for f in os.listdir(TENSOR_DIR) if f.startswith("manifest_")]
    print(f"Manifest files: {manifest_files}")
print("=" * 60)


Path Verification:
BASE_DIR: /Users/liangwenlong/study/bme/3/DeepLearning/Topic/SGN/SuperResolution_Project
BASE_DIR exists: True

Src directory: /Users/liangwenlong/study/bme/3/DeepLearning/Topic/SGN/SuperResolution_Project/src
Src exists: True
Files in src: ['transforms.py', '.DS_Store', 'CelebASRDataset.py', '__pycache__', 'dataset.py', 'DIV2KSRDataset.py']

Tensor data directory: /Users/liangwenlong/study/bme/3/DeepLearning/Topic/SGN/SuperResolution_Project/notebooks/outputs/final_tensors_sharded
Tensor dir exists: True
Manifest files: ['manifest_test.json', 'manifest_train.json', 'manifest_val.json']


3. 模型占位（Generator / Discriminator）

In [None]:
# ====== Generator (placeholder) ======
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO: replace with Yosr's SRGAN generator
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 9, padding=4),
            nn.PReLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )
    def forward(self, x):
        return self.net(x)

# ====== Discriminator (placeholder) ======
class Discriminator(nn.Module):
    def __init__(self, hr_size=128):
        super().__init__()
        # TODO: replace with Yosr's SRGAN discriminator
        # 计算经过卷积后的特征图尺寸
        # 输入: (batch, 3, hr_size, hr_size)
        # Conv1: stride=2, padding=1, kernel=3 -> (batch, 64, hr_size//2, hr_size//2)
        # 对于 hr_size=128: 输出是 (batch, 64, 64, 64)
        # 使用 AdaptiveAvgPool2d 来避免硬编码尺寸
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),  # (128,128) -> (64,64)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # (64,64) -> (32,32)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),  # (32,32) -> (16,16)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 3, stride=2, padding=1),  # (16,16) -> (8,8)
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AdaptiveAvgPool2d(1),  # (8,8) -> (1,1)
            nn.Flatten(),  # (batch, 512)
            nn.Linear(512, 1)  # (batch, 1)
        )
    def forward(self, x):
        return self.net(x)

generator     = Generator().to(device)
discriminator = Discriminator().to(device)

print("G params:", sum(p.numel() for p in generator.parameters())/1e6, "M")
print("D params:", sum(p.numel() for p in discriminator.parameters())/1e6, "M")


G params: 0.017348 M
D params: 0.018177 M


4. 损失函数 & 优化器
SRGAN 标配是 content loss (VGG) + adversarial loss。现在先放一个简化版，能跑就行。

In [17]:
bce_loss   = nn.BCEWithLogitsLoss().to(device)
l1_loss    = nn.L1Loss().to(device)

g_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.9, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.9, 0.999))


In [None]:
5. 训练循环（简化版

In [18]:
from collections import defaultdict

num_epochs = 5
sample_every = 200  # 每多少 step 存一张图

history = defaultdict(list)

for epoch in range(num_epochs):
    generator.train()
    discriminator.train()
    for step, batch in enumerate(train_loader):
        # 你的 Dataset 应该返回 (lr, hr)
        lr, hr = batch
        lr = lr.to(device)
        hr = hr.to(device)

        # ---------------------
        # 1) Train Discriminator
        # ---------------------
        sr = generator(lr).detach()  # 生成的假图
        real_logits = discriminator(hr)
        fake_logits = discriminator(sr)

        real_labels = torch.ones_like(real_logits)
        fake_labels = torch.zeros_like(fake_logits)

        d_loss_real = bce_loss(real_logits, real_labels)
        d_loss_fake = bce_loss(fake_logits, fake_labels)
        d_loss = (d_loss_real + d_loss_fake) * 0.5

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ---------------------
        # 2) Train Generator
        # ---------------------
        sr = generator(lr)
        fake_logits = discriminator(sr)
        adv_loss = bce_loss(fake_logits, torch.ones_like(fake_logits))
        content_loss = l1_loss(sr, hr)
        g_loss = content_loss + 1e-3 * adv_loss

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # log
        history["d_loss"].append(d_loss.item())
        history["g_loss"].append(g_loss.item())

        if (step + 1) % 50 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Step [{step+1}/{len(train_loader)}] "
                  f"D: {d_loss.item():.4f}  G: {g_loss.item():.4f}")

        # 保存可视化样例
        if (step + 1) % sample_every == 0:
            out_dir = EXP_DIR / "samples"
            out_dir.mkdir(parents=True, exist_ok=True)
            # 反归一化要看你前面 transforms 怎么做的，这里先直接存
            grid = sr[0].detach().cpu().permute(1,2,0).numpy()
            plt.imshow(grid)
            plt.title(f"epoch{epoch+1}_step{step+1}")
            plt.axis("off")
            plt.savefig(out_dir / f"e{epoch+1}_s{step+1}.png")
            plt.close()

    # 每个 epoch 保存一次权重
    torch.save(generator.state_dict(), EXP_DIR / f"generator_epoch{epoch+1}.pth")
    torch.save(discriminator.state_dict(), EXP_DIR / f"discriminator_epoch{epoch+1}.pth")




RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x262144 and 16384x1)

6. 验证 / PSNR / SSIM（可选）
这里的数据范围我写了 data_range=1.0，如果是 0–255 记得改回 255。

In [None]:
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

def evaluate_psnr_ssim(model, dataloader, max_batches=10):
    model.eval()
    psnrs, ssims = [], []
    with torch.no_grad():
        for i, (lr, hr) in enumerate(dataloader):
            if i >= max_batches:
                break
            lr = lr.to(device)
            hr = hr.to(device)
            sr = model(lr)

            # 转 numpy
            sr_np = sr[0].detach().cpu().permute(1,2,0).numpy()
            hr_np = hr[0].detach().cpu().permute(1,2,0).numpy()

            psnrs.append(peak_signal_noise_ratio(hr_np, sr_np, data_range=1.0))
            ssims.append(structural_similarity(hr_np, sr_np, channel_axis=2, data_range=1.0))

    return sum(psnrs)/len(psnrs), sum(ssims)/len(ssims)

psnr, ssim = evaluate_psnr_ssim(generator, val_loader, max_batches=5)
print("Val PSNR:", psnr, " Val SSIM:", ssim)


7. 保存训练日志

In [None]:
import json

log_path = EXP_DIR / "train_log.json"
with open(log_path, "w") as f:
    json.dump(history, f)

print("Saved train log to", log_path)
