In [None]:
# !git clone --branch Distributed-Data-Parallel https://github.com/TAYDOVAT/Cuoi_Ki_DL.git
# %cd /kaggle/working/Cuoi_Ki_DL

# Train SRGAN x4 (DDP torchrun) - Single Config

- Notebook nay dung 1 config train duy nhat (khong candidate/sweep).
- Ban co the chon `G_LOSS_MODE = 'srgan'` hoac `'lpips_adv'` trong cell config.
- Dieu khien phase bang `TRAIN_PHASE = 'phase1'` hoac `TRAIN_PHASE = 'phase2'`.
- Quy uoc epoch tong:
  - `phase1`: train den epoch 60, `resume=False`
  - `phase2`: train tiep den epoch 120, `resume=True` tu `weights/srgan_60/checkpoint_srgan_60.pth`



In [None]:
from copy import deepcopy
from configs import CFG

cfg = deepcopy(CFG)
!pip install lpips

In [None]:
# Config override - Single config with 2-phase flow
from pathlib import Path

# Auto-detect data_root (Kaggle + local)
candidates = [
    Path('/kaggle/input/datasets/tyantran/anh-ve-tinh-2/Anh_ve_tinh_2'),
    Path('/kaggle/input/anh-ve-tinh-2/Anh_ve_tinh_2'),
    Path('../../input/anh-ve-tinh-2/Anh_ve_tinh_2'),
    Path('../input/anh-ve-tinh-2/Anh_ve_tinh_2'),
    Path('./datasets/anh-ve-tinh-2/Anh_ve_tinh_2'),
]

data_root = next((p.resolve() for p in candidates if p.exists()), None)
if data_root is None:
    raise FileNotFoundError('Khong tim thay data_root. Hay them candidate path phu hop may ban.')
print('Using data_root:', data_root)

# Paths
cfg['paths']['train_lr'] = str(data_root / 'train' / 'train_lr')
cfg['paths']['train_hr'] = str(data_root / 'train' / 'train_hr')
cfg['paths']['val_lr'] = str(data_root / 'val' / 'val_lr')
cfg['paths']['val_hr'] = str(data_root / 'val' / 'val_hr')
cfg['paths']['test_lr'] = str(data_root / 'test' / 'test_lr')
cfg['paths']['test_hr'] = str(data_root / 'test' / 'test_hr')

# Shared runtime settings
cfg['hr_crop'] = 96
cfg['gan']['train_batch_size'] = 32
cfg['gan']['val_batch_size'] = 12  # None -> auto by world_size
cfg['gan']['num_workers'] = 4
cfg['gan']['pin_memory'] = True
cfg['gan']['persistent_workers'] = True
cfg['gan']['use_amp'] = True
cfg['gan']['use_lpips'] = True

# Choose one loss mode: 'srgan' or 'lpips_adv'
G_LOSS_MODE = 'srgan'
cfg['gan']['g_loss_mode'] = G_LOSS_MODE

# Shared GAN dynamics
cfg['gan']['lr_g'] = 1e-4
cfg['gan']['lr_d'] = 3e-5
cfg['gan']['g_steps'] = 1
cfg['gan']['d_steps'] = 1
cfg['gan']['r1_weight'] = 2.0
cfg['gan']['r1_interval'] = 8
cfg['gan']['d_noise_std_start'] = 0.03
cfg['gan']['d_noise_std_end'] = 0.005
cfg['gan']['d_noise_decay_epochs'] = 60
cfg['gan']['real_label'] = 0.9
cfg['gan']['fake_label'] = 0.0
cfg['gan']['val_use_train_labels'] = True
cfg['gan']['scheduler_type'] = 'multistep'
cfg['gan']['milestones'] = [60, 90]
cfg['gan']['gamma'] = 0.5
cfg['gan']['pixel_weight'] = 0.0

# Mode-specific loss weights
if G_LOSS_MODE == 'srgan':
    cfg['gan']['perc_weight'] = 1.0
    cfg['gan']['adv_weight'] = 1e-3
    cfg['gan']['lpips_weight'] = 1.0  # compatibility only
