# data

## dataloader

In [None]:
import lightning as L

from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader, ConcatDataset


In [None]:
class CustomDataset(Dataset):
    def __init__(self, path, transform):
        super().__init__()
        self.path = Path(path)
        self.transform = transform
        self.data = list(self.path.glob(pattern='*.*'))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = Image.open(fp=self.data[index]).convert(mode="RGB")
        return self.transform(image)


In [None]:
class CustomDataModule(L.LightningDataModule):
    def __init__(self, train_dir, valid_dir, infer_dir, bench_dir, transform, batch_size=32, num_workers=4):
        super().__init__()
        self.train_dir = Path(train_dir)
        self.valid_dir = Path(valid_dir)
        self.infer_dir = Path(infer_dir)
        self.bench_dir = Path(bench_dir)

        self.transform = transform
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_datasets = self._set_dataset(path=self.train_dir)
        self.valid_datasets = self._set_dataset(path=self.valid_dir)
        self.bench_datasets = self._set_dataset(path=self.bench_dir)
        self.infer_datasets = self._set_dataset(path=self.infer_dir)

    def _set_dataset(self, path):
        datasets = []
        for folder in path.iterdir():
            if folder.is_dir():
                datasets.append(
                    CustomDataset(
                        path=folder,
                        transform=self.transform
                    )
                )
        return datasets

    def _set_dataloader(self, datasets, concat=False, shuffle=False):
        if concat:
            dataloader = DataLoader(
                dataset=ConcatDataset(datasets=datasets),
                batch_size=self.batch_size,
                shuffle=shuffle,
                num_workers=self.num_workers,
                pin_memory=True
            )
            return dataloader
        else:
            dataloaders = []
            for dataset in datasets:
                loader = DataLoader(
                    dataset=dataset,
                    batch_size=self.batch_size,
                    shuffle=shuffle,
                    num_workers=self.num_workers,
                    pin_memory=True
                )
                dataloaders.append(loader)
            return dataloaders

    def train_dataloader(self):
        return self._set_dataloader(datasets=self.train_datasets, concat=True, shuffle=True)

    def val_dataloader(self):
        return self._set_dataloader(datasets=self.valid_datasets, concat=True)

    def test_dataloader(self):
        return self._set_dataloader(datasets=self.bench_datasets)

    def predict_dataloader(self):
        return self._set_dataloader(datasets=self.infer_datasets)


## utils

In [None]:
from torchvision import transforms


In [None]:
class DataTransform:
    def __init__(self, image_size=256):
        self.image_size = image_size

        self.transform = self._build_transform()

    def _build_transform(self):
        base = [
            transforms.Resize(size=(self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Lambda(lambd=lambda x: x.float()),
        ]

        return transforms.Compose(transforms=base)

    def __call__(self, image):
        return self.transform(img=image)


# utils

## metrics

In [None]:
import cv2
import math
import numpy as np
import torch
import torch.nn as nn

from scipy.ndimage import convolve
from scipy.special import gamma
from torchmetrics.image import (
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure,
    LearnedPerceptualImagePatchSimilarity,
)


In [None]:
class ImageQualityMetrics(nn.Module):
    def __init__(self, device="cuda", data_range=1.0):
        super().__init__()

        self.device_type = device

        # reference-based metrics
        self.psnr = PeakSignalNoiseRatio(
            data_range=data_range).to(device=device)
        self.ssim = StructuralSimilarityIndexMeasure(
            data_range=data_range).to(device=device)
        self.lpips = LearnedPerceptualImagePatchSimilarity(
            net_type='squeeze').to(device=device)

        # for NIQE (dummy pristine dist)
        self.niqe_stats = np.load(
            file="utils/files/niqe_params.npz", allow_pickle=True)
        self.mu_pris_param = self.niqe_stats['mu_pris_param']
        self.cov_pris_param = self.niqe_stats['cov_pris_param']
        self.gaussian_window = self.niqe_stats['gaussian_window']

    def forward(self, preds: torch.Tensor, targets: torch.Tensor):
        preds = preds.to(device=self.device_type)
        targets = targets.to(device=self.device_type)

        return {
            "PSNR": self.psnr(preds, targets).item(),
            "SSIM": self.ssim(preds, targets).item(),
            "LPIPS": self.lpips(preds, targets).squeeze().mean().item(),
        }

    def no_ref(self, preds: torch.Tensor):
        preds = preds.to(device=self.device_type)
        preds_np = preds.detach().cpu().numpy()
        preds_np = np.clip(a=preds_np, a_min=0, a_max=1)

        niqe_list = []
        brisque_list = []

        for img in preds_np:
            # (C, H, W) → (H, W, C)
            img_np = np.transpose(a=img, axes=(1, 2, 0))
            img_np_uint8 = (img_np * 255).astype(dtype=np.uint8)

            niqe = self._compute_niqe(img_np=img_np)
            brisuqe = self._compute_brisque(img=img_np_uint8)

            niqe_list.append(niqe)
            brisque_list.append(brisuqe)

        return {
            "NIQE": float(x=np.mean(a=niqe_list)),
            "BRISQUE": float(x=np.mean(a=brisque_list)),
        }

    def full(self, preds, targets):
        ref_metrics = self.forward(preds=preds, targets=targets)
        no_ref_metrics = self.no_ref(preds=preds)
        return {**ref_metrics, **no_ref_metrics}

    # --------------------------
    # Custom no-reference metrics
    # --------------------------

    def _compute_niqe(self, img_np):
        img = img_np.astype(np.float32)
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        gray = gray.round()
        return self._niqe(img=gray)

    def _compute_brisque(self, img):
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        brisque_score = cv2.quality.QualityBRISQUE_compute(
            img, "utils/files/brisque_model.yaml", "utils/files/brisque_range.yaml")
        return brisque_score

    def _niqe(self, img, block_size_h=96, block_size_w=96):
        assert img.ndim == 2
        h, w = img.shape
        num_block_h = math.floor(h / block_size_h)
        num_block_w = math.floor(w / block_size_w)
        img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]

        distparam = []
        for scale in (1, 2):
            mu = convolve(
                input=img, weights=self.gaussian_window, mode='nearest')
            sigma = np.sqrt(np.abs(convolve(
                input=np.square(img), weights=self.gaussian_window, mode='nearest') - np.square(mu)))
            img_nomalized = (img - mu) / (sigma + 1)

            feat = []
            for idx_w in range(num_block_w):
                for idx_h in range(num_block_h):
                    block = img_nomalized[
                        idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
                        idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale
                    ]
                    feat.append(self._compute_feature(block=block))

            distparam.append(np.array(object=feat))

            if scale == 1:
                img = cv2.resize(img / 255., dsize=(0, 0), fx=0.5,
                                 fy=0.5, interpolation=cv2.INTER_CUBIC) * 255.

        distparam = np.concatenate(distparam, axis=1)
        mu_distparam = np.nanmean(a=distparam, axis=0)
        distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
        cov_distparam = np.cov(m=distparam_no_nan, rowvar=False)

        invcov_param = np.linalg.pinv(
            (self.cov_pris_param + cov_distparam) / 2)
        quality = np.matmul(
            np.matmul((self.mu_pris_param - mu_distparam), invcov_param),
            np.transpose(a=(self.mu_pris_param - mu_distparam))
        )
        return float(x=np.sqrt(quality))

    def _compute_feature(self, block):
        def estimate_aggd_param(block):
            block = block.flatten()
            gam = np.arange(start=0.2, stop=10.001, step=0.001)
            gam_reciprocal = np.reciprocal(gam)
            r_gam = np.square(gamma(gam_reciprocal * 2)) / (
                gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))

            left_std = np.sqrt(np.mean(block[block < 0]**2))
            right_std = np.sqrt(np.mean(block[block > 0]**2))
            gammahat = left_std / right_std
            rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
            rhatnorm = (rhat * (gammahat**3 + 1) *
                        (gammahat + 1)) / ((gammahat**2 + 1)**2)
            array_position = np.argmin(a=(r_gam - rhatnorm)**2)
            alpha = gam[array_position]
            beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
            beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
            return (alpha, beta_l, beta_r)

        feat = []
        alpha, beta_l, beta_r = estimate_aggd_param(block=block)
        feat.extend([alpha, (beta_l + beta_r) / 2])
        shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
        for shift in shifts:
            shifted = np.roll(a=block, shift=shift, axis=(0, 1))
            alpha, beta_l, beta_r = estimate_aggd_param(block=block * shifted)
            mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
            feat.extend([alpha, mean, beta_l, beta_r])
        return feat


