# Data

## Dataloader

In [None]:
import lightning as L

from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader, ConcatDataset


In [None]:
class CustomDataset(Dataset):
    def __init__(self, path, transform):
        super().__init__()
        self.path = Path(path)
        self.transform = transform
        self.data = list(self.path.rglob(pattern='*.*'))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = Image.open(fp=self.data[index]).convert(mode="RGB")
        return self.transform(image)


In [None]:
class CustomDataModule(L.LightningDataModule):
    def __init__(self, train_dir, valid_dir, infer_dir, bench_dir, transform, batch_size=32, num_workers=4):
        super().__init__()
        self.train_dir = Path(train_dir)
        self.valid_dir = Path(valid_dir)
        self.infer_dir = Path(infer_dir)
        self.bench_dir = Path(bench_dir)

        self.transform = transform
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_datasets = self._set_dataset(path=self.train_dir)
        self.valid_datasets = self._set_dataset(path=self.valid_dir)
        self.bench_datasets = self._set_dataset(path=self.bench_dir)
        self.infer_datasets = self._set_dataset(path=self.infer_dir)

    def _set_dataset(self, path):
        datasets = []
        for folder in path.iterdir():
            if folder.is_dir():
                datasets.append(
                    CustomDataset(
                        path=folder,
                        transform=self.transform
                    )
                )
        return datasets

    def _set_dataloader(self, datasets, concat=False, shuffle=False):
        if concat:
            dataloader = DataLoader(
                dataset=ConcatDataset(datasets=datasets),
                batch_size=self.batch_size,
                shuffle=shuffle,
                num_workers=self.num_workers,
                pin_memory=True
            )
            return dataloader
        else:
            dataloaders = []
            for dataset in datasets:
                loader = DataLoader(
                    dataset=dataset,
                    batch_size=self.batch_size,
                    shuffle=shuffle,
                    num_workers=self.num_workers,
                    pin_memory=True
                )
                dataloaders.append(loader)
            return dataloaders

    def train_dataloader(self):
        return self._set_dataloader(datasets=self.train_datasets, concat=True, shuffle=True)

    def val_dataloader(self):
        return self._set_dataloader(datasets=self.valid_datasets, concat=True)

    def test_dataloader(self):
        return self._set_dataloader(datasets=self.bench_datasets)

    def predict_dataloader(self):
        return self._set_dataloader(datasets=self.infer_datasets)


## Utils

In [None]:
from torchvision import transforms


In [None]:
class DataTransform:
    def __init__(self, image_size=256):
        self.image_size = image_size

        self.transform = self._build_transform()

    def _build_transform(self):
        base = [
            transforms.Resize(size=(self.image_size, self.image_size)),
            transforms.ToTensor(),
        ]

        return transforms.Compose(transforms=base)

    def __call__(self, image):
        return self.transform(img=image)


# Utils

## Metrics

In [None]:
import cv2
import math
import numpy as np
import torch
import torch.nn as nn

from scipy.ndimage import convolve
from scipy.special import gamma
from torchmetrics.image import (
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure,
    LearnedPerceptualImagePatchSimilarity,
)


