In [None]:
# !rm -rf /kaggle/working/*
# %cd /kaggle/working
# !git clone https://github.com/TAYDOVAT/Cuoi_Ki_DL.git
# !pip install lpips
# %cd /kaggle/working/Cuoi_Ki_DL

# !rm -r /kaggle/working/Cuoi_Ki_DL/test
# !rm -r /kaggle/working/Cuoi_Ki_DL/train
# !rm -r /kaggle/working/Cuoi_Ki_DL/val

# !cp -r "/kaggle/input/anh-ve-tinh/Ảnh vệ tinh/test" /kaggle/working/Cuoi_Ki_DL
# !cp -r "/kaggle/input/anh-ve-tinh/Ảnh vệ tinh/train" /kaggle/working/Cuoi_Ki_DL
# !cp -r "/kaggle/input/anh-ve-tinh/Ảnh vệ tinh/val" /kaggle/working/Cuoi_Ki_DL

# Train SRGAN x4

In [None]:
import os
import csv
import random
import torch
from pathlib import Path
from torch import optim
from torch.optim import lr_scheduler
from tqdm.auto import tqdm
from IPython.display import clear_output

from data import build_loader
from original_model import SRResNet, DiscriminatorForVGG
from losses import PixelLoss, PerceptualLoss, AdversarialLoss
from engine import (
    train_gan_epoch, val_gan_epoch,
    save_gan_checkpoint, load_gan_checkpoint,
    load_gan_history_from_log, rewrite_log_up_to_epoch
)
from vis import show_lr_sr_hr, plot_curves
import lpips

In [None]:
# Config override here
cfg = {
    'scale': 4,
    'hr_crop': 96,
    'gan': {
        'batch_size': 32,
        'num_workers': 4,
        'epochs': 300,
        'lr_g': 1e-5,           # Learning rate cho Generator
        'lr_d': 1e-5,           # Learning rate cho Discriminator
        'adv_weight': 5e-3,     # Adversarial loss weight
        'perc_weight': 1,       # Perceptual loss weight  
        'pixel_weight': 0,   # Pixel loss weight
        'r1_weight': 10.0,      # R1 gradient penalty
        'd_steps': 1,           # Số bước train D mỗi iteration
        'g_steps': 2,           # Số bước train G mỗi iteration
        # ========== RESUME CONFIG ==========
        'resume': True,        # True: resume training, False: train từ đầu
        'load_disc': True,      # True: load cả Discriminator, False: chỉ load Generator
        'checkpoint_path': 'weights/gan_checkpoint.pth',  # Path to checkpoint
    },
    'paths': {
        'train_lr': 'train/train_lr',
        'train_hr': 'train/train_hr',
        'val_lr': 'val/val_lr',
        'val_hr': 'val/val_hr',
        'test_lr': 'test/test_lr',
        'test_hr': 'test/test_hr',
    },
}
base_dir = None
cwd = Path.cwd().resolve()
for parent in [cwd] + list(cwd.parents):
    if (parent / 'train' / 'train_lr').is_dir():
        base_dir = parent
        break

if base_dir is None:
    raise FileNotFoundError(f"Cannot find 'train/train_lr' from cwd: {cwd}")

cfg['paths']['train_lr'] = str(base_dir / 'train' / 'train_lr')
cfg['paths']['train_hr'] = str(base_dir / 'train' / 'train_hr')
cfg['paths']['val_lr'] = str(base_dir / 'val' / 'val_lr')
cfg['paths']['val_hr'] = str(base_dir / 'val' / 'val_hr')
cfg['paths']['test_lr'] = str(base_dir / 'test' / 'test_lr')
cfg['paths']['test_hr'] = str(base_dir / 'test' / 'test_hr')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs('weights', exist_ok=True)

In [None]:
train_dataset, train_loader = build_loader(
    cfg['paths']['train_lr'], cfg['paths']['train_hr'],
    scale=cfg['scale'], hr_crop=cfg['hr_crop'],
    batch_size=cfg['gan']['batch_size'],
    num_workers=cfg['gan']['num_workers'],
    train=True
)
val_dataset, val_loader = build_loader(
    cfg['paths']['val_lr'], cfg['paths']['val_hr'],
    scale=cfg['scale'], hr_crop=cfg['hr_crop'],
    batch_size=8,
    num_workers=cfg['gan']['num_workers'],
    train=False
)