## utils

In [None]:
from pathlib import Path
from torchvision.utils import save_image
from torchinfo import summary


In [None]:
def make_dirs(path: str | Path):
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    return path


def print_metrics(metrics: dict, prefix: str = ""):
    for k, v in metrics.items():
        print(f"{prefix}{k}: {v:.4f}")


def save_images(results, save_dir, prefix="infer", ext="png"):
    for i, datasets in enumerate(iterable=results):
        save_path = make_dirs(path=f"{save_dir}/batch{i+1}")
        for ii, batch in enumerate(iterable=datasets):
            save_image(
                tensor=batch,
                fp=save_path / f"{prefix}_{ii:04d}.{ext}",
                nrow=8,
                padding=2,
                normalize=True,
                value_range=(0, 1)
            )


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def summarize_model(model, input_size):
    return summary(model=model, input_size=input_size, depth=3, col_names=["input_size", "output_size", "num_params"])


# engine

## trainer

In [None]:
import os
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import *
from lightning.pytorch.loggers import TensorBoardLogger


In [None]:
class LightningTrainer:
    def __init__(self, model, hparams: dict, ckpt: Path = None, transform=None):
        self.hparams = hparams
        self.transform = transform if transform else DataTransform()
        seed_everything(seed=hparams["seed"], workers=True)

        # --- 모델 정의
        if ckpt:
            self.model = model.load_from_checkpoint(
                checkpoint_path=ckpt,
                map_location="cuda",
            )
            self.ckpt = ckpt
        else:
            self.model = model

        # --- DataModule 정의
        self.datamodule = self._build_datamodule()

        # --- 로깅 설정
        self.logger = self._build_logger()

        # --- 콜백 정의
        self.callbacks = self._build_callbacks()

        # --- Lightning Trainer 정의
        self.trainer = Trainer(
            max_epochs=hparams["epochs"],
            accelerator="gpu",
            devices=1,
            precision="16",
            logger=self.logger,
            callbacks=self.callbacks,
            log_every_n_steps=5,
        )

    def _build_datamodule(self):
        return CustomDataModule(
            train_dir=self.hparams["train_data_path"],
            valid_dir=self.hparams["valid_data_path"],
            infer_dir=self.hparams["infer_data_path"],
            bench_dir=self.hparams["bench_data_path"],
            transform=DataTransform(image_size=self.hparams["image_size"]),
            batch_size=self.hparams["batch_size"],
            num_workers=int(os.cpu_count() * 0.9),
        )

    def _build_logger(self):
        return TensorBoardLogger(
            save_dir=self.hparams["log_dir"],
            name=self.hparams["experiment_name"]
        )

    def _build_callbacks(self):
        return [
            ModelCheckpoint(
                monitor="valid/06_total",
                save_top_k=1,
                mode="min",
                filename="best-{epoch:02d}",
            ),
            ModelCheckpoint(
                every_n_epochs=10,
                save_top_k=-1,  # 모두 저장
                filename="epoch-{epoch:02d}",
            ),
            EarlyStopping(
                monitor="valid/06_total",
                patience=10,
                mode="min",
                verbose=True,
            ),
            LearningRateMonitor(logging_interval="step"),
            RichProgressBar(),
        ]

    def run(self):
        print("[INFO] Start training...")
        self.trainer.fit(
            model=self.model,
            datamodule=self.datamodule
        )
        print("[INFO] Training completed!")


## validater

In [None]:
import os

from lightning import Trainer
from pathlib import Path
from tqdm.auto import tqdm


In [None]:
class LightningValidater:
    def __init__(self, model, trainer: Trainer, ckpt: Path, hparams: dict):
        self.hparams = hparams

        # --- 모델 정의
        if ckpt:
            self.model = model.load_from_checkpoint(
                checkpoint_path=ckpt,
                map_location="cuda",
            )
            self.ckpt = ckpt
        else:
            self.model = model
            self.ckpt = "best"

        # --- Lightning Trainer 정의
        self.trainer = trainer

        # --- DataModule 정의
        self.datamodule = self._build_datamodule()

    def _build_datamodule(self):
        return CustomDataModule(
            train_dir=self.hparams["train_data_path"],
            valid_dir=self.hparams["valid_data_path"],
            infer_dir=self.hparams["infer_data_path"],
            bench_dir=self.hparams["bench_data_path"],
            transform=DataTransform(image_size=self.hparams["image_size"]),
            batch_size=self.hparams["batch_size"],
            num_workers=int(os.cpu_count() * 0.9),
        )

    def run(self):
        print("[INFO] Start validating...")
        results = self.trainer.validate(
            model=self.model,
            datamodule=self.datamodule,
            ckpt_path=self.ckpt
        )
        print("[VALIDATION RESULT]")
        for res in tqdm(results):
            print(res)


## inferencer

In [None]:
import os
import yaml

from lightning import Trainer
from pathlib import Path


In [None]:
class LightningInferencer:
    def __init__(self, model, trainer: Trainer, ckpt: Path, hparams: Path):
        with open(file=hparams) as f:
            hparams = yaml.load(stream=f, Loader=yaml.FullLoader)
        self.hparams = hparams

        if ckpt:
            self.model = model.load_from_checkpoint(
                checkpoint_path=str(object=ckpt),
                map_location="cuda",
            )
            self.ckpt = ckpt
        else:
            self.model = model
            self.ckpt = "best"

        # --- Lightning Trainer 정의
        self.trainer = trainer

        # --- DataModule 정의
        self.datamodule = self._build_datamodule()

        self.save_dir = ckpt.parents[1] / hparams["inference"]
        print(f"save_dir: {self.save_dir}")

    def _build_datamodule(self):
        return CustomDataModule(
            train_dir=self.hparams["train_data_path"],
            valid_dir=self.hparams["valid_data_path"],
            infer_dir=self.hparams["infer_data_path"],
            bench_dir=self.hparams["bench_data_path"],
            transform=DataTransform(image_size=self.hparams["image_size"]),
            batch_size=self.hparams["batch_size"],
            num_workers=int(os.cpu_count() * 0.9),
        )

    def run(self):
        print("[INFO] Start training...")
        results = self.trainer.predict(
            model=self.model,
            datamodule=self.datamodule,
            ckpt_path=self.ckpt
        )
        save_images(results=results, save_dir=self.save_dir)
        print("[INFO] Inference completed.")


## benchmarker

In [None]:
import os

from lightning import Trainer
from pathlib import Path
from tqdm.auto import tqdm


In [None]:
class LightningBenchmarker:
    def __init__(self, model, trainer: Trainer, ckpt: Path, hparams: dict):
        self.hparams = hparams

        if ckpt:
            self.model = model.load_from_checkpoint(
                checkpoint_path=ckpt,
                map_location="cuda",
            )
            self.ckpt = ckpt
        else:
            self.model = model
            self.ckpt = "best"

        # --- Lightning Trainer 정의
        self.trainer = trainer

        # --- DataModule 정의
        self.datamodule = self._build_datamodule()

        # --- 평가 메트릭 정의
        self.metric = ImageQualityMetrics(device="cuda")
        self.metric.eval()

    def _build_datamodule(self):
        datamodule = CustomDataModule(
            train_dir=self.hparams["train_data_path"],
            valid_dir=self.hparams["valid_data_path"],
            infer_dir=self.hparams["infer_data_path"],
            bench_dir=self.hparams["bench_data_path"],
            transform=DataTransform(image_size=self.hparams["image_size"]),
            batch_size=self.hparams["batch_size"],
            num_workers=int(os.cpu_count() * 0.9),
        )
        datamodule.setup()  # 벤치마크 데이터셋 사용 위해 미리 세팅
        return datamodule

    def run(self):
        print("[INFO] Start benchmarking image quality metrics...")

        outputs = self.trainer.test(
            model=self.model,
            datamodule=self.datamodule,
            ckpt_path=self.ckpt
        )
        print("outputs", outputs)
        print("\n[FINAL BENCHMARK RESULT]")
        for k, v in tqdm(outputs.items()):
            print(f"{k}: {v:.4f}")


# model

## block

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


### study