In [None]:
class ImageQualityMetrics(nn.Module):
    def __init__(self, device="cuda", data_range=1.0):
        super().__init__()

        self.device_type = device

        # reference-based metrics
        self.psnr = PeakSignalNoiseRatio(
            data_range=data_range).to(device=device)
        self.ssim = StructuralSimilarityIndexMeasure(
            data_range=data_range).to(device=device)
        self.lpips = LearnedPerceptualImagePatchSimilarity(
            net_type='squeeze').to(device=device)

        # for NIQE (dummy pristine dist)
        self.niqe_stats = np.load(
            file="utils/files/niqe_params.npz", allow_pickle=True)
        self.mu_pris_param = self.niqe_stats['mu_pris_param']
        self.cov_pris_param = self.niqe_stats['cov_pris_param']
        self.gaussian_window = self.niqe_stats['gaussian_window']

    def forward(self, preds: torch.Tensor, targets: torch.Tensor):
        preds = preds.to(device=self.device_type)
        targets = targets.to(device=self.device_type)

        return {
            "PSNR": self.psnr(preds, targets).item(),
            "SSIM": self.ssim(preds, targets).item(),
            "LPIPS": self.lpips(preds, targets).squeeze().mean().item(),
        }

    def no_ref(self, preds: torch.Tensor):
        preds = preds.to(device=self.device_type)
        preds_np = preds.detach().cpu().numpy()
        preds_np = np.clip(a=preds_np, a_min=0, a_max=1)

        niqe_list = []
        brisque_list = []

        for img in preds_np:
            # (C, H, W) → (H, W, C)
            img_np = np.transpose(a=img, axes=(1, 2, 0))
            img_np_uint8 = (img_np * 255).astype(dtype=np.uint8)

            niqe = self._compute_niqe(img_np=img_np)
            brisuqe = self._compute_brisque(img=img_np_uint8)

            niqe_list.append(niqe)
            brisque_list.append(brisuqe)

        return {
            "NIQE": float(x=np.mean(a=niqe_list)),
            "BRISQUE": float(x=np.mean(a=brisque_list)),
        }

    def full(self, preds, targets):
        ref_metrics = self.forward(preds=preds, targets=targets)
        no_ref_metrics = self.no_ref(preds=preds)
        return {**ref_metrics, **no_ref_metrics}

    # --------------------------
    # Custom no-reference metrics
    # --------------------------

    def _compute_niqe(self, img_np):
        img = img_np.astype(np.float32)
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        gray = gray.round()
        return self._niqe(img=gray)

    def _compute_brisque(self, img):
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        brisque_score = cv2.quality.QualityBRISQUE_compute(
            img, "utils/files/brisque_model.yaml", "utils/files/brisque_range.yaml")
        return brisque_score

    def _niqe(self, img, block_size_h=96, block_size_w=96):
        assert img.ndim == 2
        h, w = img.shape
        num_block_h = math.floor(h / block_size_h)
        num_block_w = math.floor(w / block_size_w)
        img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]

        distparam = []
        for scale in (1, 2):
            mu = convolve(
                input=img, weights=self.gaussian_window, mode='nearest')
            sigma = np.sqrt(np.abs(convolve(
                input=np.square(img), weights=self.gaussian_window, mode='nearest') - np.square(mu)))
            img_nomalized = (img - mu) / (sigma + 1)

            feat = []
            for idx_w in range(num_block_w):
                for idx_h in range(num_block_h):
                    block = img_nomalized[
                        idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
                        idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale
                    ]
                    feat.append(self._compute_feature(block=block))

            distparam.append(np.array(object=feat))

            if scale == 1:
                img = cv2.resize(img / 255., dsize=(0, 0), fx=0.5,
                                 fy=0.5, interpolation=cv2.INTER_CUBIC) * 255.

        distparam = np.concatenate(distparam, axis=1)
        mu_distparam = np.nanmean(a=distparam, axis=0)
        distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
        cov_distparam = np.cov(m=distparam_no_nan, rowvar=False)

        invcov_param = np.linalg.pinv(
            (self.cov_pris_param + cov_distparam) / 2)
        quality = np.matmul(
            np.matmul((self.mu_pris_param - mu_distparam), invcov_param),
            np.transpose(a=(self.mu_pris_param - mu_distparam))
        )
        return float(x=np.sqrt(quality))

    def _compute_feature(self, block):
        def estimate_aggd_param(block):
            block = block.flatten()
            gam = np.arange(start=0.2, stop=10.001, step=0.001)
            gam_reciprocal = np.reciprocal(gam)
            r_gam = np.square(gamma(gam_reciprocal * 2)) / (
                gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))

            left_std = np.sqrt(np.mean(block[block < 0]**2))
            right_std = np.sqrt(np.mean(block[block > 0]**2))
            gammahat = left_std / right_std
            rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
            rhatnorm = (rhat * (gammahat**3 + 1) *
                        (gammahat + 1)) / ((gammahat**2 + 1)**2)
            array_position = np.argmin(a=(r_gam - rhatnorm)**2)
            alpha = gam[array_position]
            beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
            beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
            return (alpha, beta_l, beta_r)

        feat = []
        alpha, beta_l, beta_r = estimate_aggd_param(block=block)
        feat.extend([alpha, (beta_l + beta_r) / 2])
        shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
        for shift in shifts:
            shifted = np.roll(a=block, shift=shift, axis=(0, 1))
            alpha, beta_l, beta_r = estimate_aggd_param(block=block * shifted)
            mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
            feat.extend([alpha, mean, beta_l, beta_r])
        return feat


## Utils

In [None]:
from pathlib import Path
from torchvision.utils import save_image
from torchinfo import summary


In [None]:
def make_dirs(path: str | Path):
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    return path


def print_metrics(metrics: dict, prefix: str = ""):
    for k, v in metrics.items():
        print(f"{prefix}{k}: {v:.4f}")