In [None]:
# ==================== Initialize Models ====================
generator = SRResNet(upscale=cfg['scale']).to(device)
discriminator = DiscriminatorForVGG().to(device)

# ==================== Initialize Optimizers & Schedulers ====================
optimizer_g = optim.Adam(generator.parameters(), lr=cfg['gan']['lr_g'])
optimizer_d = optim.Adam(discriminator.parameters(), lr=cfg['gan']['lr_d'])
scheduler_g = lr_scheduler.StepLR(optimizer_g, step_size=100000, gamma=0.5)
scheduler_d = lr_scheduler.StepLR(optimizer_d, step_size=100000, gamma=0.5)

# ==================== Loss Criteria ====================
pixel_criterion = PixelLoss().to(device)
perceptual_criterion = PerceptualLoss().to(device)
adversarial_criterion = AdversarialLoss().to(device)
lpips_metric = lpips.LPIPS(net='vgg').to(device)

weights = {
    'pixel': cfg['gan']['pixel_weight'],
    'perceptual': cfg['gan']['perc_weight'],
    'adversarial': cfg['gan']['adv_weight'],
}

# ==================== Resume or Fresh Start ====================
log_path = os.path.join('logs', 'gan_log.csv')

if cfg['gan']['resume']:
    # RESUME: Load checkpoint với tất cả states
    start_epoch, best_lpips = load_gan_checkpoint(
        generator=generator,
        discriminator=discriminator,
        optimizer_g=optimizer_g,
        optimizer_d=optimizer_d,
        scheduler_g=scheduler_g,
        scheduler_d=scheduler_d,
        path=cfg['gan']['checkpoint_path'],
        load_disc=cfg['gan']['load_disc'],
        device=device
    )
    # Load history từ log file
    history = load_gan_history_from_log(log_path, start_epoch)
    # Rewrite log file để clean (tránh duplicate)
    rewrite_log_up_to_epoch(log_path, history, start_epoch)
else:
    # FRESH START: Load pretrained SRResNet
    start_epoch = 1
    best_lpips = 100.0
    generator.load_state_dict(torch.load('weights/best_srresnet.pth', map_location=device))
    print("[INFO] Loaded Generator from 'weights/best_srresnet.pth'")
    print("[INFO] Initialized fresh Discriminator")
    
    # Empty history
    history = {
        'loss_g': {'train': [], 'val': []},
        'loss_d': {'train': [], 'val': []},
        'psnr': {'train': [], 'val': []},
        'ssim': {'train': [], 'val': []},
        'lpips': {'train': [], 'val': []},
        'd_real_prob': {'train': [], 'val': []},
        'd_fake_prob': {'train': [], 'val': []},
    }
    # Write fresh log header
    os.makedirs('logs', exist_ok=True)
    with open(log_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            'epoch', 'train_loss_g', 'val_loss_g', 'train_loss_d', 'val_loss_d',
            'train_d_real_prob', 'val_d_real_prob', 'train_d_fake_prob', 'val_d_fake_prob',
            'train_psnr', 'val_psnr', 'train_ssim', 'val_ssim', 'train_lpips', 'val_lpips',
        ])

print(f"\n{'='*50}")
print(f"Starting from epoch {start_epoch}, best LPIPS: {best_lpips:.4f}")
print(f"Resume: {cfg['gan']['resume']}, Load Disc: {cfg['gan']['load_disc']}")
print(f"{'='*50}")

In [None]:
epochs = cfg['gan']['epochs']

