In [None]:
# !git clone --branch Distributed-Data-Parallel https://github.com/TAYDOVAT/Cuoi_Ki_DL.git
# %cd ./working/Cuoi_Ki_DL

# Train SRGAN x4 (DDP torchrun)

- `resume=False`: load SRResNet weight (`init_gen_path`) + initialize fresh Discriminator.
- `resume=True`: load full GAN checkpoint (`checkpoint_path`) to continue training.


In [None]:
from copy import deepcopy
from configs import CFG

cfg = deepcopy(CFG)
!pip install lpips

In [None]:
# Config override here
from pathlib import Path

# Auto-detect data_root (Kaggle + local)
candidates = [
    Path('/kaggle/input/datasets/tyantran/anh-ve-tinh-2/Anh_ve_tinh_2'),
    Path('/kaggle/input/anh-ve-tinh-2/Anh_ve_tinh_2'),
    Path('../../input/anh-ve-tinh-2/Anh_ve_tinh_2'),
    Path('../input/anh-ve-tinh-2/Anh_ve_tinh_2'),
    Path('./datasets/anh-ve-tinh-2/Anh_ve_tinh_2'),
]

data_root = next((p.resolve() for p in candidates if p.exists()), None)
if data_root is None:
    raise FileNotFoundError('Khong tim thay data_root. Hay them candidate path phu hop may ban.')
print('Using data_root:', data_root)

# Paths
cfg['paths']['train_lr'] = str(data_root / 'train' / 'train_lr')
cfg['paths']['train_hr'] = str(data_root / 'train' / 'train_hr')
cfg['paths']['val_lr'] = str(data_root / 'val' / 'val_lr')
cfg['paths']['val_hr'] = str(data_root / 'val' / 'val_hr')
cfg['paths']['test_lr'] = str(data_root / 'test' / 'test_lr')
cfg['paths']['test_hr'] = str(data_root / 'test' / 'test_hr')

# Patch size config (HR patch); LR patch = hr_crop // scale
cfg['hr_crop'] = 96

# GAN train config overrides
cfg['gan']['train_batch_size'] = 32
cfg['gan']['val_batch_size'] = 12  # None -> auto by world_size
cfg['gan']['num_workers'] = 4
cfg['gan']['pin_memory'] = True
cfg['gan']['persistent_workers'] = True
cfg['gan']['epochs'] = 50
cfg['gan']['lr_g'] = 1e-4
cfg['gan']['lr_d'] = 1e-4
cfg['gan']['use_amp'] = True
cfg['gan']['use_lpips'] = True  # Set True after: pip install lpips
cfg['gan']['g_loss_mode'] = 'srgan'  # 'srgan' | 'lpips_adv'
# srgan -> perc_weight * perc_loss + adv_weight * adv_loss
# lpips_adv -> lpips_weight * lpips_loss + adv_weight * adv_loss
cfg['gan']['g_steps'] = 2
cfg['gan']['d_steps'] = 1
cfg['gan']['adv_weight'] = 1e-3
cfg['gan']['perc_weight'] = 1
cfg['gan']['lpips_weight'] = 1
cfg['gan']['pixel_weight'] = 0  # compatibility key (not used in g_loss_mode: srgan/lpips_adv)
cfg['gan']['r1_weight'] = 10
cfg['gan']['real_label'] = 0.9
cfg['gan']['fake_label'] = 0.1

# Label override applies to TRAIN only. Validation keeps default labels (real=1.0, fake=0.0).

# Resume mechanism (2 modes only)
cfg['gan']['resume'] = False
cfg['gan']['init_gen_path'] = str(Path('..') / "weights" / 'srresnet_lpips_epoch_20.pth')
cfg['gan']['checkpoint_path'] = 'weights/srgan_10/checkpoint_srgan_10.pth'


In [None]:
import os
import json
import torch

os.makedirs('configs', exist_ok=True)
config_path = 'configs/gan_ddp.json'
with open(config_path, 'w') as f:
    json.dump(cfg, f, indent=2)
print('Config saved to:', config_path)


## Run torchrun DDP

In [None]:
import subprocess