def save_images(results, save_dir, prefix="infer", ext="png"):
    for i, datasets in enumerate(iterable=results):
        save_path = make_dirs(path=f"{save_dir}/batch{i+1}")
        for ii, batch in enumerate(iterable=datasets):
            save_image(
                tensor=batch,
                fp=save_path / f"{prefix}_{ii:04d}.{ext}",
                nrow=8,
                padding=2,
                normalize=True,
                value_range=(0, 1)
            )


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def summarize_model(model, input_size):
    return summary(model=model, input_size=input_size, depth=3, col_names=["input_size", "output_size", "num_params"])


# Engine

## Trainer

In [None]:
import os
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import *
from lightning.pytorch.loggers import TensorBoardLogger


In [None]:
class LightningTrainer:
    def __init__(self, model, hparams: dict, ckpt: Path = None, transform=None):
        self.hparams = hparams
        self.transform = transform if transform else DataTransform()
        seed_everything(seed=hparams["seed"], workers=True)

        # --- 모델 정의
        if ckpt:
            self.model = model.load_from_checkpoint(
                checkpoint_path=ckpt,
                map_location="cuda",
            )
            self.ckpt = ckpt
        else:
            self.model = model

        # --- DataModule 정의
        self.datamodule = self._build_datamodule()

        # --- 로깅 설정
        self.logger = self._build_logger()

        # --- 콜백 정의
        self.callbacks = self._build_callbacks()

        # --- Lightning Trainer 정의
        self.trainer = Trainer(
            max_epochs=hparams["epochs"],
            accelerator="gpu",
            devices=1,
            precision="32",
            logger=self.logger,
            callbacks=self.callbacks,
            log_every_n_steps=5,
        )

    def _build_datamodule(self):
        return CustomDataModule(
            train_dir=self.hparams["train_data_path"],
            valid_dir=self.hparams["valid_data_path"],
            infer_dir=self.hparams["infer_data_path"],
            bench_dir=self.hparams["bench_data_path"],
            transform=DataTransform(image_size=self.hparams["image_size"]),
            batch_size=self.hparams["batch_size"],
            num_workers=int(os.cpu_count() * 0.9),
        )

    def _build_logger(self):
        return TensorBoardLogger(
            save_dir=self.hparams["log_dir"],
            name=self.hparams["experiment_name"]
        )

    def _build_callbacks(self):
        return [
            ModelCheckpoint(
                monitor="valid/4_tot",
                save_top_k=1,
                mode="min",
                filename="best-{epoch:02d}",
            ),
            ModelCheckpoint(
                every_n_epochs=1,
                save_top_k=-1,  # 모두 저장
                filename="epoch-{epoch:02d}",
            ),
            EarlyStopping(
                monitor="valid/4_tot",
                patience=4,
                mode="min",
                verbose=True,
            ),
            LearningRateMonitor(logging_interval="step"),
            RichProgressBar(),
        ]

    def run(self):
        print("[INFO] Start Training...")
        self.trainer.fit(
            model=self.model,
            datamodule=self.datamodule
        )
        print("[INFO] Training Completed.")


## Validater

In [None]:
import os

from lightning import Trainer
from pathlib import Path
from tqdm.auto import tqdm


In [None]:
class LightningValidater:
    def __init__(self, model, trainer: Trainer, ckpt: Path, hparams: dict):
        self.hparams = hparams

        # --- 모델 정의
        if ckpt:
            self.model = model.load_from_checkpoint(
                checkpoint_path=ckpt,
                map_location="cuda",
            )
            self.ckpt = ckpt
        else:
            self.model = model
            self.ckpt = "best"

        # --- Lightning Trainer 정의
        self.trainer = trainer

        # --- DataModule 정의
        self.datamodule = self._build_datamodule()

    def _build_datamodule(self):
        return CustomDataModule(
            train_dir=self.hparams["train_data_path"],
            valid_dir=self.hparams["valid_data_path"],
            infer_dir=self.hparams["infer_data_path"],
            bench_dir=self.hparams["bench_data_path"],
            transform=DataTransform(image_size=self.hparams["image_size"]),
            batch_size=self.hparams["batch_size"],
            num_workers=int(os.cpu_count() * 0.9),
        )

    def run(self):
        print("[INFO] Start Validating...")
        results = self.trainer.validate(
            model=self.model,
            datamodule=self.datamodule,
            ckpt_path=self.ckpt
        )
        print("[VALIDATION RESULT]")
        for res in tqdm(results):
            print(res)

        print("[INFO] Validation Completed.")