for epoch in range(start_epoch, epochs + 1):
    # ==================== Training ====================
    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs} [Train]')
    train_stats = train_gan_epoch(
        generator, discriminator, train_pbar,
        optimizer_g, optimizer_d, device,
        pixel_criterion, perceptual_criterion, adversarial_criterion,
        weights,
        lpips_metric=lpips_metric,
        g_steps=cfg['gan'].get('g_steps', 1),
        d_steps=cfg['gan'].get('d_steps', 1),
        r1_weight=cfg['gan'].get('r1_weight', 0.0),
    )

    # ==================== Validation ====================
    val_pbar = tqdm(val_loader, desc=f'Epoch {epoch}/{epochs} [Val]')
    val_stats = val_gan_epoch(
        generator, discriminator, val_pbar, device,
        pixel_criterion, perceptual_criterion, adversarial_criterion,
        weights,
        lpips_metric=lpips_metric
    )

    # ==================== Step Schedulers ====================
    scheduler_g.step()
    scheduler_d.step()

    # ==================== Update History ====================
    history['loss_g']['train'].append(train_stats['loss_g'])
    history['loss_g']['val'].append(val_stats['loss_g'])
    history['loss_d']['train'].append(train_stats['loss_d'])
    history['loss_d']['val'].append(val_stats['loss_d'])
    history['d_real_prob']['train'].append(train_stats['d_real_prob'])
    history['d_real_prob']['val'].append(val_stats['d_real_prob'])
    history['d_fake_prob']['train'].append(train_stats['d_fake_prob'])
    history['d_fake_prob']['val'].append(val_stats['d_fake_prob'])
    history['psnr']['train'].append(train_stats['psnr'])
    history['psnr']['val'].append(val_stats['psnr'])
    history['ssim']['train'].append(train_stats['ssim'])
    history['ssim']['val'].append(val_stats['ssim'])
    history['lpips']['train'].append(train_stats['lpips'])
    history['lpips']['val'].append(val_stats['lpips'])

    # ==================== Append to Log ====================
    with open(log_path, 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            epoch,
            train_stats['loss_g'],
            val_stats['loss_g'],
            train_stats['loss_d'],
            val_stats['loss_d'],
            train_stats['d_real_prob'],
            val_stats['d_real_prob'],
            train_stats['d_fake_prob'],
            val_stats['d_fake_prob'],
            train_stats['psnr'],
            val_stats['psnr'],
            train_stats['ssim'],
            val_stats['ssim'],
            train_stats['lpips'],
            val_stats['lpips'],
        ])

    # ==================== Save Checkpoint (every epoch) ====================
    epoch_dir = os.path.join('weights', f'epoch_{epoch}')
    os.makedirs(epoch_dir, exist_ok=True)

    # Rolling checkpoint for resume
    save_gan_checkpoint(
        generator=generator,
        discriminator=discriminator,
        optimizer_g=optimizer_g,
        optimizer_d=optimizer_d,
        scheduler_g=scheduler_g,
        scheduler_d=scheduler_d,
        epoch=epoch,
        best_lpips=best_lpips,
        path=cfg['gan']['checkpoint_path']
    )

    # Epoch snapshot checkpoint
    save_gan_checkpoint(
        generator=generator,
        discriminator=discriminator,
        optimizer_g=optimizer_g,
        optimizer_d=optimizer_d,
        scheduler_g=scheduler_g,
        scheduler_d=scheduler_d,
        epoch=epoch,
        best_lpips=best_lpips,
        path=os.path.join(epoch_dir, f'gan_checkpoint_{epoch}.pth')
    )
    
    # Save individual weights (for compatibility)
    torch.save(generator.state_dict(), 'weights/last_gan.pth')
    torch.save(discriminator.state_dict(), 'weights/last_disc.pth')
    
    # Save individual weights per epoch
    torch.save(generator.state_dict(), os.path.join(epoch_dir, f'generator_{epoch}.pth'))
    torch.save(discriminator.state_dict(), os.path.join(epoch_dir, f'discriminator_{epoch}.pth'))

# ==================== Save Best Model ====================
    if val_stats['lpips'] < best_lpips:
        best_lpips = val_stats['lpips']
        torch.save(generator.state_dict(), 'weights/best_gan.pth')
        torch.save(discriminator.state_dict(), 'weights/best_disc.pth')
        print(f"[NEW BEST] LPIPS: {best_lpips:.4f}")

    # ==================== Visualization ====================
    clear_output(wait=True)

    rand_idx = random.randint(0, len(val_dataset) - 1)
    lr_sample, hr_sample = val_dataset[rand_idx]
    lr_in = lr_sample.unsqueeze(0).to(device)
    with torch.no_grad():
        sr_sample = generator(lr_in).cpu()
    show_lr_sr_hr(lr_sample, sr_sample, hr_sample)

    plot_curves(history)
    
    # Print info
    print(f"Epoch {epoch}/{epochs} | LR_G: {scheduler_g.get_last_lr()[0]:.6f} | LR_D: {scheduler_d.get_last_lr()[0]:.6f}")
    print(f"Best LPIPS: {best_lpips:.4f}")

print("\n" + "="*50)
print("GAN Training Completed!")
print(f"Best LPIPS: {best_lpips:.4f}")
print("="*50)

