In [None]:
# !git clone --branch Distributed-Data-Parallel https://github.com/TAYDOVAT/Cuoi_Ki_DL.git
# %cd /kaggle/working/Cuoi_Ki_DL


In [None]:
!pip install lpips

# Train SRResNet x4 (DDP torchrun)

In [None]:
from copy import deepcopy
from configs import CFG

cfg = deepcopy(CFG)


In [None]:
# Config override here
from pathlib import Path

# Dataset root (change if needed)
data_root = Path('..') / '..' / 'input' / 'anh-ve-tinh-2' / 'Anh_ve_tinh_2'

# Paths
cfg['paths']['train_lr'] = str(data_root / 'train' / 'train_lr')
cfg['paths']['train_hr'] = str(data_root / 'train' / 'train_hr')
cfg['paths']['val_lr'] = str(data_root / 'val' / 'val_lr')
cfg['paths']['val_hr'] = str(data_root / 'val' / 'val_hr')
cfg['paths']['test_lr'] = str(data_root / 'test' / 'test_lr')
cfg['paths']['test_hr'] = str(data_root / 'test' / 'test_hr')

# Model / data settings
cfg['scale'] = 4
cfg['hr_crop'] = 96

# Train settings
cfg['train']['batch_size'] = 32
cfg['train']['num_workers'] = 4
cfg['train']['epochs'] = 300
cfg['train']['lr'] = 1e-4
cfg['train']['loss'] = 'l1'  # 'l1' | 'ssim' | 'lpips'
cfg['train']['val_batch_size'] = 12
cfg['train']['use_amp'] = True
cfg['train']['resume'] = False  # set True to resume from last checkpoint
cfg['train']['load_pretrained_model'] = True
cfg['train']['pretrained_path'] = 'weights/SRResNet_x4-ImageNet.pth.tar'
cfg['train']['checkpoint_path'] = 'weights/srresnet_{loss}_checkpoint.pth'
cfg['train']['lr_step'] = 100000
cfg['train']['lr_gamma'] = 0.5


In [None]:
import os
import json
from pathlib import Path
import torch

loss_name = cfg['train']['loss']
os.makedirs('configs', exist_ok=True)
config_path = f"configs/srresnet_{loss_name}.json"
with open(config_path, 'w') as f:
    json.dump(cfg, f, indent=2)
print('Config saved to:', config_path)


## Run torchrun DDP

In [None]:
import subprocess
import sys

nproc = torch.cuda.device_count()
if nproc < 2:
    raise RuntimeError('torchrun DDP requires >=2 GPUs')

cmd = [
    'torchrun', '--standalone', f'--nproc_per_node={nproc}',
    'train_srresnet_ddp.py', '--config', config_path
]
print('Launching:', ' '.join(cmd))
proc = subprocess.Popen(cmd)
print(f'Background PID: {proc.pid}')


## Monitor log and plots (epoch-level)

In [None]:
import time
import csv
import math
import matplotlib.pyplot as plt
from IPython.display import clear_output

LOG_PATH = f'logs/srresnet_{loss_name}_log.csv'
TOTAL_EPOCHS = cfg['train']['epochs']
REFRESH_SEC = 30

def read_log(path):
    if not os.path.exists(path):
        return []
    try:
        with open(path, 'r', newline='') as f:
            reader = csv.DictReader(f)
            return list(reader)
    except Exception:
        return []

def render_bar(cur, total, width=30):
    if total <= 0:
        return '[?]'
    cur = min(cur, total)
    filled = int(width * cur / total)
    return f"[{'#' * filled}{'.' * (width - filled)}] {cur}/{total}"

while True:
    rows = read_log(LOG_PATH)
    clear_output(wait=True)
    if not rows:
        print('Chua co log. Doi...')
        time.sleep(REFRESH_SEC)
        continue

    last = rows[-1]
    epoch = int(last['epoch'])
    finished = epoch >= TOTAL_EPOCHS
    print('Progress:', render_bar(epoch, TOTAL_EPOCHS))
    print(f'Epoch {epoch}/{TOTAL_EPOCHS}')
    print(
        f"Train Loss: {float(last['train_loss']):.4f} | "
        f"Val Loss: {float(last['val_loss']):.4f} | "
        f"Train PSNR: {float(last['train_psnr']):.2f} | "
        f"Val PSNR: {float(last['val_psnr']):.2f} | "
        f"Train SSIM: {float(last['train_ssim']):.4f} | "
        f"Val SSIM: {float(last['val_ssim']):.4f} | "
        f"Train LPIPS: {float(last['train_lpips']):.4f} | "
        f"Val LPIPS: {float(last['val_lpips']):.4f}"
    )

    epochs = [int(r['epoch']) for r in rows]

    train_loss = [float(r['train_loss']) for r in rows]
    val_loss = [float(r['val_loss']) for r in rows]
    train_psnr = [float(r['train_psnr']) for r in rows]
    val_psnr = [float(r['val_psnr']) for r in rows]
    train_ssim = [float(r['train_ssim']) for r in rows]
    val_ssim = [float(r['val_ssim']) for r in rows]
    train_lpips = [float(r['train_lpips']) for r in rows]
    val_lpips = [float(r['val_lpips']) for r in rows]

    plots = [
        ('loss', train_loss, val_loss),
        ('psnr', train_psnr, val_psnr),
        ('ssim', train_ssim, val_ssim),
        ('lpips', train_lpips, val_lpips),
    ]

    ncols = 3
    nrows = math.ceil(len(plots) / ncols)
    fig, axes = plt.subplots(nrows, ncols, figsize=(15, 4 * nrows))
    axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]

    for i, (title, train_vals, val_vals) in enumerate(plots):
        ax = axes[i]
        ax.plot(epochs, train_vals, label='train')
        ax.plot(epochs, val_vals, label='val')
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
        ax.legend()

    for j in range(len(plots), len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

    if finished:
        print('Training completed. Stopping monitor.')
        break

    time.sleep(REFRESH_SEC)