## Inferencer

In [None]:
import os
import yaml

from lightning import Trainer
from pathlib import Path


In [None]:
class LightningInferencer:
    def __init__(self, model, trainer: Trainer, ckpt: Path, hparams: Path):
        with open(file=hparams) as f:
            hparams = yaml.load(stream=f, Loader=yaml.FullLoader)
        self.hparams = hparams

        if ckpt:
            self.model = model.load_from_checkpoint(
                checkpoint_path=str(object=ckpt),
                map_location="cuda",
            )
            self.ckpt = ckpt
        else:
            self.model = model
            self.ckpt = "best"

        # --- Lightning Trainer 정의
        self.trainer = trainer

        # --- DataModule 정의
        self.datamodule = self._build_datamodule()

        self.save_dir = ckpt.parents[1] / hparams["inference"]
        print(f"save_dir: {self.save_dir}")

    def _build_datamodule(self):
        return CustomDataModule(
            train_dir=self.hparams["train_data_path"],
            valid_dir=self.hparams["valid_data_path"],
            infer_dir=self.hparams["infer_data_path"],
            bench_dir=self.hparams["bench_data_path"],
            transform=DataTransform(image_size=self.hparams["image_size"]),
            batch_size=self.hparams["batch_size"],
            num_workers=int(os.cpu_count() * 0.9),
        )

    def run(self):
        print("[INFO] Start Inferencing...")
        results = self.trainer.predict(
            model=self.model,
            datamodule=self.datamodule,
            ckpt_path=self.ckpt
        )
        save_images(results=results, save_dir=self.save_dir)
        print("[INFO] Inference Completed.")


## Benchmarker

In [None]:
import os

from lightning import Trainer
from pathlib import Path
from tqdm.auto import tqdm


In [None]:
class LightningBenchmarker:
    def __init__(self, model, trainer: Trainer, ckpt: Path, hparams: dict):
        self.hparams = hparams

        if ckpt:
            self.model = model.load_from_checkpoint(
                checkpoint_path=ckpt,
                map_location="cuda",
            )
            self.ckpt = ckpt
        else:
            self.model = model
            self.ckpt = "best"

        # --- Lightning Trainer 정의
        self.trainer = trainer

        # --- DataModule 정의
        self.datamodule = self._build_datamodule()

        self.save_dir = ckpt.parents[1] / hparams["benchmark"]
        print(f"save_dir: {self.save_dir}")

        # --- 평가 메트릭 정의
        self.metric = ImageQualityMetrics(device="cuda")
        self.metric.eval()

    def _build_datamodule(self):
        datamodule = CustomDataModule(
            train_dir=self.hparams["train_data_path"],
            valid_dir=self.hparams["valid_data_path"],
            infer_dir=self.hparams["infer_data_path"],
            bench_dir=self.hparams["bench_data_path"],
            transform=DataTransform(image_size=self.hparams["image_size"]),
            batch_size=self.hparams["batch_size"],
            num_workers=int(os.cpu_count() * 0.9),
        )
        datamodule.setup()  # 벤치마크 데이터셋 사용 위해 미리 세팅
        return datamodule

    def run(self):
        print("[INFO] Start benchmarking")

        results = self.trainer.test(
            model=self.model,
            datamodule=self.datamodule,
            ckpt_path=self.ckpt
        )
        print("results", results)
        print("\n[FINAL BENCHMARK RESULT]")
        for i, datasets in tqdm(enumerate(iterable=results)):
            for k, v in datasets.items():
                print(f"{k}: {v:.4f}")

        save_images(results=results, save_dir=self.save_dir)
        print("[INFO] Benchmark Completed.")


# Model