In [None]:
class RGB2YCrCb(nn.Module):
    def __init__(self, offset=0.5, ):
        super().__init__()
        self.offset = offset
        self.register_buffer(
            name='weights',
            tensor=torch.tensor(
                data=[
                    [0.299,  0.587,  0.114],   # Y
                    [0.713, -0.713,  0.000],   # Cr
                    [0.000, -0.564,  0.564],   # Cb
                ],
                dtype=torch.float32
            )
        )

    def forward(self, x):
        out = torch.einsum('bchw,oc->bohw', x, self.weights)

        Y = out[:, 0:1, :, :]  # (B,1,H,W)
        Cr = out[:, 1:2, :, :] + self.offset  # (B,1,H,W)
        Cb = out[:, 2:3, :, :] + self.offset  # (B,1,H,W)
        return Y, Cr, Cb


class YCrCb2RGB(nn.Module):
    def __init__(self, offset=0.5,):
        super().__init__()
        self.offset = offset
        self.register_buffer(
            name='weights',
            tensor=torch.tensor(
                data=[
                    [1.000, 1.403, 0.000],
                    [1.000, -0.714, -0.344],
                    [1.000, 0.000, 1.773]
                ],
                dtype=torch.float32
            )
        )

    def forward(self, Y, Cr, Cb):
        Cr = Cr - self.offset
        Cb = Cb - self.offset

        inputs = torch.cat(tensors=[Y, Cr, Cb], dim=1)  # (B,3,H,W)

        rgb = torch.einsum('bchw,oc->bohw', inputs, self.weights)
        return rgb


class HomomorphicSeparation(nn.Module):
    def __init__(self, size=128, cutoff=0.1, eps=1e-6):
        super().__init__()
        self.size = size
        self.cutoff = cutoff
        self.eps = eps

        self.register_buffer(
            name='filter_mask',
            tensor=self._build_gaussian_low_pass_filter(
                size=self.size,
                cutoff=self.cutoff
            )
        )

    def _build_gaussian_low_pass_filter(self, size, cutoff):
        coord = torch.linspace(start=-1, end=1, steps=size)
        y, x = torch.meshgrid(coord, coord, indexing='ij')
        d = torch.sqrt(input=x ** 2 + y ** 2)
        filter = torch.exp(input=-(d**2) / (2*cutoff**2))
        return filter

    def forward(self, x):
        B, C, H, W = x.shape

        # 1. log 변환
        x_log = torch.log(input=x + self.eps)  # (B, 1, H, W)

        # 2. FFT
        x_fft = torch.fft.fft2(x_log.squeeze(dim=1))  # (B, H, W)

        # 3. Low-pass / High-pass 분리
        filter_mask = self.filter_mask.unsqueeze(
            0).expand(B, -1, -1)  # (B, H, W)

        low_fft = x_fft * filter_mask
        high_fft = x_fft * (1 - filter_mask)

        # 4. IFFT 후 real 값 추출
        low_spatial = torch.real(
            input=torch.fft.ifft2(
                low_fft
            )
        ).unsqueeze(dim=1)  # (B,1,H,W)
        high_spatial = torch.real(
            input=torch.fft.ifft2(
                high_fft
            )
        ).unsqueeze(dim=1)  # (B,1,H,W)

        # 5. exp 복원
        illumination = torch.exp(input=low_spatial)
        detail = torch.exp(input=high_spatial)

        return illumination, detail


### embedding

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size=512, in_channels=1, embed_dim=768, patch_size=4, bias=True):
        super().__init__()
        self.image_size = image_size
        self.grid_size = self.image_size // patch_size
        self.num_patches = self.grid_size ** 2
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.bias = bias

        self.proj = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=self.bias
        )

    def forward(self, x):
        # (B, embed_dim, image_size/patch_size, image_size/patch_size) -> (B, 768, 128, 128)
        x = self.proj(x)
        # (B, embed_dim, (image_size/patch_size)**2) -> (B, 768, 16384)
        x = x.flatten(2)
        # (B, (image_size/patch_size)**2, embed_dim) -> (B, 16384, 768
        x = x.transpose(1, 2)
        return x


