In [3]:
import os
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
from torch.utils.data import Dataset
from utils.common import DATA_ROOT

class HDRTrainTriplet(Dataset):
    """
    CMOS (RGB, full res), SPC (mono, half res), GT_HDR (RGB full res).
    Paired by identical filename stems.
    """

    def __init__(self):
        self.cmos_dir = (
            DATA_ROOT
            / "Training_1024x512"
            / "CMOS_sat_1024x512"
            / "dynamic_exposures_train_png_1024x512"
        )

        self.spc_dir = (
            DATA_ROOT
            / "Training_1024x512"
            / "SPC_512x256_train_png"
        )

        self.gt_dir = (
            DATA_ROOT
            / "GT_HDR_1024X512_train_png"
        )

        self.to_rgb = T.ToTensor()
        self.to_gray = T.ToTensor()

        self.items = self._pair_lists()

    def _pair_lists(self):
        cmos_files = {
            Path(f).stem: self.cmos_dir / f
            for f in os.listdir(self.cmos_dir)
            if f.lower().endswith(".png")
        }
        spc_files = {
            Path(f).stem: self.spc_dir / f
            for f in os.listdir(self.spc_dir)
            if f.lower().endswith(".png")
        }
        gt_files = {
            Path(f).stem: self.gt_dir / f
            for f in os.listdir(self.gt_dir)
            if f.lower().endswith(".png")
        }

        common = sorted(set(cmos_files) & set(spc_files) & set(gt_files))
        if not common:
            raise RuntimeError("No matching CMOS/SPC/GT triplets found")

        return [
            {
                "name": stem,
                "cmos_path": cmos_files[stem],
                "spc_path":  spc_files[stem],
                "gt_path":   gt_files[stem],
            }
            for stem in common
        ]

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        e = self.items[idx]

        cmos_img = Image.open(e["cmos_path"]).convert("RGB")
        spc_img  = Image.open(e["spc_path"]).convert("L")
        gt_img   = Image.open(e["gt_path"]).convert("RGB")

        cmos_t = self.to_rgb(cmos_img)   # (3,512,1024)
        spc_t  = self.to_gray(spc_img)   # (1,256,512)
        gt_t   = self.to_rgb(gt_img)     # (3,512,1024)

        return {
            "name": e["name"],
            "cmos": cmos_t,
            "spc": spc_t,
            "gt": gt_t,
        }


In [4]:
import torch
from torch.utils.data import DataLoader, Subset
from utils.common import DEVICE, RECON_DIR
from models.feature_extractors import DilatedConvEncoder, HighResCMOSEncoder
from models.decoder import SmallDecoder
from models.fusion import SimpleFusion
from utils.metrics import mse_loss, psnr, ssim
from utils.viz import save_panel
from utils.save_intermediate import tensor_to_pngimg
from pathlib import Path

# Use the dataset with CMOS, SPC, and GT HDR
full_train = HDRTrainTriplet()
subset_ids = list(range(min(5, len(full_train))))  # use 5 samples to debug
subset = Subset(full_train, subset_ids)
loader = DataLoader(subset, batch_size=1, shuffle=True)

print("Subset size:", len(subset), "Device:", DEVICE)

spc_encoder  = DilatedConvEncoder(in_channels=1).to(DEVICE).eval()
cmos_encoder = HighResCMOSEncoder().to(DEVICE).eval()
fusion_head  = SimpleFusion().to(DEVICE)
decoder      = SmallDecoder(in_channels=6).to(DEVICE)

with torch.no_grad():
    fusion_head.fuse_conv.weight.zero_()
    fusion_head.fuse_conv.bias.zero_()
    eye = torch.eye(3).view(3,3,1,1)
    fusion_head.fuse_conv.weight[:, :3] = eye