## Losses

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class L_col(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        mean_rgb = torch.mean(
            input=x,
            dim=[2, 3],
            keepdim=True
        )
        mr, mg, mb = torch.split(
            tensor=mean_rgb,
            split_size_or_sections=1,
            dim=1
        )

        Drg = torch.pow(input=mr - mg, exponent=2)
        Drb = torch.pow(input=mr - mb, exponent=2)
        Dgb = torch.pow(input=mb - mg, exponent=2)

        c = torch.pow(
            input=torch.pow(input=Drg, exponent=2) +
            torch.pow(input=Drb, exponent=2) +
            torch.pow(input=Dgb, exponent=2),
            exponent=0.5
        )
        return c


class L_spa(nn.Module):
    def __init__(self):
        super().__init__()
        kernel_l = torch.FloatTensor([
            [0, 0, 0],
            [-1, 1, 0],
            [0, 0, 0]
        ]).cuda().unsqueeze(dim=0).unsqueeze(dim=0)
        kernel_r = torch.FloatTensor([
            [0, 0, 0],
            [0, 1, -1],
            [0, 0, 0]
        ]).cuda().unsqueeze(dim=0).unsqueeze(dim=0)
        kernel_u = torch.FloatTensor([
            [0, -1, 0],
            [0, 1, 0],
            [0, 0, 0]
        ]).cuda().unsqueeze(dim=0).unsqueeze(dim=0)
        kernel_d = torch.FloatTensor([
            [0, 0, 0],
            [0, 1, 0],
            [0, -1, 0]
        ]).cuda().unsqueeze(dim=0).unsqueeze(dim=0)

        self.weight_l = nn.Parameter(data=kernel_l, requires_grad=False)
        self.weight_r = nn.Parameter(data=kernel_r, requires_grad=False)
        self.weight_u = nn.Parameter(data=kernel_u, requires_grad=False)
        self.weight_d = nn.Parameter(data=kernel_d, requires_grad=False)
        self.pool = nn.AvgPool2d(kernel_size=4)

    def forward(self, org, enh):
        org_mean = torch.mean(input=org, dim=1, keepdim=True)
        enh_mean = torch.mean(input=enh, dim=1, keepdim=True)

        org_pool = self.pool(org_mean)
        enh_pool = self.pool(enh_mean)

        D_org_l = F.conv2d(input=org_pool, weight=self.weight_l, padding=1)
        D_org_r = F.conv2d(input=org_pool, weight=self.weight_r, padding=1)
        D_org_u = F.conv2d(input=org_pool, weight=self.weight_u, padding=1)
        D_org_d = F.conv2d(input=org_pool, weight=self.weight_d, padding=1)

        D_enh_l = F.conv2d(input=enh_pool, weight=self.weight_l, padding=1)
        D_enh_r = F.conv2d(input=enh_pool, weight=self.weight_r, padding=1)
        D_enh_u = F.conv2d(input=enh_pool, weight=self.weight_u, padding=1)
        D_enh_d = F.conv2d(input=enh_pool, weight=self.weight_d, padding=1)

        D_l = torch.pow(input=D_org_l - D_enh_l, exponent=2)
        D_r = torch.pow(input=D_org_r - D_enh_r, exponent=2)
        D_u = torch.pow(input=D_org_u - D_enh_u, exponent=2)
        D_d = torch.pow(input=D_org_d - D_enh_d, exponent=2)

        s = (D_l + D_r + D_u + D_d)
        return s


class L_exp(nn.Module):
    def __init__(self, patch_size=16, mean_val=0.8):
        super().__init__()
        self.pool = nn.AvgPool2d(kernel_size=patch_size)
        self.mean_val = mean_val

    def forward(self, x):
        x = torch.mean(input=x, dim=1, keepdim=True)
        mean = self.pool(x)

        e = torch.mean(
            input=torch.pow(
                input=mean - torch.FloatTensor(
                    [self.mean_val]
                ).cuda(),
                exponent=2
            )
        )
        return e


## Block

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


In [None]:
class RGB2YCrCb(nn.Module):
    def __init__(self, offset=0.5):
        super().__init__()
        self.offset = offset
        self.register_buffer(
            name='weights',
            tensor=torch.tensor(
                data=[
                    [0.299,  0.587,  0.114],   # Y
                    [0.713, -0.713,  0.000],   # Cr
                    [0.000, -0.564,  0.564],   # Cb
                ],
                dtype=torch.float32
            )
        )

    def forward(self, x):
        out = torch.einsum('bchw,oc->bohw', x, self.weights)

        Y = out[:, 0:1, :, :]  # (B,1,H,W)
        Cr = out[:, 1:2, :, :] + self.offset  # (B,1,H,W)
        Cb = out[:, 2:3, :, :] + self.offset  # (B,1,H,W)
        return Y, Cr, Cb


class YCrCb2RGB(nn.Module):
    def __init__(self, offset=0.5,):
        super().__init__()
        self.offset = offset
        self.register_buffer(
            name='weights',
            tensor=torch.tensor(
                data=[
                    [1.000, 1.403, 0.000],
                    [1.000, -0.714, -0.344],
                    [1.000, 0.000, 1.773]
                ],
                dtype=torch.float32
            )
        )

    def forward(self, Y, Cr, Cb):
        Cr = Cr - self.offset
        Cb = Cb - self.offset

        inputs = torch.cat(tensors=[Y, Cr, Cb], dim=1)  # (B,3,H,W)

        x = torch.einsum('bchw,oc->bohw', inputs, self.weights)
        return x


class HomomorphicSeparation(nn.Module):
    def __init__(self, size=256, init_cutoff=0.1, eps=1e-6):
        super().__init__()
        self.size = size
        self.eps = eps

        init_p = float(init_cutoff)
        raw_init = torch.log(input=torch.tensor(data=init_p / (1.0 - init_p)))
        self.raw_cutoff = nn.Parameter(data=raw_init)

        coord = torch.linspace(start=-1, end=1, steps=size)
        y, x = torch.meshgrid(coord, coord, indexing='ij')
        d = torch.sqrt(input=x**2 + y**2)  # (H, W)
        self.register_buffer(name='d', tensor=d)

    def forward(self, x):
        B, C, H, W = x.shape
        cutoff = torch.sigmoid(input=self.raw_cutoff)

        mask2d = torch.exp(input=-(self.d ** 2) / (2 * (cutoff ** 2)))
        mask = mask2d.unsqueeze(dim=0).expand(B, H, W)

        # 1. log 변환
        x_log = torch.log(input=x + self.eps).squeeze(dim=1)    # (B, 1, H, W)

        # 2. FFT
        x_fft = torch.fft.fft2(x_log)  # (B, H, W)

        # 3. Low-pass / High-pass 분리
        low_fft = x_fft * mask
        high_fft = x_fft * (1 - mask)

        # 4. IFFT 후 real 값 추출
        low_spatial = torch.real(
            input=torch.fft.ifft2(
                low_fft
            )
        )  # (B,1,H,W)
        high_spatial = torch.real(
            input=torch.fft.ifft2(
                high_fft
            )
        )  # (B,1,H,W)

        # 5. exp 복원
        illumination = torch.exp(input=low_spatial).unsqueeze(dim=1)
        detail = torch.exp(input=high_spatial).unsqueeze(dim=1)

        return illumination, detail


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.SiLU(),
            nn.Conv2d(
                in_channels=self.out_channels,
                out_channels=self.out_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.SiLU(),
            nn.Dropout2d(p=0.2)
        )

    def forward(self, x):
        return self.block(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(
                kernel_size=2
            ),
            DoubleConv(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
            )
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.up = nn.ConvTranspose2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=2, stride=2
        )

        self.conv = DoubleConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels
        )

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(
            input=x1,
            pad=[
                diffX // 2, diffX - diffX // 2,
                diffY // 2, diffY - diffY // 2
            ]
        )
        x = torch.cat(tensors=[x2, x1], dim=1)
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.inc = DoubleConv(in_channels=self.in_channels, out_channels=64)
        self.down1 = Down(in_channels=64, out_channels=128)
        self.down2 = Down(in_channels=128, out_channels=256)
        self.down3 = Down(in_channels=256, out_channels=512)
        self.down4 = Down(in_channels=512, out_channels=1024)
        self.up1 = Up(in_channels=1024, out_channels=512)
        self.up2 = Up(in_channels=512, out_channels=256)
        self.up3 = Up(in_channels=256, out_channels=128)
        self.up4 = Up(in_channels=128, out_channels=64)
        self.outc = self.outc = nn.Sequential(
            DoubleConv(in_channels=64, out_channels=64),
            nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x_i = self.inc(x)           # 64
        d_1 = self.down1(x_i)       # 128
        d_2 = self.down2(d_1)       # 256
        d_3 = self.down3(d_2)       # 512
        d_4 = self.down4(d_3)       # 1024
        u_4 = self.up1(d_4, d_3)    # 512
        u_3 = self.up2(u_4, d_2)    # 256
        u_2 = self.up3(u_3, d_1)    # 128
        u_1 = self.up4(u_2, x_i)    # 64
        x_o = self.outc(u_1)
        return x_o


In [None]:
class ResidualConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.SiLU(),
            nn.Conv2d(
                in_channels=self.out_channels,
                out_channels=self.out_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm2d(num_features=out_channels),
        )

        if self.in_channels != self.out_channels:
            self.residual_conv = nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=1
            )
        else:
            self.residual_conv = nn.Identity()

        self.relu = nn.SiLU()

    def forward(self, x):
        identity = self.residual_conv(x)
        out = self.block(x)
        return self.relu(out + identity)