nproc = torch.cuda.device_count()
if nproc < 1:
    raise RuntimeError('torchrun DDP requires at least 1 GPU')

cmd = [
    'torchrun', '--standalone', f'--nproc_per_node={nproc}',
    'train_gan_ddp.py', '--config', config_path
]
print('Launching:', ' '.join(cmd))
proc = subprocess.Popen(cmd)
print(f'Background PID: {proc.pid}')


## Monitor log and plots (epoch-level)

In [None]:
import time
import csv
import math
import matplotlib.pyplot as plt
from IPython.display import clear_output

LOG_PATH = 'logs/gan_log.csv'
TOTAL_EPOCHS = cfg['gan']['epochs']
REFRESH_SEC = 30

def read_log(path):
    if not os.path.exists(path):
        return []
    try:
        with open(path, 'r', newline='') as f:
            reader = csv.DictReader(f)
            return list(reader)
    except Exception:
        return []

def render_bar(cur, total, width=30):
    if total <= 0:
        return '[?]'
    cur = min(cur, total)
    filled = int(width * cur / total)
    return f"[{'#' * filled}{'.' * (width - filled)}] {cur}/{total}"

while True:
    rows = read_log(LOG_PATH)
    clear_output(wait=True)
    if not rows:
        print('Chua co log. Doi...')
        time.sleep(REFRESH_SEC)
        continue

    last = rows[-1]
    epoch = int(last['epoch'])
    finished = epoch >= TOTAL_EPOCHS
    print('Progress:', render_bar(epoch, TOTAL_EPOCHS))
    print(f'Epoch {epoch}/{TOTAL_EPOCHS}')
    print(
        f"Train Loss G: {float(last['train_loss_g']):.4f} | "
        f"Val Loss G: {float(last['val_loss_g']):.4f} | "
        f"Train Loss D: {float(last['train_loss_d']):.4f} | "
        f"Val Loss D: {float(last['val_loss_d']):.4f} | "
        f"Val PSNR: {float(last['val_psnr']):.2f} | "
        f"Val LPIPS: {float(last['val_lpips']):.4f}"
    )

    epochs = [int(r['epoch']) for r in rows]

    train_loss_g = [float(r['train_loss_g']) for r in rows]
    val_loss_g = [float(r['val_loss_g']) for r in rows]
    train_loss_d = [float(r['train_loss_d']) for r in rows]
    val_loss_d = [float(r['val_loss_d']) for r in rows]
    train_d_real_prob = [float(r['train_d_real_prob']) for r in rows]
    val_d_real_prob = [float(r['val_d_real_prob']) for r in rows]
    train_d_fake_prob = [float(r['train_d_fake_prob']) for r in rows]
    val_d_fake_prob = [float(r['val_d_fake_prob']) for r in rows]
    train_psnr = [float(r['train_psnr']) for r in rows]
    val_psnr = [float(r['val_psnr']) for r in rows]
    train_ssim = [float(r['train_ssim']) for r in rows]
    val_ssim = [float(r['val_ssim']) for r in rows]
    train_lpips = [float(r['train_lpips']) for r in rows]
    val_lpips = [float(r['val_lpips']) for r in rows]

    plots = [
        ('loss_g', train_loss_g, val_loss_g),
        ('loss_d', train_loss_d, val_loss_d),
        ('d_real_prob', train_d_real_prob, val_d_real_prob),
        ('d_fake_prob', train_d_fake_prob, val_d_fake_prob),
        ('psnr', train_psnr, val_psnr),
        ('ssim', train_ssim, val_ssim),
        ('lpips', train_lpips, val_lpips),
    ]

    ncols = 3
    nrows = math.ceil(len(plots) / ncols)
    fig, axes = plt.subplots(nrows, ncols, figsize=(15, 4 * nrows))
    axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]

    for i, (title, train_vals, val_vals) in enumerate(plots):
        ax = axes[i]
        ax.plot(epochs, train_vals, label='train')
        ax.plot(epochs, val_vals, label='val')
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
        ax.legend()

    for j in range(len(plots), len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

    if finished:
        print('Training completed. Stopping monitor.')
        break

    time.sleep(REFRESH_SEC)