elif G_LOSS_MODE == 'lpips_adv':
    cfg['gan']['perc_weight'] = 1.0  # kept for compatibility
    cfg['gan']['adv_weight'] = 3e-3
    cfg['gan']['lpips_weight'] = 1.0
else:
    raise ValueError(f'Unsupported G_LOSS_MODE: {G_LOSS_MODE}')

# 2-phase controls
TRAIN_PHASE = 'phase1'  # 'phase1' | 'phase2'
INIT_GEN_PATH = str(Path('.') / 'weights' / 'srresnet_lpips_epoch_20.pth')
PHASE2_CKPT = str(Path('.') / 'weights' / 'srgan_60' / 'checkpoint_srgan_60.pth')

if TRAIN_PHASE == 'phase1':
    cfg['gan']['resume'] = False
    cfg['gan']['epochs'] = 60
    cfg['gan']['init_gen_path'] = INIT_GEN_PATH
elif TRAIN_PHASE == 'phase2':
    if not Path(PHASE2_CKPT).exists():
        raise FileNotFoundError(f'Phase2 checkpoint not found: {PHASE2_CKPT}')
    cfg['gan']['resume'] = True
    cfg['gan']['epochs'] = 120
    cfg['gan']['checkpoint_path'] = PHASE2_CKPT
else:
    raise ValueError(f'Unsupported TRAIN_PHASE: {TRAIN_PHASE}')

print('TRAIN_PHASE:', TRAIN_PHASE)
print('Resume:', cfg['gan']['resume'])
print('Epoch target:', cfg['gan']['epochs'])
print('G loss mode:', cfg['gan']['g_loss_mode'])
print(
    f"LR(G/D)=({cfg['gan']['lr_g']}, {cfg['gan']['lr_d']}) | "
    f"weights(perc={cfg['gan']['perc_weight']}, lpips={cfg['gan']['lpips_weight']}, adv={cfg['gan']['adv_weight']})"
)
print(
    f"G:D steps={cfg['gan']['g_steps']}:{cfg['gan']['d_steps']} | "
    f"R1={cfg['gan']['r1_weight']}@{cfg['gan']['r1_interval']} | "
    f"noise={cfg['gan']['d_noise_std_start']}->{cfg['gan']['d_noise_std_end']} ({cfg['gan']['d_noise_decay_epochs']})"
)
if cfg['gan']['resume']:
    print('Checkpoint path:', cfg['gan']['checkpoint_path'])
else:
    print('Init generator path:', cfg['gan']['init_gen_path'])



In [None]:
import os
import json

os.makedirs('configs', exist_ok=True)
config_path = 'configs/gan_ddp_single.json'
with open(config_path, 'w') as f:
    json.dump(cfg, f, indent=2)

ACTIVE_CONFIG_PATH = config_path
ACTIVE_TOTAL_EPOCHS = cfg['gan']['epochs']

print('Config saved to:', ACTIVE_CONFIG_PATH)
print('TOTAL_EPOCHS =', ACTIVE_TOTAL_EPOCHS)



## Run torchrun DDP

In [None]:
import subprocess
import torch

nproc = torch.cuda.device_count()
if nproc < 1:
    raise RuntimeError('torchrun DDP requires at least 1 GPU')

cmd = [
    'torchrun', '--standalone', f'--nproc_per_node={nproc}',
    'train_gan_ddp.py', '--config', ACTIVE_CONFIG_PATH
]
print('Launching:', ' '.join(cmd))

RUN_IN_BACKGROUND = True
if RUN_IN_BACKGROUND:
    proc = subprocess.Popen(cmd)
    print(f'Background PID: {proc.pid}')
else:
    subprocess.run(cmd, check=True)
    print('Run finished.')



## Monitor log and plots (epoch-level)

In [None]:
import os
import time
import csv
import math
import statistics
import matplotlib.pyplot as plt
from IPython.display import clear_output