class Resnet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.rc1 = ResidualConv(
            in_channels=self.in_channels,
            out_channels=256
        )
        self.rc2 = ResidualConv(
            in_channels=256,
            out_channels=256
        )
        self.rc3 = ResidualConv(
            in_channels=256,
            out_channels=256
        )
        self.rc4 = ResidualConv(
            in_channels=256,
            out_channels=512
        )
        self.rc5 = ResidualConv(
            in_channels=512,
            out_channels=512
        )
        self.rc6 = ResidualConv(
            in_channels=512,
            out_channels=512
        )
        self.rc7 = ResidualConv(
            in_channels=512,
            out_channels=self.out_channels
        )

    def forward(self, x):
        x = self.rc1(x)
        x = self.rc2(x)
        x = self.rc3(x)
        x = self.rc4(x)
        x = self.rc5(x)
        x = self.rc6(x)
        x = self.rc7(x)
        return x


## Model

In [None]:
import torch
import lightning as L

from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
from transformers.optimization import get_cosine_schedule_with_warmup


In [None]:
class HomomorphicUnet(nn.Module):
    def __init__(self, image_size, in_channels, out_channels, offset, init_cutoff):
        super().__init__()
        self.image_size = image_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.offset = offset
        self.init_cutoff = init_cutoff

        self.rgb2ycrcb = RGB2YCrCb(
            offset=self.offset
        )
        self.homo_separate = HomomorphicSeparation(
            size=self.image_size,
            init_cutoff=self.init_cutoff
        )
        self.unet = UNet(
            in_channels=self.in_channels,
            out_channels=self.out_channels
        )
        self.resnet = Resnet(
            in_channels=self.in_channels,
            out_channels=self.out_channels
        )
        self.ycrcb2rgb = YCrCb2RGB(
            offset=self.offset
        )

    def forward(self, x):
        Y, Cr, Cb = self.rgb2ycrcb(x)
        x_i, x_d = self.homo_separate(Y)
        n_i = self.unet(x_i)
        n_i = self.resnet(n_i)
        n_Y = torch.clamp(input=n_i * x_d, min=0, max=1)
        enh_img = self.ycrcb2rgb(n_Y, Cr, Cb)
        return enh_img, n_i, x_i, x_d