class PositionalEmbedding(nn.Module):
    def __init__(self, embed_dim=768, size=128):
        super().__init__()
        self.embed_dim = embed_dim
        self.size = size

        self.device = torch.device(device='cuda')

    def _build_sincos_embedding(self, size):
        grid = torch.linspace(start=0, end=1, steps=size, device=self.device)
        grid_y, grid_x = torch.meshgrid(grid, grid, indexing='ij')  # (H, W)

        grid = torch.stack(tensors=[grid_y, grid_x], dim=0)  # (2, H, W)

        pos_embed = self._get_2d_sincos_pos_embed_from_grid(
            embed_dim=self.embed_dim,
            grid=grid
        )
        pos_embed = pos_embed.view(1, size * size, self.embed_dim)  # (1, N, D)
        return pos_embed

    def _get_1d_sincos_pos_embed_from_grid(self, embed_dim, pos):
        omega = torch.arange(
            end=embed_dim // 2, dtype=torch.float32, device=self.device)
        omega = 1. / (10000 ** (omega / (embed_dim / 2)))

        pos = pos.flatten()  # (H*W,)
        out = torch.einsum('m,d->md', pos, omega)  # (H*W, embed_dim//2)

        emb_sin = torch.sin(input=out)
        emb_cos = torch.cos(input=out)

        return torch.cat(tensors=[emb_sin, emb_cos], dim=1)  # (H*W, embed_dim)

    def _get_2d_sincos_pos_embed_from_grid(self, embed_dim, grid):
        emb_h = self._get_1d_sincos_pos_embed_from_grid(
            embed_dim=embed_dim // 2, pos=grid[0])  # y
        emb_w = self._get_1d_sincos_pos_embed_from_grid(
            embed_dim=embed_dim // 2, pos=grid[1])  # x
        return torch.cat(tensors=[emb_h, emb_w], dim=1)  # (H*W, embed_dim)

    def forward(self):
        pos_embed = self._build_sincos_embedding(size=self.size)
        return pos_embed


class TimeEmbedding(nn.Module):
    def __init__(self, hidden_size=768, embed_dim=256, max_period=10000, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.embed_dim = embed_dim
        self.max_period = max_period
        self.bias = bias

        self.mlp = nn.Sequential(
            nn.Linear(
                in_features=self.embed_dim,
                out_features=self.hidden_size,
                bias=self.bias
            ),
            nn.SiLU(),
            nn.Linear(
                in_features=self.hidden_size,
                out_features=self.hidden_size,
                bias=self.bias
            )
        )

    def timestep_embedding(self, t, embed_dim):
        half_dim = embed_dim // 2
        freqs = torch.exp(
            input=math.log(self.max_period) * torch.arange(
                end=half_dim,
                dtype=t.dtype,
                device=t.device
            ) / half_dim
        )
        args = t[:, None].float() * freqs[None, :]
        emb = torch.cat(
            tensors=[
                torch.cos(input=args),
                torch.sin(input=args)
            ],
            dim=1
        )
        return emb

    def forward(self, t):
        t_freq = self.timestep_embedding(t=t, embed_dim=self.embed_dim)
        t_emb = self.mlp(t_freq)
        return t_emb  # (B, 768)


### attention

In [None]:
class MultiLayerPerceptron(nn.Module):
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            bias=True,
            drop=0.,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features or in_features
        self.hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(
            in_features=in_features,
            out_features=self.hidden_features,
            bias=bias
        )
        self.act = nn.GELU(
            approximate="tanh"
        )
        self.drop1 = nn.Dropout(
            p=drop
        )
        self.fc2 = nn.Linear(
            in_features=self.hidden_features,
            out_features=self.out_features,
            bias=bias
        )
        self.drop2 = nn.Dropout(
            p=drop
        )

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(
            in_features=dim,
            out_features=dim * 3,
            bias=qkv_bias
        )
        self.attn_drop = nn.Dropout(p=attn_drop)
        self.proj = nn.Linear(
            in_features=dim,
            out_features=dim
        )
        self.proj_drop = nn.Dropout(p=proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
                                  self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


### core

In [None]:
class DiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio)

        self.norm1 = nn.LayerNorm(
            normalized_shape=self.hidden_size,
            elementwise_affine=False,
            eps=1e-6
        )
        self.attn = Attention(
            dim=self.hidden_size,
            num_heads=self.num_heads,
            qkv_bias=True
        )
        self.norm2 = nn.LayerNorm(
            normalized_shape=self.hidden_size,
            elementwise_affine=False,
            eps=1e-6
        )
        self.mlp = MultiLayerPerceptron(
            in_features=self.hidden_size,
            hidden_features=self.mlp_hidden_dim,
            drop=0
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                in_features=self.hidden_size,
                out_features=6 * self.hidden_size,
                bias=True
            )
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
            c).chunk(6, dim=1)
        x = self.norm1(x) * (1 + scale_msa.unsqueeze(1)) + \
            shift_msa.unsqueeze(1)
        x = x + gate_msa.unsqueeze(1) * self.attn(x)
        x = self.norm2(x) * (1 + scale_mlp.unsqueeze(1)) + \
            shift_mlp.unsqueeze(1)
        x = x + gate_mlp.unsqueeze(1) * self.mlp(x)
        return x


class FinalLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.hidden_size = hidden_size
        self.patch_size = patch_size
        self.out_channels = out_channels

        self.norm_final = nn.LayerNorm(
            normalized_shape=self.hidden_size,
            elementwise_affine=False,
            eps=1e-6
        )
        self.linear = nn.Linear(
            in_features=self.hidden_size,
            out_features=self.patch_size * self.patch_size * self.out_channels,
            bias=True
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                in_features=self.hidden_size,
                out_features=2 * self.hidden_size,
                bias=True
            )
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = self.linear(x)
        return x


## losses

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16


### basic

In [None]:
class L_col(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        mean_rgb = torch.mean(
            input=x,
            dim=[2, 3],
            keepdim=True
        )
        mr, mg, mb = torch.split(
            tensor=mean_rgb,
            split_size_or_sections=1,
            dim=1
        )

        Drg = torch.pow(input=mr - mg, exponent=2)
        Drb = torch.pow(input=mr - mb, exponent=2)
        Dgb = torch.pow(input=mb - mg, exponent=2)

        c = torch.pow(input=torch.pow(input=Drg, exponent=2) + torch.pow(input=Drb,
                      exponent=2) + torch.pow(input=Dgb, exponent=2), exponent=0.5)
        return c


class L_spa(nn.Module):
    def __init__(self):
        super().__init__()
        kernel_l = torch.FloatTensor(
            [[0, 0, 0], [-1, 1, 0], [0, 0, 0]]).cuda().unsqueeze(dim=0).unsqueeze(dim=0)
        kernel_r = torch.FloatTensor(
            [[0, 0, 0], [0, 1, -1], [0, 0, 0]]).cuda().unsqueeze(dim=0).unsqueeze(dim=0)
        kernel_u = torch.FloatTensor(
            [[0, -1, 0], [0, 1, 0], [0, 0, 0]]).cuda().unsqueeze(dim=0).unsqueeze(dim=0)
        kernel_d = torch.FloatTensor(
            [[0, 0, 0], [0, 1, 0], [0, -1, 0]]).cuda().unsqueeze(dim=0).unsqueeze(dim=0)

        self.weight_l = nn.Parameter(data=kernel_l, requires_grad=False)
        self.weight_r = nn.Parameter(data=kernel_r, requires_grad=False)
        self.weight_u = nn.Parameter(data=kernel_u, requires_grad=False)
        self.weight_d = nn.Parameter(data=kernel_d, requires_grad=False)
        self.pool = nn.AvgPool2d(kernel_size=4)

    def forward(self, org, enh):
        org_mean = torch.mean(input=org, dim=1, keepdim=True)
        enh_mean = torch.mean(input=enh, dim=1, keepdim=True)

        org_pool = self.pool(org_mean)
        enh_pool = self.pool(enh_mean)

        D_org_l = F.conv2d(input=org_pool, weight=self.weight_l, padding=1)
        D_org_r = F.conv2d(input=org_pool, weight=self.weight_r, padding=1)
        D_org_u = F.conv2d(input=org_pool, weight=self.weight_u, padding=1)
        D_org_d = F.conv2d(input=org_pool, weight=self.weight_d, padding=1)

        D_enh_l = F.conv2d(input=enh_pool, weight=self.weight_l, padding=1)
        D_enh_r = F.conv2d(input=enh_pool, weight=self.weight_r, padding=1)
        D_enh_u = F.conv2d(input=enh_pool, weight=self.weight_u, padding=1)
        D_enh_d = F.conv2d(input=enh_pool, weight=self.weight_d, padding=1)

        D_l = torch.pow(input=D_org_l - D_enh_l, exponent=2)
        D_r = torch.pow(input=D_org_r - D_enh_r, exponent=2)
        D_u = torch.pow(input=D_org_u - D_enh_u, exponent=2)
        D_d = torch.pow(input=D_org_d - D_enh_d, exponent=2)

        s = (D_l + D_r + D_u + D_d)
        return s


class L_exp(nn.Module):
    def __init__(self, patch_size=16, mean_val=0.6):
        super().__init__()
        self.pool = nn.AvgPool2d(kernel_size=patch_size)
        self.mean_val = mean_val

    def forward(self, x):
        x = torch.mean(input=x, dim=1, keepdim=True)
        mean = self.pool(x)

        e = torch.mean(
            input=torch.pow(
                input=mean - torch.FloatTensor(
                    [self.mean_val]
                ).cuda(),
                exponent=2
            )
        )
        return e


class L_TV(nn.Module):
    def __init__(self, weight=1):
        super().__init__()
        self.weight = weight

    def forward(self, x):
        b, c, h, w = x.shape

        count_h = (h - 1) * w
        count_w = h * (w - 1)

        h_tv = torch.pow(
            input=(x[:, :, 1:, :] - x[:, :, :h - 1, :]), exponent=2).sum()
        w_tv = torch.pow(
            input=(x[:, :, :, 1:] - x[:, :, :, :w - 1]), exponent=2).sum()

        t = self.weight * 2 * ((h_tv / count_h) + (w_tv / count_w)) / b
        return t


### Feature Attribution

In [None]:
class L_bri(nn.Module):
    def __init__(self, timestep_range=10000, weight=1.):
        super().__init__()
        self.timestep_range = timestep_range
        self.weight = weight

    def forward(self, i_x, n_Y, t):
        mean_i_x = i_x.mean(dim=[1, 2, 3])
        mean_n_Y = n_Y.mean(dim=[1, 2, 3])

        target_increase = t.float() / self.timestep_range
        actual_increase = mean_n_Y - mean_i_x

        loss = F.mse_loss(input=actual_increase, target=target_increase)
        b = self.weight * loss
        return b


class L_mod(nn.Module):
    def __init__(self, weight=1.):
        super().__init__()
        self.weight = weight

    def forward(self, modulations):
        loss = 0.0  # ⭐ 여기에 초기화

        for modulation_set in modulations:
            for modulation in modulation_set:
                negative = modulation[modulation < 0]
                if negative.numel() > 0:
                    loss += torch.mean(input=torch.abs(input=negative))

        loss = loss * self.weight

        # ⭐ Tensor로 변환
        return torch.tensor(data=loss, device=modulations[0][0].device, dtype=torch.float32)


## model

In [None]:
import torch
import lightning as L

from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
from transformers.optimization import get_cosine_schedule_with_warmup


#### basic

In [None]:
class HomomorphicDit(nn.Module):
    def __init__(
        self,
        image_size=512,
        hidden_size=768,
        patch_size=4,
        depth=12,
        num_heads=12,
        in_channels=3,
        out_channels=1,
    ):
        super().__init__()
        self.image_size = image_size
        self.hidden_size = hidden_size
        self.patch_size = patch_size
        self.depth = depth
        self.num_heads = num_heads
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.rgb2ycrcb = RGB2YCrCb()
        self.homo_separate = HomomorphicSeparation(
            size=self.image_size
        )
        self.illum_embedding = PatchEmbedding(
            image_size=self.image_size,
            embed_dim=self.hidden_size,
            patch_size=self.patch_size,
            bias=True
        )
        self.pos_embedding = PositionalEmbedding(
            embed_dim=self.hidden_size,
            size=self.illum_embedding.grid_size
        )
        self.t_embedding = TimeEmbedding(
            hidden_size=self.hidden_size,
            bias=True
        )
        self.blocks = nn.ModuleList(modules=[
            DiTBlock(
                hidden_size=self.hidden_size,
                num_heads=self.num_heads,
            )
            for _ in range(self.depth)
        ])
        self.final_layer = FinalLayer(
            hidden_size=self.hidden_size,
            patch_size=self.patch_size,
            out_channels=self.out_channels,
        )
        self.ycrcb2rgb = YCrCb2RGB()
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(tensor=module.weight)
                if module.bias is not None:
                    nn.init.constant_(tensor=module.bias, val=0)
        self.apply(fn=_basic_init)

        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
        w_i = self.illum_embedding.proj.weight.data
        nn.init.xavier_uniform_(tensor=w_i.view([w_i.shape[0], -1]))
        nn.init.constant_(tensor=self.illum_embedding.proj.bias, val=0)

        # Initialize timestep embedding MLP:
        nn.init.normal_(tensor=self.t_embedding.mlp[0].weight, std=0.02)
        nn.init.normal_(tensor=self.t_embedding.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for block in self.blocks:
            # nn.init.xavier_uniform_(block.adaLN_modulation[-1].weight)
            nn.init.constant_(tensor=block.adaLN_modulation[-1].weight, val=0)
            nn.init.constant_(tensor=block.adaLN_modulation[-1].bias, val=0)

        # Zero-out output layers:
        # nn.init.xavier_uniform_(self.final_layer.adaLN_modulation[-1].weight)
        nn.init.constant_(
            tensor=self.final_layer.adaLN_modulation[-1].weight, val=0)
        nn.init.constant_(
            tensor=self.final_layer.adaLN_modulation[-1].bias, val=0)
        nn.init.constant_(tensor=self.final_layer.linear.weight, val=0)
        nn.init.constant_(tensor=self.final_layer.linear.bias, val=0)

    def unpatchify(self, x):
        c = self.out_channels
        p = self.illum_embedding.patch_size
        h = w = int(x=x.shape[1] ** 0.5)

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        x = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return x

    def forward(self, x, t):
        Y, Cr, Cb = self.rgb2ycrcb(x)
        i_x, d_x = self.homo_separate(Y)
        i_emb = self.illum_embedding(i_x) + self.pos_embedding()

        t_emb = self.t_embedding(t)
        cond = t_emb

        for block in self.blocks:
            i_emb = block(i_emb, cond)

        i_emb = self.final_layer(i_emb, cond)
        i_x = self.unpatchify(x=i_emb)
        n_Y = i_x * d_x
        dit_enh_img = self.ycrcb2rgb(n_Y, Cr, Cb)
        return dit_enh_img


#### Feature Attribution

In [None]:
class HomomorphicDit(nn.Module):
    def __init__(
        self,
        image_size=512,
        hidden_size=768,
        patch_size=4,
        depth=12,
        num_heads=12,
        in_channels=3,
        out_channels=1,
    ):
        super().__init__()
        self.image_size = image_size
        self.hidden_size = hidden_size
        self.patch_size = patch_size
        self.depth = depth
        self.num_heads = num_heads
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.rgb2ycrcb = RGB2YCrCb()
        self.homo_separate = HomomorphicSeparation(
            size=self.image_size
        )
        self.illum_embedding = PatchEmbedding(
            image_size=self.image_size,
            embed_dim=self.hidden_size,
            patch_size=self.patch_size,
            bias=True
        )
        self.pos_embedding = PositionalEmbedding(
            embed_dim=self.hidden_size,
            size=self.illum_embedding.grid_size
        )
        self.t_embedding = TimeEmbedding(
            hidden_size=self.hidden_size,
            bias=True
        )
        self.blocks = nn.ModuleList(modules=[
            DiTBlock(
                hidden_size=self.hidden_size,
                num_heads=self.num_heads,
            )
            for _ in range(self.depth)
        ])
        self.final_layer = FinalLayer(
            hidden_size=self.hidden_size,
            patch_size=self.patch_size,
            out_channels=self.out_channels,
        )
        self.ycrcb2rgb = YCrCb2RGB()
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(tensor=module.weight)
                if module.bias is not None:
                    nn.init.constant_(tensor=module.bias, val=0)
        self.apply(fn=_basic_init)

        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
        w_i = self.illum_embedding.proj.weight.data
        nn.init.xavier_uniform_(tensor=w_i.view([w_i.shape[0], -1]))
        nn.init.constant_(tensor=self.illum_embedding.proj.bias, val=0)

        # Initialize timestep embedding MLP:
        nn.init.normal_(tensor=self.t_embedding.mlp[0].weight, std=0.02)
        nn.init.normal_(tensor=self.t_embedding.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for block in self.blocks:
            # nn.init.xavier_uniform_(block.adaLN_modulation[-1].weight)
            nn.init.constant_(tensor=block.adaLN_modulation[-1].weight, val=0)
            nn.init.constant_(tensor=block.adaLN_modulation[-1].bias, val=0)

        # Zero-out output layers:
        # nn.init.xavier_uniform_(self.final_layer.adaLN_modulation[-1].weight)
        nn.init.constant_(
            tensor=self.final_layer.adaLN_modulation[-1].weight, val=0)
        nn.init.constant_(
            tensor=self.final_layer.adaLN_modulation[-1].bias, val=0)
        nn.init.constant_(tensor=self.final_layer.linear.weight, val=0)
        nn.init.constant_(tensor=self.final_layer.linear.bias, val=0)

    def unpatchify(self, x):
        c = self.out_channels
        p = self.illum_embedding.patch_size
        h = w = int(x.shape[1] ** 0.5)

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        x = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return x

    def forward(self, x, t):
        Y, Cr, Cb = self.rgb2ycrcb(x)
        i_x, d_x = self.homo_separate(Y)
        i_emb = self.illum_embedding(i_x) + self.pos_embedding()

        t_emb = self.t_embedding(t)
        cond = t_emb

        modulations = []  # ⭐ modulation 기록할 리스트

        for block in self.blocks:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.adaLN_modulation(
                cond).chunk(6, dim=1)
            modulations.append(
                (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp))
            i_emb = block(i_emb, cond)

        i_emb = self.final_layer(i_emb, cond)
        i_x_out = self.unpatchify(i_emb)
        n_Y = i_x_out * d_x
        dit_enh_img = self.ycrcb2rgb(n_Y, Cr, Cb)

        return dit_enh_img, modulations, i_x, i_x_out  # ⭐


## lightning

#### adam

In [None]:
class HomomorphicDiTLightning(L.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.model = HomomorphicDit(
            image_size=hparams['image_size'],
            hidden_size=hparams['hidden_size'],
            patch_size=hparams['patch_size'],
            depth=hparams['depth'],
            num_heads=hparams['num_heads'],
            in_channels=hparams["in_channels"],
            out_channels=hparams["out_channels"],
        )

        self.spa_loss = L_spa()
        self.col_loss = L_col()
        self.exp_loss = L_exp()

        self.lambda_spa = hparams["lambda_spa"]
        self.lambda_col = hparams["lambda_col"]
        self.lambda_exp = hparams["lambda_exp"]

        self.timestep_range = hparams['timestep_range']

        self.metric = ImageQualityMetrics(device="cuda")
        self.metric.eval()

    def forward(self, x, t):
        return self.model(x, t)

    def training_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img = self(x, t)

        loss_spa = self.lambda_spa * \
            torch.mean(input=self.spa_loss(x, enh_img))
        loss_col = self.lambda_col * torch.mean(input=self.col_loss(enh_img))
        loss_exp = self.lambda_exp * torch.mean(input=self.exp_loss(enh_img))

        total = (loss_spa + loss_col + loss_exp)

        self.log_dict(dictionary={
            "train/spa": loss_spa,
            "train/col": loss_col,
            "train/exp": loss_exp,
            "train/total": total,
        }, prog_bar=True)
        return total

    def validation_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img = self(x, t)

        loss_spa = self.lambda_spa * \
            torch.mean(input=self.spa_loss(x, enh_img))
        loss_col = self.lambda_col * torch.mean(input=self.col_loss(enh_img))
        loss_exp = self.lambda_exp * torch.mean(input=self.exp_loss(enh_img))

        total = (loss_spa + loss_col + loss_exp)

        self.log_dict(dictionary={
            "valid/spa": loss_spa,
            "valid/col": loss_col,
            "valid/exp": loss_exp,
            "valid/total": total,
        }, prog_bar=True)
        return total

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img = self(x, t)

        metrics = self.metric.full(preds=enh_img, targets=x)

        self.log_dict(dictionary={
            "bench/PSNR": metrics["PSNR"],
            "bench/SSIM": metrics["SSIM"],
            "bench/LPIPS": metrics["LPIPS"],
            "bench/NIQE": metrics["NIQE"],
            "bench/BRISQUE": metrics["BRISQUE"],
        }, prog_bar=True)
        return metrics

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img = self(x, t)
        return enh_img

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams['lr']
        )
        return optimizer


#### get cosine schedule with warmup

In [None]:
class HomomorphicDiTLightning(L.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.model = HomomorphicDit(
            image_size=hparams['image_size'],
            hidden_size=hparams['hidden_size'],
            patch_size=hparams['patch_size'],
            depth=hparams['depth'],
            num_heads=hparams['num_heads'],
            in_channels=hparams["in_channels"],
            out_channels=hparams["out_channels"],
        )

        self.spa_loss = L_spa()
        self.col_loss = L_col()
        self.exp_loss = L_exp()

        self.lambda_spa = hparams["lambda_spa"]
        self.lambda_col = hparams["lambda_col"]
        self.lambda_exp = hparams["lambda_exp"]

        self.timestep_range = hparams['timestep_range']

        self.metric = ImageQualityMetrics(device="cuda")
        self.metric.eval()

    def forward(self, x, t):
        return self.model(x, t)

    def training_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img = self(x, t)

        loss_spa = self.lambda_spa * \
            torch.mean(input=self.spa_loss(x, enh_img))
        loss_col = self.lambda_col * torch.mean(input=self.col_loss(enh_img))
        loss_exp = self.lambda_exp * torch.mean(input=self.exp_loss(enh_img))

        total = (loss_spa + loss_col + loss_exp)

        self.log_dict(dictionary={
            "train/spa": loss_spa,
            "train/col": loss_col,
            "train/exp": loss_exp,
            "train/total": total,
        }, prog_bar=True)
        return total

    def validation_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img = self(x, t)

        loss_spa = self.lambda_spa * \
            torch.mean(input=self.spa_loss(x, enh_img))
        loss_col = self.lambda_col * torch.mean(input=self.col_loss(enh_img))
        loss_exp = self.lambda_exp * torch.mean(input=self.exp_loss(enh_img))

        total = (loss_spa + loss_col + loss_exp)

        self.log_dict(dictionary={
            "valid/spa": loss_spa,
            "valid/col": loss_col,
            "valid/exp": loss_exp,
            "valid/total": total,
        }, prog_bar=True)
        return total

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img = self(x, t)

        metrics = self.metric.full(preds=enh_img, targets=x)

        self.log_dict(dictionary={
            "bench/PSNR": metrics["PSNR"],
            "bench/SSIM": metrics["SSIM"],
            "bench/LPIPS": metrics["LPIPS"],
            "bench/NIQE": metrics["NIQE"],
            "bench/BRISQUE": metrics["BRISQUE"],
        }, prog_bar=True)
        return metrics

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img = self(x, t)
        return enh_img

    def configure_optimizers(self):
        total_steps = self.trainer.estimated_stepping_batches

        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams['lr']
        )

        scheduler = get_cosine_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=2600,  # 1~2 epoch 분량
            num_training_steps=total_steps,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            }
        }


#### Cosine Annealing Warm Restarts

In [None]:
class HomomorphicDiTLightning(L.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.model = HomomorphicDit(
            image_size=hparams['image_size'],
            hidden_size=hparams['hidden_size'],
            patch_size=hparams['patch_size'],
            depth=hparams['depth'],
            num_heads=hparams['num_heads'],
            in_channels=hparams["in_channels"],
            out_channels=hparams["out_channels"],
        )

        self.spa_loss = L_spa()
        self.col_loss = L_col()
        self.exp_loss = L_exp()

        self.lambda_spa = hparams["lambda_spa"]
        self.lambda_col = hparams["lambda_col"]
        self.lambda_exp = hparams["lambda_exp"]

        self.timestep_range = hparams['timestep_range']

        self.metric = ImageQualityMetrics(device="cuda")
        self.metric.eval()

    def forward(self, x, t):
        return self.model(x, t)

    def training_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img = self(x, t)

        loss_spa = self.lambda_spa * \
            torch.mean(input=self.spa_loss(x, enh_img))
        loss_col = self.lambda_col * torch.mean(input=self.col_loss(enh_img))
        loss_exp = self.lambda_exp * torch.mean(input=self.exp_loss(enh_img))

        total = (loss_spa + loss_col + loss_exp)

        self.log_dict(dictionary={
            "train/spa": loss_spa,
            "train/col": loss_col,
            "train/exp": loss_exp,
            "train/total": total,
        }, prog_bar=True)
        return total

    def validation_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img = self(x, t)

        loss_spa = self.lambda_spa * \
            torch.mean(input=self.spa_loss(x, enh_img))
        loss_col = self.lambda_col * torch.mean(input=self.col_loss(enh_img))
        loss_exp = self.lambda_exp * torch.mean(input=self.exp_loss(enh_img))

        total = (loss_spa + loss_col + loss_exp)

        self.log_dict(dictionary={
            "valid/spa": loss_spa,
            "valid/col": loss_col,
            "valid/exp": loss_exp,
            "valid/total": total,
        }, prog_bar=True)
        return total

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img = self(x, t)

        metrics = self.metric.full(preds=enh_img, targets=x)

        self.log_dict(dictionary={
            "bench/PSNR": metrics["PSNR"],
            "bench/SSIM": metrics["SSIM"],
            "bench/LPIPS": metrics["LPIPS"],
            "bench/NIQE": metrics["NIQE"],
            "bench/BRISQUE": metrics["BRISQUE"],
        }, prog_bar=True)
        return metrics

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img = self(x, t)
        return enh_img

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams['lr']
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=10,           # 최초 주기 (epoch 수 기준) - 예를 들어 10 에폭마다 리셋
            T_mult=2,         # 주기를 2배로 늘려가면서 반복
            eta_min=1e-7      # 최소 학습률
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",  # 매 에폭마다 스케줄러 업데이트
                "frequency": 1,
            }
        }


#### dynamic reweight

In [None]:
class HomomorphicDiTLightning(L.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.model = HomomorphicDit(
            image_size=hparams['image_size'],
            hidden_size=hparams['hidden_size'],
            patch_size=hparams['patch_size'],
            depth=hparams['depth'],
            num_heads=hparams['num_heads'],
            in_channels=hparams["in_channels"],
            out_channels=hparams["out_channels"],
        )

        self.spa_loss = L_spa()
        self.col_loss = L_col()
        self.exp_loss = L_exp()

        self.lambda_spa = hparams["lambda_spa"]
        self.lambda_col = hparams["lambda_col"]
        self.lambda_exp = hparams["lambda_exp"]

        self.timestep_range = hparams['timestep_range']

        self.metric = ImageQualityMetrics(device="cuda")
        self.metric.eval()

    def dynamic_reweight(self, losses):
        total = sum(losses.values())
        weights = {}
        for key, value in losses.items():
            normalized = value / (total + 1e-8)
            weights[key] = 1.0 / (normalized + 1e-6)

        weight_sum = sum(weights.values())
        for key in weights.keys():
            weights[key] /= weight_sum
        return weights

    def forward(self, x, t):
        return self.model(x, t)

    def training_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img = self(x, t)

        loss_spa = torch.mean(self.spa_loss(x, enh_img))
        loss_col = torch.mean(self.col_loss(enh_img))
        loss_exp = torch.mean(self.exp_loss(enh_img))

        # --- Dynamic Reweighting
        losses = {
            "spa": loss_spa.detach(),
            "col": loss_col.detach(),
            "exp": loss_exp.detach()
        }
        dynamic_lambdas = self.dynamic_reweight(losses)

        total = (
            dynamic_lambdas["spa"] * loss_spa +
            dynamic_lambdas["col"] * loss_col +
            dynamic_lambdas["exp"] * loss_exp
        )

        self.log_dict({
            "train/spa": loss_spa,
            "train/col": loss_col,
            "train/exp": loss_exp,
            "train/total": total,
            "train/weight_spa": dynamic_lambdas["spa"],
            "train/weight_col": dynamic_lambdas["col"],
            "train/weight_exp": dynamic_lambdas["exp"],
        }, prog_bar=True)

        return total

    def validation_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img = self(x, t)

        loss_spa = torch.mean(self.spa_loss(x, enh_img))
        loss_col = torch.mean(self.col_loss(enh_img))
        loss_exp = torch.mean(self.exp_loss(enh_img))

        losses = {
            "spa": loss_spa.detach(),
            "col": loss_col.detach(),
            "exp": loss_exp.detach()
        }
        dynamic_lambdas = self.dynamic_reweight(losses)

        total = (
            dynamic_lambdas["spa"] * loss_spa +
            dynamic_lambdas["col"] * loss_col +
            dynamic_lambdas["exp"] * loss_exp
        )

        self.log_dict({
            "valid/spa": loss_spa,
            "valid/col": loss_col,
            "valid/exp": loss_exp,
            "valid/total": total,
            "valid/weight_spa": dynamic_lambdas["spa"],
            "valid/weight_col": dynamic_lambdas["col"],
            "valid/weight_exp": dynamic_lambdas["exp"],
        }, prog_bar=True)

        return total

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img = self(x, t)

        metrics = self.metric.full(enh_img, x)

        self.log_dict({
            "bench/PSNR": metrics["PSNR"],
            "bench/SSIM": metrics["SSIM"],
            "bench/LPIPS": metrics["LPIPS"],
            "bench/NIQE": metrics["NIQE"],
            "bench/BRISQUE": metrics["BRISQUE"],
        }, prog_bar=True)
        return metrics

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img = self(x, t)
        return enh_img

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams['lr']
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=10,           # 최초 주기 (epoch 수 기준) - 예를 들어 10 에폭마다 리셋
            T_mult=2,         # 주기를 2배로 늘려가면서 반복
            eta_min=1e-7      # 최소 학습률
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",  # 매 에폭마다 스케줄러 업데이트
                "frequency": 1,
            }
        }


#### Feature Attribution

In [None]:
class HomomorphicDiTLightning(L.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.model = HomomorphicDit(
            image_size=hparams['image_size'],
            hidden_size=hparams['hidden_size'],
            patch_size=hparams['patch_size'],
            depth=hparams['depth'],
            num_heads=hparams['num_heads'],
            in_channels=hparams["in_channels"],
            out_channels=hparams["out_channels"],
        )

        self.spa_loss = L_spa()
        self.col_loss = L_col()
        self.exp_loss = L_exp()
        self.bri_loss = L_bri()
        self.mod_loss = L_mod()

        self.lambda_spa = hparams["lambda_spa"]
        self.lambda_col = hparams["lambda_col"]
        self.lambda_exp = hparams["lambda_exp"]
        self.lambda_bri = hparams["lambda_bri"]
        self.lambda_mod = hparams["lambda_mod"]

        self.timestep_range = hparams['timestep_range']

        self.metric = ImageQualityMetrics(device="cuda")
        self.metric.eval()

    def dynamic_reweight(self, losses):
        total = sum(losses.values())
        weights = {}
        for key, value in losses.items():
            normalized = value / (total + 1e-8)
            weights[key] = 1.0 / (normalized + 1e-6)

        weight_sum = sum(weights.values())
        for key in weights.keys():
            weights[key] /= weight_sum
        return weights

    def forward(self, x, t):
        return self.model(x, t)

    def training_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img, modulations, illumination_in, illumination_out = self(x, t)

        loss_spa = torch.mean(input=self.spa_loss(x, enh_img))
        loss_col = torch.mean(input=self.col_loss(enh_img))
        loss_exp = torch.mean(input=self.exp_loss(enh_img))

        loss_bri = torch.mean(
            input=self.bri_loss(
                illumination_in,
                illumination_out,
                t
            )
        )
        loss_mod = torch.mean(
            input=self.mod_loss(
                modulations
            )
        )

        # --- Dynamic Reweighting
        losses = {
            "spa": loss_spa.detach(),
            "col": loss_col.detach(),
            "exp": loss_exp.detach(),
            "bri": loss_bri.detach(),
            "mod": loss_mod.detach(),
        }
        dynamic_lambdas = self.dynamic_reweight(losses=losses)

        total = (
            dynamic_lambdas["spa"] * loss_spa +
            dynamic_lambdas["col"] * loss_col +
            dynamic_lambdas["exp"] * loss_exp +
            dynamic_lambdas["bri"] * loss_bri +
            dynamic_lambdas["mod"] * loss_mod
        )

        self.log_dict(
            dictionary={
                "train/01_spa": loss_spa,
                "train/02_col": loss_col,
                "train/03_exp": loss_exp,
                "train/04_bri": loss_bri,
                "train/05_mod": loss_mod,
                "train/06_total": total,
                "train/07_weight_spa": dynamic_lambdas["spa"],
                "train/08_weight_col": dynamic_lambdas["col"],
                "train/09_weight_exp": dynamic_lambdas["exp"],
                "train/10_weight_bri": dynamic_lambdas["bri"],
                "train/11_weight_mod": dynamic_lambdas["mod"],
            },
            prog_bar=True
        )

        return total

    def validation_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        enh_img, modulations, illumination_in, illumination_out = self(x, t)

        loss_spa = torch.mean(input=self.spa_loss(x, enh_img))
        loss_col = torch.mean(input=self.col_loss(enh_img))
        loss_exp = torch.mean(input=self.exp_loss(enh_img))

        loss_bri = torch.mean(
            input=self.bri_loss(
                illumination_in,
                illumination_out,
                t
            )
        )
        loss_mod = torch.mean(
            input=self.mod_loss(
                modulations
            )
        )

        # --- Dynamic Reweighting
        losses = {
            "spa": loss_spa.detach(),
            "col": loss_col.detach(),
            "exp": loss_exp.detach(),
            "bri": loss_bri.detach(),
            "mod": loss_mod.detach(),
        }
        dynamic_lambdas = self.dynamic_reweight(losses=losses)

        total = (
            dynamic_lambdas["spa"] * loss_spa +
            dynamic_lambdas["col"] * loss_col +
            dynamic_lambdas["exp"] * loss_exp +
            dynamic_lambdas["bri"] * loss_bri +
            dynamic_lambdas["mod"] * loss_mod
        )

        self.log_dict(
            dictionary={
                "valid/01_spa": loss_spa,
                "valid/02_col": loss_col,
                "valid/03_exp": loss_exp,
                "valid/04_bri": loss_bri,
                "valid/05_mod": loss_mod,
                "valid/06_total": total,
                "valid/07_weight_spa": dynamic_lambdas["spa"],
                "valid/08_weight_col": dynamic_lambdas["col"],
                "valid/09_weight_exp": dynamic_lambdas["exp"],
                "valid/10_weight_bri": dynamic_lambdas["bri"],
                "valid/11_weight_mod": dynamic_lambdas["mod"],
            },
            prog_bar=True
        )

        return total

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img, modulations, illumination_in, illumination_out = self(x, t)

        metrics = self.metric.full(preds=enh_img, targets=x)

        self.log_dict(
            dictionary={
                "bench/01_PSNR": metrics["PSNR"],
                "bench/02_SSIM": metrics["SSIM"],
                "bench/03_LPIPS": metrics["LPIPS"],
                "bench/04_NIQE": metrics["NIQE"],
                "bench/05_BRISQUE": metrics["BRISQUE"],
            },
            prog_bar=True
        )
        return metrics

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        enh_img, modulations, illumination_in, illumination_out = self(x, t)
        return enh_img

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams['lr']
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=10,           # 최초 주기 (epoch 수 기준) - 예를 들어 10 에폭마다 리셋
            T_mult=2,         # 주기를 2배로 늘려가면서 반복
            eta_min=1e-7      # 최소 학습률
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",  # 매 에폭마다 스케줄러 업데이트
                "frequency": 1,
            }
        }


In [None]:
class HomomorphicDiTLightning(L.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.model = HomomorphicDit(
            image_size=hparams['image_size'],
            hidden_size=hparams['hidden_size'],
            patch_size=hparams['patch_size'],
            depth=hparams['depth'],
            num_heads=hparams['num_heads'],
            in_channels=hparams["in_channels"],
            out_channels=hparams["out_channels"],
        )

        self.spa_loss = L_spa()
        self.col_loss = L_col()
        self.exp_loss = L_exp()
        self.bri_loss = L_bri()
        self.mod_loss = L_mod()

        self.lambda_spa = hparams["lambda_spa"]
        self.lambda_col = hparams["lambda_col"]
        self.lambda_exp = hparams["lambda_exp"]
        self.lambda_bri = hparams["lambda_bri"]
        self.lambda_mod = hparams["lambda_mod"]

        self.timestep_range = hparams['timestep_range']

        self.metric = ImageQualityMetrics(device="cuda")
        self.metric.eval()

    def dynamic_reweight(self, losses):
        total = sum(losses.values())
        weights = {}
        for key, value in losses.items():
            normalized = value / (total + 1e-8)
            weights[key] = 1.0 / (normalized + 1e-6)

        weight_sum = sum(weights.values())
        for key in weights.keys():
            weights[key] /= weight_sum
        return weights

    def forward(self, x, t):
        return self.model(x, t)

    def training_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        dit_enh_img, modulations, i_x, i_x_out = self(x, t)

        loss_spa = torch.mean(input=self.spa_loss(x, dit_enh_img))
        loss_col = torch.mean(input=self.col_loss(dit_enh_img))
        loss_exp = torch.mean(input=self.exp_loss(dit_enh_img))

        loss_bri = torch.mean(input=self.bri_loss(i_x, i_x_out, t))
        loss_mod = torch.mean(input=self.mod_loss(modulations))

        # --- Dynamic Reweighting
        losses = {
            "spa": loss_spa.detach(),
            "col": loss_col.detach(),
            "exp": loss_exp.detach(),
            "bri": loss_bri.detach(),
            "mod": loss_mod.detach(),
        }
        dynamic_lambdas = self.dynamic_reweight(losses=losses)

        total = (
            dynamic_lambdas["spa"] * loss_spa +
            dynamic_lambdas["col"] * loss_col +
            dynamic_lambdas["exp"] * loss_exp +
            dynamic_lambdas["bri"] * loss_bri +
            dynamic_lambdas["mod"] * loss_mod
        )

        self.log_dict(
            dictionary={
                "train/01_spa": loss_spa,
                "train/02_col": loss_col,
                "train/03_exp": loss_exp,
                "train/04_bri": loss_bri,
                "train/05_mod": loss_mod,
                "train/06_total": total,
                "train/07_weight_spa": dynamic_lambdas["spa"],
                "train/08_weight_col": dynamic_lambdas["col"],
                "train/09_weight_exp": dynamic_lambdas["exp"],
                "train/10_weight_bri": dynamic_lambdas["bri"],
                "train/11_weight_mod": dynamic_lambdas["mod"],
            },
            prog_bar=True
        )

        return total

    def validation_step(self, batch, batch_idx):
        x = batch.to(self.device)
        t = torch.randint(
            low=0,
            high=self.timestep_range,
            size=(x.size(0),),
            device=self.device
        )
        dit_enh_img, modulations, i_x, i_x_out = self(x, t)

        loss_spa = torch.mean(input=self.spa_loss(x, dit_enh_img))
        loss_col = torch.mean(input=self.col_loss(dit_enh_img))
        loss_exp = torch.mean(input=self.exp_loss(dit_enh_img))

        loss_bri = torch.mean(
            input=self.bri_loss(
                i_x,
                i_x_out,
                t
            )
        )
        loss_mod = torch.mean(
            input=self.mod_loss(
                modulations
            )
        )

        # --- Dynamic Reweighting
        losses = {
            "spa": loss_spa.detach(),
            "col": loss_col.detach(),
            "exp": loss_exp.detach(),
            "bri": loss_bri.detach(),
            "mod": loss_mod.detach(),
        }
        dynamic_lambdas = self.dynamic_reweight(losses=losses)

        total = (
            dynamic_lambdas["spa"] * loss_spa +
            dynamic_lambdas["col"] * loss_col +
            dynamic_lambdas["exp"] * loss_exp +
            dynamic_lambdas["bri"] * loss_bri +
            dynamic_lambdas["mod"] * loss_mod
        )

        self.log_dict(
            dictionary={
                "valid/01_spa": loss_spa,
                "valid/02_col": loss_col,
                "valid/03_exp": loss_exp,
                "valid/04_bri": loss_bri,
                "valid/05_mod": loss_mod,
                "valid/06_total": total,
                "valid/07_weight_spa": dynamic_lambdas["spa"],
                "valid/08_weight_col": dynamic_lambdas["col"],
                "valid/09_weight_exp": dynamic_lambdas["exp"],
                "valid/10_weight_bri": dynamic_lambdas["bri"],
                "valid/11_weight_mod": dynamic_lambdas["mod"],
            },
            prog_bar=True
        )

        return total

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        dit_enh_img, modulations, i_x, i_x_out = self(x, t)

        metrics = self.metric.full(preds=dit_enh_img, targets=x)

        self.log_dict(
            dictionary={
                "bench/01_PSNR": metrics["PSNR"],
                "bench/02_SSIM": metrics["SSIM"],
                "bench/03_LPIPS": metrics["LPIPS"],
                "bench/04_NIQE": metrics["NIQE"],
                "bench/05_BRISQUE": metrics["BRISQUE"],
            },
            prog_bar=True
        )
        return metrics

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch.to(self.device)
        t = torch.zeros(
            size=(x.size(0),),
            dtype=torch.long,
            device=self.device
        )
        dit_enh_img, modulations, i_x, i_x_out = self(x, t)
        return dit_enh_img

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams['lr']
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=10,           # 최초 주기 (epoch 수 기준) - 예를 들어 10 에폭마다 리셋
            T_mult=2,         # 주기를 2배로 늘려가면서 반복
            eta_min=1e-7      # 최소 학습률
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",  # 매 에폭마다 스케줄러 업데이트
                "frequency": 1,
            }
        }


