In [None]:
import sys

import torch

sys.path.append("..")


In [None]:
from data.dataloader import LowLightDataModule
from model.blocks.homomorphic import ImageComposition, ImageDecomposition
from utils.utils import show_batch, summarize_model


In [None]:
data_module = LowLightDataModule(
    train_dir="../data/1_train",
    valid_dir="../data/2_valid",
    bench_dir="../data/3_bench",
    infer_dir="../data/4_infer",
    image_size=512,
    batch_size=1,
    num_workers=4,
)

data_module.setup(stage="fit")


In [None]:
train_dataloader = data_module.train_dataloader()


In [None]:
train_low, train_high = next(iter(train_dataloader))
print(train_low.shape)
print(train_high.shape)
show_batch(images=train_low)
show_batch(images=train_high)


In [None]:
decompose = ImageDecomposition(offset=0.5, cutoff=0.1)
compose = ImageComposition(offset=0.5)


In [None]:
train_low = train_low.cuda()
decompose = decompose.cuda()
compose = compose.cuda()


In [None]:
luminance, chroma_red, chroma_blue, illuminance, reflectance = decompose(train_low)
rgb = compose(chroma_red, chroma_blue, luminance)


In [None]:
show_batch(images=train_low)
show_batch(
    images=torch.cat(
        tensors=[luminance, chroma_red, chroma_blue, illuminance, reflectance], dim=0
    )
)

show_batch(images=rgb)


In [None]:
summarize_model(model=decompose, input_size=(1, 3, 256, 256))


In [None]:
input_tensors = [
    torch.randn(1, 1, 256, 256, device="cuda"),  # cr
    torch.randn(1, 1, 256, 256, device="cuda"),  # cb
    torch.randn(1, 1, 256, 256, device="cuda"),  # y
]
summarize_model(model=compose, input_data=input_tensors)