## Lightning

In [None]:
class HomomorphicUnetLightning(L.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.model = HomomorphicUnet(
            image_size=hparams['image_size'],
            in_channels=hparams['in_channels'],
            out_channels=hparams['out_channels'],
            offset=hparams['offset'],
            init_cutoff=hparams["init_cutoff"],
        )

        self.spa_loss = L_spa()
        self.col_loss = L_col()
        self.exp_loss = L_exp()

        self.lambda_spa = hparams["lambda_spa"]
        self.lambda_col = hparams["lambda_col"]
        self.lambda_exp = hparams["lambda_exp"]

        self.metric = ImageQualityMetrics(device="cuda")
        self.metric.eval()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch.to(self.device)
        enh_img, n_i, x_i, x_d = self(x)

        loss_spa = self.lambda_spa * torch.mean(
            input=self.spa_loss(enh_img, x)
        )
        loss_col = self.lambda_col * torch.mean(
            input=self.col_loss(enh_img)
        )
        loss_exp = self.lambda_exp * torch.mean(
            input=self.exp_loss(enh_img)
        )

        total = (
            loss_spa +
            loss_col +
            loss_exp
        )

        self.log_dict(dictionary={
            "train/1_spa": loss_spa,
            "train/2_col": loss_col,
            "train/3_exp": loss_exp,
            "train/4_tot": total,
        }, prog_bar=True)

        if batch_idx % 100 == 0:
            self.logger.experiment.add_images(
                "train/1_input",
                x,
                self.global_step
            )
            self.logger.experiment.add_images(
                "train/2_x_i",
                x_i,
                self.global_step
            )
            self.logger.experiment.add_images(
                "train/3_x_d",
                x_d,
                self.global_step
            )
            self.logger.experiment.add_images(
                "train/4_enh_img",
                enh_img,
                self.global_step
            )
            self.logger.experiment.add_images(
                "train/5_n_i",
                n_i,
                self.global_step
            )
        return total

    def validation_step(self, batch, batch_idx):
        x = batch.to(self.device)
        enh_img, n_i, x_i, x_d = self(x)

        loss_spa = self.lambda_spa * torch.mean(
            input=self.spa_loss(enh_img, x)
        )
        loss_col = self.lambda_col * torch.mean(
            input=self.col_loss(enh_img)
        )
        loss_exp = self.lambda_exp * torch.mean(
            input=self.exp_loss(enh_img)
        )

        total = (
            loss_spa +
            loss_col +
            loss_exp
        )

        self.log_dict(dictionary={
            "valid/1_spa": loss_spa,
            "valid/2_col": loss_col,
            "valid/3_exp": loss_exp,
            "valid/4_tot": total,
        }, prog_bar=True)
        return total

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        enh_img, n_i, x_i, x_d = self(x)

        metrics = self.metric.full(preds=enh_img, targets=x)

        self.log_dict(dictionary={
            "bench/1_PSNR": metrics["PSNR"],
            "bench/2_SSIM": metrics["SSIM"],
            "bench/3_LPIPS": metrics["LPIPS"],
            "bench/4_NIQE": metrics["NIQE"],
            "bench/5_BRISQUE": metrics["BRISQUE"],
        }, prog_bar=True)
        return metrics

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        enh_img, n_i, x_i, x_d = self(x)
        return enh_img

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams['lr'],
            weight_decay=self.hparams['decay'],
        )

        scheduler = CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=10,           # 첫 번째 주기의 epoch 수
            T_mult=2,         # 이후 주기의 길이 배수
            eta_min=1e-7      # 최소 learning rate
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",   # 매 epoch마다 업데이트
                "frequency": 1,
            }
        }