# main

In [None]:
import os
import random
import torch

from pathlib import Path


## hparams

In [None]:
def get_hparams():
    hparams = {
        # 모델 구조
        "image_size": 512,
        "hidden_size": 768,
        "patch_size": 16,  # 32
        "depth": 12,
        "num_heads": 12,
        "in_channels": 3,
        "out_channels": 1,

        # 손실 함수 가중치 (losses.py 기준)
        "lambda_col": 10.0,
        "lambda_exp": 0,
        "lambda_spa": 1000.0,

        "lambda_bri": 0,
        "lambda_mod": 1.0,

        # timestep range (diffusion time embedding)
        "timestep_range": 10000,

        # 최적화 및 학습 설정
        "lr": 1e-4,
        "epochs": 100,
        "batch_size": 16,
        "seed": random.randint(a=0, b=1000),

        # 데이터 경로
        "train_data_path": "data/1_train",
        "valid_data_path": "data/2_valid",
        "bench_data_path": "data/3_bench",
        "infer_data_path": "data/4_infer",

        # 로깅 설정
        "log_dir": "./runs2",
        "experiment_name": "HomomorphicDiT",
        "inference": "inference"
    }
    return hparams


## main

In [None]:
def main():
    global hparams, model_class, lightning_trainer

    model_class = HomomorphicDiTLightning
    hparams = get_hparams()

    print("[RUNNING] Trainer...")
    trainer = LightningTrainer(
        model=model_class(hparams=hparams),
        hparams=hparams
    )
    lightning_trainer = trainer.trainer
    trainer.run()


