In [None]:
# !rm -rf /kaggle/working/*
# %cd /kaggle/working
# !git clone https://github.com/TAYDOVAT/Cuoi_Ki_DL.git
# %cd /kaggle/working/Cuoi_Ki_DL

# !rm -r /kaggle/working/Cuoi_Ki_DL/test
# !rm -r /kaggle/working/Cuoi_Ki_DL/train
# !rm -r /kaggle/working/Cuoi_Ki_DL/val

# !cp -r "/kaggle/input/anh-ve-tinh/Ảnh vệ tinh/test" /kaggle/working/Cuoi_Ki_DL
# !cp -r "/kaggle/input/anh-ve-tinh/Ảnh vệ tinh/train" /kaggle/working/Cuoi_Ki_DL
# !cp -r "/kaggle/input/anh-ve-tinh/Ảnh vệ tinh/val" /kaggle/working/Cuoi_Ki_DL

# Train SRGAN x4

In [1]:
import os
import csv
import random
import torch
from pathlib import Path
from torch import optim
from torch.optim import lr_scheduler
from tqdm.auto import tqdm
from IPython.display import clear_output

from data import build_loader
from original_model import SRResNet, DiscriminatorForVGG
from losses import PixelLoss, PerceptualLoss, AdversarialLoss
from engine import train_gan_epoch, val_gan_epoch
from vis import show_lr_sr_hr, plot_curves


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Config override here
cfg = {
    'scale': 4,
    'hr_crop': 96,
    'gan': {
        'batch_size': 32,
        'num_workers': 4,
        'epochs': 200,
        'lr_g': 1e-6,
        'lr_d': 1e-8,
        'adv_weight': 1e-3,
        'perc_weight': 1,
        'pixel_weight': 1,
        'load_d': True,
        'd_ckpt': 'weights/DiscriminatorForVGG_x4-ImageNet.pth.tar',
    },
    'paths': {
        'train_lr': 'train/train_lr',
        'train_hr': 'train/train_hr',
        'val_lr': 'val/val_lr',
        'val_hr': 'val/val_hr',
        'test_lr': 'test/test_lr',
        'test_hr': 'test/test_hr',
    },
}
base_dir = None
cwd = Path.cwd().resolve()
for parent in [cwd] + list(cwd.parents):
    if (parent / 'train' / 'train_lr').is_dir():
        base_dir = parent
        break

if base_dir is None:
    raise FileNotFoundError(f"Cannot find 'train/train_lr' from cwd: {cwd}")

cfg['paths']['train_lr'] = str(base_dir / 'train' / 'train_lr')
cfg['paths']['train_hr'] = str(base_dir / 'train' / 'train_hr')
cfg['paths']['val_lr'] = str(base_dir / 'val' / 'val_lr')
cfg['paths']['val_hr'] = str(base_dir / 'val' / 'val_hr')
cfg['paths']['test_lr'] = str(base_dir / 'test' / 'test_lr')
cfg['paths']['test_hr'] = str(base_dir / 'test' / 'test_hr')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs('weights', exist_ok=True)


In [None]:
train_dataset, train_loader = build_loader(
    cfg['paths']['train_lr'], cfg['paths']['train_hr'],
    scale=cfg['scale'], hr_crop=cfg['hr_crop'],
    batch_size=cfg['gan']['batch_size'],
    num_workers=cfg['gan']['num_workers'],
    train=True
)
val_dataset, val_loader = build_loader(
    cfg['paths']['val_lr'], cfg['paths']['val_hr'],
    scale=cfg['scale'], hr_crop=cfg['hr_crop'],
    batch_size=32,
    num_workers=cfg['gan']['num_workers'],
    train=False
)


