In [None]:
# !rm -rf /kaggle/working/*
# %cd /kaggle/working
# !git clone --branch Distributed-Data-Parallel https://github.com/TAYDOVAT/Cuoi_Ki_DL.git
# !pip install lpips
# %cd /kaggle/working/Cuoi_Ki_DL

# !rm -r /kaggle/working/Cuoi_Ki_DL/test
# !rm -r /kaggle/working/Cuoi_Ki_DL/train
# !rm -r /kaggle/working/Cuoi_Ki_DL/val


# Train SRGAN x4 (DDP torchrun)

In [None]:
import os
import json
from pathlib import Path
import torch


In [None]:
# Config override here
cfg = {
    'scale': 4,
    'hr_crop': 96,
    'gan': {
        'batch_size': 32,
        'num_workers': 4,
        'epochs': 300,
        'lr_g': 1e-5,           # Learning rate cho Generator
        'lr_d': 1e-5,           # Learning rate cho Discriminator
        'adv_weight': 1e-2,     # Adversarial loss weight
        'perc_weight': 1,       # Perceptual loss weight
        'pixel_weight': 0,      # Pixel loss weight
        'r1_weight': 10.0,      # R1 gradient penalty
        'use_amp': True,        # Automatic Mixed Precision
        'd_steps': 1,           # So buoc train D moi iteration
        'g_steps': 2,           # So buoc train G moi iteration
        # ========== RESUME CONFIG ==========
        'resume': False,        # True: resume training, False: train tu dau
        'load_disc': False,     # True: load ca Discriminator, False: chi load Generator
        'checkpoint_path': 'weights/gan_checkpoint.pth',  # Path to checkpoint
    },
    'paths': {
        'train_lr': 'train/train_lr',
        'train_hr': 'train/train_hr',
        'val_lr': 'val/val_lr',
        'val_hr': 'val/val_hr',
        'test_lr': 'test/test_lr',
        'test_hr': 'test/test_hr',
    },
}

data_root = os.environ.get('DATA_ROOT')
kaggle_root = Path('/kaggle/input/anh-ve-tinh-2/Anh_ve_tinh_2')
base_dir = None
cwd = Path.cwd().resolve()

candidate = None
if data_root:
    candidate = Path(data_root).expanduser().resolve()
elif kaggle_root.is_dir():
    candidate = kaggle_root

if candidate is not None and (candidate / 'train' / 'train_lr').is_dir():
    base_dir = candidate
else:
    train_lr_path = Path(cfg['paths']['train_lr'])
    if train_lr_path.is_absolute():
        if train_lr_path.is_dir():
            base_dir = train_lr_path.parents[1]
    else:
        candidate = (cwd / train_lr_path).resolve()
        if candidate.is_dir():
            base_dir = candidate.parents[1]

if base_dir is None:
    for parent in [cwd] + list(cwd.parents):
        if (parent / 'train' / 'train_lr').is_dir():
            base_dir = parent
            break

if base_dir is None:
    raise FileNotFoundError(
        f"Cannot find dataset root. Set DATA_ROOT or update cfg['paths'] (cwd: {cwd})"
    )

cfg['paths']['train_lr'] = str(base_dir / 'train' / 'train_lr')
cfg['paths']['train_hr'] = str(base_dir / 'train' / 'train_hr')
cfg['paths']['val_lr'] = str(base_dir / 'val' / 'val_lr')
cfg['paths']['val_hr'] = str(base_dir / 'val' / 'val_hr')
cfg['paths']['test_lr'] = str(base_dir / 'test' / 'test_lr')
cfg['paths']['test_hr'] = str(base_dir / 'test' / 'test_hr')

config_path = 'configs_gan.json'
with open(config_path, 'w') as f:
    json.dump(cfg, f, indent=2)
print(f"Wrote {config_path}")


## Run torchrun DDP

In [None]:
import subprocess
import sys

nproc = torch.cuda.device_count()
if nproc < 2:
    raise RuntimeError('torchrun DDP requires >=2 GPUs')

cmd = [
    'torchrun', '--standalone', f'--nproc_per_node={nproc}',
    'train_gan_ddp.py', '--config', config_path
]
print('Launching:', ' '.join(cmd))
proc = subprocess.Popen(cmd)
print(f'Background PID: {proc.pid}')


## Monitor log and plots (epoch-level)

In [None]:
import time
import csv
import math
import matplotlib.pyplot as plt
from IPython.display import clear_output

LOG_PATH = 'logs/gan_log.csv'
TOTAL_EPOCHS = cfg['gan']['epochs']
REFRESH_SEC = 5

def read_log(path):
    if not os.path.exists(path):
        return []
    try:
        with open(path, 'r', newline='') as f:
            reader = csv.DictReader(f)
            return list(reader)
    except Exception:
        return []

def render_bar(cur, total, width=30):
    if total <= 0:
        return '[?]'
    cur = min(cur, total)
    filled = int(width * cur / total)
    return f"[{'#' * filled}{'.' * (width - filled)}] {cur}/{total}"

while True:
    rows = read_log(LOG_PATH)
    clear_output(wait=True)
    if not rows:
        print('Chua co log. Doi...')
        time.sleep(REFRESH_SEC)
        continue

    last = rows[-1]
    epoch = int(last['epoch'])
    print('Progress:', render_bar(epoch, TOTAL_EPOCHS))
    print(f"Epoch {epoch}/{TOTAL_EPOCHS}")
    print(
        f"Train G: {float(last['train_loss_g']):.4f} | "
        f"Val G: {float(last['val_loss_g']):.4f} | "
        f"LPIPS Val: {float(last['val_lpips']):.4f}"
    )

    epochs = [int(r['epoch']) for r in rows]

    train_loss_g = [float(r['train_loss_g']) for r in rows]
    val_loss_g = [float(r['val_loss_g']) for r in rows]
    train_loss_d = [float(r['train_loss_d']) for r in rows]
    val_loss_d = [float(r['val_loss_d']) for r in rows]

    train_psnr = [float(r['train_psnr']) for r in rows]
    val_psnr = [float(r['val_psnr']) for r in rows]
    train_ssim = [float(r['train_ssim']) for r in rows]
    val_ssim = [float(r['val_ssim']) for r in rows]
    train_lpips = [float(r['train_lpips']) for r in rows]
    val_lpips = [float(r['val_lpips']) for r in rows]

    train_d_real = [float(r['train_d_real_prob']) for r in rows]
    val_d_real = [float(r['val_d_real_prob']) for r in rows]
    train_d_fake = [float(r['train_d_fake_prob']) for r in rows]
    val_d_fake = [float(r['val_d_fake_prob']) for r in rows]

    plots = [
        ('loss_g', train_loss_g, val_loss_g),
        ('loss_d', train_loss_d, val_loss_d),
        ('psnr', train_psnr, val_psnr),
        ('ssim', train_ssim, val_ssim),
        ('lpips', train_lpips, val_lpips),
        ('d_real_prob', train_d_real, val_d_real),
        ('d_fake_prob', train_d_fake, val_d_fake),
    ]

    ncols = 3
    nrows = math.ceil(len(plots) / ncols)
    fig, axes = plt.subplots(nrows, ncols, figsize=(15, 4 * nrows))
    axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]

    for i, (title, train_vals, val_vals) in enumerate(plots):
        ax = axes[i]
        ax.plot(epochs, train_vals, label='train')
        ax.plot(epochs, val_vals, label='val')
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
        ax.legend()

    # Hide any unused subplots
    for j in range(len(plots), len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

    time.sleep(REFRESH_SEC)