## train

In [None]:
main()


[RUNNING] Trainer...


Seed set to 127
/home/user/anaconda3/envs/jih_icicic/lib/python3.10/site-packages/lightning/fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


[INFO] Start training...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Output()

Metric valid/06_total improved. New best score: 0.000


## inference

In [None]:
path = Path(f"{lightning_trainer.log_dir}")
ckpts = path.glob(pattern="checkpoints/best*.ckpt")
hparams = path.glob(pattern="hparams.yaml")

for ckpt, hparam in zip(ckpts, hparams):
    print("[RUNNING] Inferencer...")
    inferencer = LightningInferencer(
        model=model_class,
        trainer=lightning_trainer,
        ckpt=ckpt,
        hparams=hparam,
    )
    inferencer.run()


[RUNNING] Inferencer...
save_dir: runs2/HomomorphicDiT/version_2/inference
[INFO] Start training...


Restoring states from the checkpoint path at runs2/HomomorphicDiT/version_2/checkpoints/best-epoch=00.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from the checkpoint at runs2/HomomorphicDiT/version_2/checkpoints/best-epoch=00.ckpt


Output()

[INFO] Inference completed.


In [None]:
for ckpt, hparam in zip(ckpts, hparams):
    print("[RUNNING] Inferencer...")
    inferencer = LightningBenchmarker(
        model=model_class,
        trainer=lightning_trainer,
        ckpt=ckpt,
        hparams=hparam,
    )
    inferencer.run()
