In [None]:
import sys

import torch

sys.path.append("..")


In [None]:
from data.dataloader import LowLightDataModule
from model.blocks.featurerestorer import FeatureRestorationBlock
from model.blocks.homomorphic import ImageDecomposition
from utils.utils import show_batch, summarize_model


In [None]:
data_module = LowLightDataModule(
    train_dir="../data/1_train",
    valid_dir="../data/2_valid",
    bench_dir="../data/3_bench",
    infer_dir="../data/4_infer",
    image_size=512,
    batch_size=1,
    num_workers=4,
)

data_module.setup(stage="fit")


In [None]:
train_dataloader = data_module.train_dataloader()


In [None]:
train_low, train_high = next(iter(train_dataloader))
print(train_low.shape)
print(train_high.shape)
show_batch(images=train_low)
show_batch(images=train_high)


In [None]:
decompose = ImageDecomposition(offset=0.5, cutoff=0.1)


In [None]:
train_low = train_low.cuda()
decompose = decompose.cuda()


In [None]:
luminance, chroma_red, chroma_blue, illuminance, reflectance = decompose(train_low)


In [None]:
restore = FeatureRestorationBlock(
    in_channels=1,
    out_channels=1,
    embed_dim=32,
    num_heads=8,
    mlp_ratio=4,
    dropout_ratio=0.2,
)


In [None]:
restore = restore.cuda()


In [None]:
chroma_red, chroma_blue = restore(chroma_red, chroma_blue)


In [None]:
show_batch(images=torch.cat(tensors=[chroma_red, chroma_blue], dim=0))


In [None]:
input_tensors = [
    torch.randn(1, 1, 256, 256, device="cuda"),  # cr
    torch.randn(1, 1, 256, 256, device="cuda"),  # cb
]
summarize_model(model=restore, input_data=input_tensors)