In [4]:
generator = SRResNet(upscale=cfg['scale']).to(device)
generator.load_state_dict(torch.load('weights/best_srresnet.pth', map_location=device))
discriminator = DiscriminatorForVGG().to(device)
if cfg['gan'].get('load_d', False):
    disc_path = cfg['gan'].get('d_ckpt', 'weights/DiscriminatorForVGG_x4-ImageNet.pth.tar')
    disc_ckpt = torch.load(disc_path, map_location=device)
    disc_state = disc_ckpt['state_dict'] if isinstance(disc_ckpt, dict) and 'state_dict' in disc_ckpt else disc_ckpt
    if isinstance(disc_state, dict) and any(k.startswith('_orig_mod.') for k in disc_state):
        disc_state = {k.replace('_orig_mod.', ''): v for k, v in disc_state.items()}
    discriminator.load_state_dict(disc_state)

pixel_criterion = PixelLoss().to(device)
perceptual_criterion = PerceptualLoss().to(device)
adversarial_criterion = AdversarialLoss().to(device)

optimizer_g = optim.Adam(generator.parameters(), lr=cfg['gan']['lr_g'])
optimizer_d = optim.Adam(discriminator.parameters(), lr=cfg['gan']['lr_d'])
scheduler_g = lr_scheduler.StepLR(optimizer_g, step_size=50, gamma=0.5)
scheduler_d = lr_scheduler.StepLR(optimizer_d, step_size=50, gamma=0.5)

weights = {
    'pixel': cfg['gan']['pixel_weight'],
    'perceptual': cfg['gan']['perc_weight'],
    'adversarial': cfg['gan']['adv_weight'],
}


In [5]:
history = {
    'loss_g': {'train': [], 'val': []},
    'loss_d': {'train': [], 'val': []},
    'psnr': {'train': [], 'val': []},
    'ssim': {'train': [], 'val': []},
}
log_dir = 'logs'
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, 'gan_log.csv')
if not os.path.exists(log_path):
    with open(log_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            'epoch',
            'train_loss_g',
            'val_loss_g',
            'train_loss_d',
            'val_loss_d',
            'train_psnr',
            'val_psnr',
            'train_ssim',
            'val_ssim',
        ])
best_psnr = -1.0
epochs = cfg['gan']['epochs']

for epoch in range(1, epochs + 1):
    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs} [Train]')
    train_stats = train_gan_epoch(
        generator, discriminator, train_pbar,
        optimizer_g, optimizer_d, device,
        pixel_criterion, perceptual_criterion, adversarial_criterion,
        weights
    )

    val_pbar = tqdm(val_loader, desc=f'Epoch {epoch}/{epochs} [Val]')
    val_stats = val_gan_epoch(
        generator, discriminator, val_pbar, device,
        pixel_criterion, perceptual_criterion, adversarial_criterion,
        weights
    )

    scheduler_g.step()
    scheduler_d.step()

    history['loss_g']['train'].append(train_stats['loss_g'])
    history['loss_g']['val'].append(val_stats['loss_g'])
    history['loss_d']['train'].append(train_stats['loss_d'])
    history['loss_d']['val'].append(val_stats['loss_d'])
    history['psnr']['train'].append(train_stats['psnr'])
    history['psnr']['val'].append(val_stats['psnr'])
    history['ssim']['train'].append(train_stats['ssim'])
    history['ssim']['val'].append(val_stats['ssim'])

    with open(log_path, 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            epoch,
            train_stats['loss_g'],
            val_stats['loss_g'],
            train_stats['loss_d'],
            val_stats['loss_d'],
            train_stats['psnr'],
            val_stats['psnr'],
            train_stats['ssim'],
            val_stats['ssim'],
        ])

    torch.save(generator.state_dict(), 'weights/last_gan.pth')
    if val_stats['psnr'] > best_psnr:
        best_psnr = val_stats['psnr']
        torch.save(generator.state_dict(), 'weights/best_gan.pth')

    clear_output(wait=True)

    rand_idx = random.randint(0, len(val_dataset) - 1)
    lr_sample, hr_sample = val_dataset[rand_idx]
    lr_in = lr_sample.unsqueeze(0).to(device)
    with torch.no_grad():
        sr_sample = generator(lr_in).cpu()
    show_lr_sr_hr(lr_sample, sr_sample, hr_sample)

    plot_curves(history)


Epoch 1/200 [Train]:   0%|          | 0/1 [00:10<?, ?it/s]


KeyboardInterrupt: 