LOG_PATH = 'logs/gan_log.csv'
TOTAL_EPOCHS = ACTIVE_TOTAL_EPOCHS
REFRESH_SEC = 30


def read_log(path):
    if not os.path.exists(path):
        return []
    try:
        with open(path, 'r', newline='') as f:
            return list(csv.DictReader(f))
    except Exception:
        return []


def render_bar(cur, total, width=30):
    if total <= 0:
        return '[?]'
    cur = min(cur, total)
    filled = int(width * cur / total)
    return f"[{'#' * filled}{'.' * (width - filled)}] {cur}/{total}"


def to_float(row, key, default=0.0):
    try:
        return float(row.get(key, default))
    except Exception:
        return float(default)


while True:
    rows = read_log(LOG_PATH)
    clear_output(wait=True)

    if not rows:
        print('Chua co log. Doi...')
        time.sleep(REFRESH_SEC)
        continue

    last = rows[-1]
    epoch = int(last['epoch'])
    finished = epoch >= TOTAL_EPOCHS

    print('Progress:', render_bar(epoch, TOTAL_EPOCHS))
    print(f'Epoch {epoch}/{TOTAL_EPOCHS}')
    print(
        f"Train Loss G: {to_float(last, 'train_loss_g'):.4f} | "
        f"Val Loss G: {to_float(last, 'val_loss_g'):.4f} | "
        f"Train Loss D: {to_float(last, 'train_loss_d'):.4f} | "
        f"Val Loss D: {to_float(last, 'val_loss_d'):.4f} | "
        f"Val PSNR: {to_float(last, 'val_psnr'):.2f} | "
        f"Val LPIPS: {to_float(last, 'val_lpips'):.4f}"
    )

    epochs = [int(r['epoch']) for r in rows]

    plots = [
        ('loss_g', [to_float(r, 'train_loss_g') for r in rows], [to_float(r, 'val_loss_g') for r in rows]),
        ('loss_d', [to_float(r, 'train_loss_d') for r in rows], [to_float(r, 'val_loss_d') for r in rows]),
        ('d_real_prob', [to_float(r, 'train_d_real_prob') for r in rows], [to_float(r, 'val_d_real_prob') for r in rows]),
        ('d_fake_prob', [to_float(r, 'train_d_fake_prob') for r in rows], [to_float(r, 'val_d_fake_prob') for r in rows]),
        ('psnr', [to_float(r, 'train_psnr') for r in rows], [to_float(r, 'val_psnr') for r in rows]),
        ('ssim', [to_float(r, 'train_ssim') for r in rows], [to_float(r, 'val_ssim') for r in rows]),
        ('lpips', [to_float(r, 'train_lpips') for r in rows], [to_float(r, 'val_lpips') for r in rows]),
        ('loss_adv', [to_float(r, 'train_loss_adv') for r in rows], [to_float(r, 'val_loss_adv') for r in rows]),
        ('noise_std', [to_float(r, 'noise_std') for r in rows], [to_float(r, 'noise_std') for r in rows]),
    ]

    ncols = 3
    nrows = math.ceil(len(plots) / ncols)
    fig, axes = plt.subplots(nrows, ncols, figsize=(15, 4 * nrows))
    axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]

    for i, (title, train_vals, val_vals) in enumerate(plots):
        ax = axes[i]
        ax.plot(epochs, train_vals, label='train')
        ax.plot(epochs, val_vals, label='val')
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
        ax.legend()

    for j in range(len(plots), len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

    if finished:
        val_lpips = [to_float(r, 'val_lpips') for r in rows]
        val_psnr = [to_float(r, 'val_psnr') for r in rows]
        best_val_lpips = min(val_lpips) if val_lpips else float('nan')
        mean_val_psnr_last5 = statistics.mean(val_psnr[-5:]) if val_psnr else float('nan')

        print('Training completed. Stopping monitor.')
        print(f'Best val LPIPS: {best_val_lpips:.4f}')
        print(f'Mean val PSNR (last 5 epochs): {mean_val_psnr_last5:.3f}')
        break

    time.sleep(REFRESH_SEC)

