In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import os


In [None]:
from torch.utils.data import DataLoader


In [None]:
from data.dataloader import CustomDataset
from data.utils import DataTransform
from model.blocks.contourlet import LaplacianPyramid, DirectionalFilterBank, ContourletTransform


In [None]:
transform=DataTransform(image_size=640)

dataset = CustomDataset(
    path="data/1_train/1_LOLdataset",
    transform=transform
)


In [None]:
dataloader = DataLoader(
    dataset=dataset,
    batch_size=16,
    shuffle=True,
    num_workers=int(os.cpu_count() * 0.9),
    persistent_workers=True,
    pin_memory=True
)


In [None]:
data = next(iter(dataloader))
print(data.shape)


In [None]:
def show_batch(images, ncols=8):
    nimgs = images.shape[0]
    nrows = (nimgs + ncols - 1) // ncols
    plt.figure(figsize=(ncols * 3, nrows * 3))
    for i in range(nimgs):
        plt.subplot(nrows, ncols, i + 1)
        plt.imshow(X=F.to_pil_image(pic=images[i]))
        plt.axis('off')
        plt.title(label=f"Image {i}")
    plt.tight_layout()
    plt.show()

show_batch(images=data)


In [None]:
lp  = LaplacianPyramid(
    in_channels=3,
    num_levels=4,
    filter_size=5,
    sigma=1.0
)

py, lc = lp(data)


In [None]:
print(len(py), len(lc))


In [None]:
for i, (p, l) in enumerate(zip(py, lc)):
    show_batch(images=p)
    show_batch(images=l)


In [None]:
dfb_1 = DirectionalFilterBank(
    in_channels=3,
    num_levels=1,
    filter_size=5,
    sigma=1.0,
    omega_x=0.25,
    omega_y=0.25
)

dfb_4 = DirectionalFilterBank(
    in_channels=3,
    num_levels=4,
    filter_size=5,
    sigma=1.0,
    omega_x=0.25,
    omega_y=0.25
)

subband_1 = dfb_1(lc[0])
print(len(subband_1))
for s in subband_1:
    show_batch(images=s)

subband_4 = dfb_4(lc[-1])
print(len(subband_4))
for s in subband_4:
    show_batch(images=s)


In [None]:
contourlet = ContourletTransform(
    in_channels=3,
    num_levels=4,
    filter_size=5,
    sigma=1.0,
    omega_x=0.25,
    omega_y=0.25
)

pyramid, subbands = contourlet(data)


In [None]:
for i, p in enumerate(iterable=pyramid):
    print("--------")
    print("p shape :", p.shape)


In [None]:
for i, subband in enumerate(iterable=subbands):
    print("--------")
    for i, s in enumerate(iterable=subband):
        print("s shape :", s.shape)
