In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
import os
import sys

MODULE_PATH = os.path.abspath(os.path.join("..", "src"))

if MODULE_PATH not in sys.path:
    sys.path.append(MODULE_PATH)

In [None]:
import numpy as np
from PIL import Image
from tqdm import tqdm

from natsort import natsorted

import torch
import torch.nn.functional as F
from torchvision import transforms

from config import EXTERNAL_DATA_DIR, INTERIM_DATA_DIR, PROCESSED_DATA_DIR, PATCH_SIZE # type: ignore

In [None]:
DATA_PATH = os.path.join(EXTERNAL_DATA_DIR, "kodak")

INTERIM_DATA_PATH = os.path.join(INTERIM_DATA_DIR, "kodak")

PROCESSED_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, "kodak")

In [None]:
images = [img for img in os.listdir(INTERIM_DATA_PATH) if img.endswith('.png') or img.endswith('.jpg')]

images = natsorted(images)

images = images

print(f'The total number of images are: {len(images)}.')

In [None]:
print('-'*50)
print('The shape of the images in the Kodak dataset are:')
print('-'*50, end='\n\n')

for img in images:
    image = Image.open(os.path.join(DATA_PATH, img))
    image = np.array(image)
    print(f'{img}:{image.shape}', end=' | ')

In [None]:
if not os.path.exists(INTERIM_DATA_PATH):
    os.makedirs(INTERIM_DATA_PATH)

In [None]:
for img in tqdm(natsorted(os.listdir(DATA_PATH)), desc="Processing images"):
    image = Image.open(os.path.join(DATA_PATH, img))
    image = np.array(image)
    image = np.clip(image, 0, 255).astype(np.uint8)
    image = Image.fromarray(image)
    image = image.resize((256, 256))
    image = image.convert('RGB')
    image.save(os.path.join(INTERIM_DATA_PATH, img))

In [None]:
print('-'*50)
print('The shape of the Semi-Processed images in the Kodak dataset are:')
print('-'*50, end='\n\n')

for img in images:
    image = Image.open(os.path.join(INTERIM_DATA_PATH, img))
    image = np.array(image)
    print(f'{img}:{image.shape}', end=' | ')

In [None]:
if not os.path.exists(PROCESSED_DATA_PATH):
    os.makedirs(PROCESSED_DATA_PATH)

In [None]:
PATCH_OUTPUT_DIR = os.path.join(PROCESSED_DATA_DIR, 'patches')

if not os.path.exists(PATCH_OUTPUT_DIR):
    os.makedirs(PATCH_OUTPUT_DIR)

In [None]:
for img in tqdm(natsorted(os.listdir(INTERIM_DATA_PATH)), desc="Final Copy of Processed Images"):
    if img.endswith('.png') or img.endswith('.jpg'):
        image = Image.open(os.path.join(INTERIM_DATA_PATH, img))
        image.save(os.path.join(PROCESSED_DATA_PATH, img))

In [None]:
for img in tqdm(natsorted(os.listdir(INTERIM_DATA_PATH)), desc="Extracting and Saving Patches"):

    # Load and convert the image
    if img.endswith('.png') or img.endswith('.jpg'):
        image = Image.open(os.path.join(INTERIM_DATA_PATH, img))
        image = image.convert('RGB')
        tensor = transforms.ToTensor()(image)
        tensor = tensor.unsqueeze(0)
        B, C, H, W = tensor.shape

        # Extract patches
        patches = F.unfold(tensor, PATCH_SIZE, 4)
        patches = patches.transpose(1, 2).contiguous()
        patches = patches.view(B, -1, C, PATCH_SIZE * PATCH_SIZE)
        patches = patches.squeeze(0)

        # Create folders to save patches
        folder = os.path.join(PATCH_OUTPUT_DIR, img)
        os.makedirs(folder, exist_ok=True)

        # Save patches
        for i, patch in enumerate(patches):
            patch = transforms.ToPILImage()(patch)
            patch.save(os.path.join(folder, f'patch-{i}.png'))