params = list(fusion_head.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=5e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

Subset size: 5 Device: cuda


In [5]:
EPOCHS = 600 # start with 50; you can reduce to 10 for quick tests
for ep in range(EPOCHS):
    running = 0.0
    for batch in loader:
        cmos = batch["cmos"].to(DEVICE)
        spc  = batch["spc"].to(DEVICE)
        gt   = batch["gt"].to(DEVICE)

        # Extract features (frozen)
        with torch.no_grad():
            cmos_feat = cmos_encoder(cmos)
            spc_feat  = spc_encoder(spc)

        fused_feat = fusion_head(cmos_feat, spc_feat)
        decoder_in = torch.cat([fused_feat, cmos], dim=1) 
        pred = decoder(decoder_in).clamp(0,1)

        loss = mse_loss(pred, gt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running += loss.item()
    if (ep+1) % 50 == 0:
        current_lr = scheduler.get_last_lr()[0]
        print(f"Learning rate adjusted to: {current_lr:.6f}")
        print(f"[epoch {ep+1}/{EPOCHS}] avg MSE: {running/len(loader):.6f}")
    scheduler.step()


Learning rate adjusted to: 0.005000
[epoch 50/600] avg MSE: 0.018118
Learning rate adjusted to: 0.005000
[epoch 100/600] avg MSE: 0.015765
Learning rate adjusted to: 0.005000
[epoch 150/600] avg MSE: 0.014501
Learning rate adjusted to: 0.005000
[epoch 200/600] avg MSE: 0.013085
Learning rate adjusted to: 0.002500
[epoch 250/600] avg MSE: 0.011285
Learning rate adjusted to: 0.002500
[epoch 300/600] avg MSE: 0.012057
Learning rate adjusted to: 0.002500
[epoch 350/600] avg MSE: 0.010235
Learning rate adjusted to: 0.002500
[epoch 400/600] avg MSE: 0.009873
Learning rate adjusted to: 0.001250
[epoch 450/600] avg MSE: 0.009598
Learning rate adjusted to: 0.001250
[epoch 500/600] avg MSE: 0.009331
Learning rate adjusted to: 0.001250
[epoch 550/600] avg MSE: 0.009516
Learning rate adjusted to: 0.001250
[epoch 600/600] avg MSE: 0.009132


In [6]:
decoder.eval()
fusion_head.eval()

results_table = []
for batch in loader:
    name = batch["name"][0]
    cmos = batch["cmos"].to(DEVICE)
    spc  = batch["spc"].to(DEVICE)
    gt   = batch["gt"].to(DEVICE)

    with torch.no_grad():
        cmos_feat = cmos_encoder(cmos)
        spc_feat  = spc_encoder(spc)
        fused_feat = fusion_head(cmos_feat, spc_feat)
        decoder_in = torch.cat([fused_feat, cmos], dim=1)  # (1,6,H,W)
        pred = decoder(decoder_in).clamp(0, 1)

    mse_val  = mse_loss(pred, gt).item()
    psnr_val = psnr(pred, gt, max_val=1.0)
    ssim_val = ssim(pred, gt)

    results_table.append((name, mse_val, psnr_val, ssim_val))

    recon_img = tensor_to_pngimg(pred[0].cpu())
    recon_path = RECON_DIR / f"{name}_recon.png"
    recon_img.save(recon_path)
    gt_img = tensor_to_pngimg(gt[0].cpu())
    gt_path = RECON_DIR / f"{name}_GT.png"
    gt_img.save(gt_path)


    panel_path = save_panel(
        cmos[0].cpu(),
        spc[0].cpu(),
        gt[0].cpu(),
        pred[0].cpu(),
        name
    )

    print(f"[EVAL] {name}: MSE={mse_val:.6f}, PSNR={psnr_val:.2f}, SSIM={ssim_val:.4f}")
    print(f"       recon -> {recon_path}")
    print(f"       panel -> {panel_path}")

results_table

from utils.save_intermediate import tensor_to_pngimg

gt_img = tensor_to_pngimg(gt[0].cpu())
gt_out_path = RECON_DIR / f"{name}_GT.png"
gt_img.save(gt_out_path)
print("       GT ->", gt_out_path)



[EVAL] 9C4A0001-5e832da4cc: MSE=0.002270, PSNR=26.44, SSIM=0.7350
       recon -> outputs\reconstructions\9C4A0001-5e832da4cc_recon.png
       panel -> outputs\visuals_for_review\9C4A0001-5e832da4cc_panel.png
[EVAL] 9C4A0001-beb39950ec: MSE=0.005068, PSNR=22.95, SSIM=0.7169
       recon -> outputs\reconstructions\9C4A0001-beb39950ec_recon.png
       panel -> outputs\visuals_for_review\9C4A0001-beb39950ec_panel.png
[EVAL] 9C4A0001-6fbef8172f: MSE=0.009346, PSNR=20.29, SSIM=0.7910
       recon -> outputs\reconstructions\9C4A0001-6fbef8172f_recon.png
       panel -> outputs\visuals_for_review\9C4A0001-6fbef8172f_panel.png
[EVAL] 9C4A0001-c6c6bf7c76: MSE=0.022056, PSNR=16.56, SSIM=0.6274
       recon -> outputs\reconstructions\9C4A0001-c6c6bf7c76_recon.png
       panel -> outputs\visuals_for_review\9C4A0001-c6c6bf7c76_panel.png
[EVAL] 9C4A0001-7c62497929: MSE=0.005467, PSNR=22.62, SSIM=0.7659
       recon -> outputs\reconstructions\9C4A0001-7c62497929_recon.png
       panel -> outputs\visu

In [7]:
with torch.no_grad():
    for batch in loader:
        cmos_feat = cmos_encoder(batch["cmos"].to(DEVICE))
        spc_feat  = spc_encoder(batch["spc"].to(DEVICE))
        fused_feat = fusion_head(cmos_feat, spc_feat)
        print("Feature mean/std:", fused_feat.mean().item(), fused_feat.std().item())
        break


Feature mean/std: 0.0010192279005423188 0.03301551565527916


In [8]:
from PIL import Image
import numpy as np

sample_path = list((DATA_ROOT / "GT_HDR_1024X512_train_png").glob("*.png"))[0]
arr = np.asarray(Image.open(sample_path)) / 255.0
print("GT HDR min/max/mean:", arr.min(), arr.max(), arr.mean())


GT HDR min/max/mean: 0.0 1.0 0.12313276802013118