# main

In [None]:
import os
import random
import torch

from pathlib import Path


## hparams

In [None]:
def get_hparams():
    hparams = {
        # 모델 구조
        "image_size": 256,
        "in_channels": 1,
        "out_channels": 1,
        "offset": 0.5,
        "init_cutoff": 0.1,

        # 손실 함수 가중치 (losses.py 기준)
        "lambda_col": 10.0,
        "lambda_exp": 1,
        "lambda_spa": 100.0,

        # 최적화 및 학습 설정
        "lr": 1e-4,
        "decay": 1e-5,
        "epochs": 10,
        "batch_size": 16,
        "seed": 290,

        # 데이터 경로
        "train_data_path": "data/1_train",
        "valid_data_path": "data/2_valid",
        "bench_data_path": "data/3_bench",
        "infer_data_path": "data/4_infer",

        # 로깅 설정
        "log_dir": "./runs/HomomorphicUResnet",
        "experiment_name": "cutoff_parameter",
        "inference": "inference",
        "benchmark": "benchmark",
    }
    return hparams


## main

In [None]:
def main():
    global hparams, model_class, lightning_trainer

    model_class = HomomorphicUnetLightning
    hparams = get_hparams()

    print("[RUNNING] Trainer...")
    trainer = LightningTrainer(
        model=model_class(hparams=hparams),
        hparams=hparams
    )
    lightning_trainer = trainer.trainer
    trainer.run()


## train

In [None]:
main()


[RUNNING] Trainer...


Seed set to 290
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


[INFO] Start Training...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Output()

RuntimeError: NVML_SUCCESS == DriverAPI::get()->nvmlInit_v2_() INTERNAL ASSERT FAILED at "/pytorch/c10/cuda/CUDACachingAllocator.cpp":983, please report a bug to PyTorch. 

## inference

In [None]:
path = Path(f"{lightning_trainer.log_dir}")
ckpts = path.glob(pattern="checkpoints/best*.ckpt")
hparams = path.glob(pattern="hparams.yaml")

for ckpt, hparam in zip(ckpts, hparams):
    print("[RUNNING] Inferencer...")
    inferencer = LightningInferencer(
        model=model_class,
        trainer=lightning_trainer,
        ckpt=ckpt,
        hparams=hparam,
    )
    inferencer.run()


[RUNNING] Inferencer...
save_dir: runs/HomomorphicUnet/base/version_0/inference
[INFO] Start training...


Restoring states from the checkpoint path at runs/HomomorphicUnet/base/version_0/checkpoints/best-epoch=03.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from the checkpoint at runs/HomomorphicUnet/base/version_0/checkpoints/best-epoch=03.ckpt


Output()

[INFO] Inference completed.


In [None]:
for ckpt, hparam in zip(ckpts, hparams):
    print("[RUNNING] Inferencer...")
    inferencer = LightningBenchmarker(
        model=model_class,
        trainer=lightning_trainer,
        ckpt=ckpt,
        hparams=hparam,
    )
    inferencer.run()
