In [2]:
import os
from PIL import Image
import numpy as np
import random
import scipy.io as sio
from torchvision.transforms.functional import to_tensor
import torch
from tqdm import tqdm


In [3]:
def extract_and_save_patches(lr_dir_x2, lr_dir_x4, hr_dir, save_dir, lr_patch_size=32, num_patches=10):
    lr_filenamesx2 = sorted(os.listdir(lr_dir_x2))
    lr_filenamesx4 = sorted(os.listdir(lr_dir_x4))
    hr_filenames = sorted([f.replace('x2', '') for f in lr_filenamesx2])  # Adjust based on naming convention if needed
    
    patch_id = 1
    
    for lr_filenamex2, lr_filenamex4,hr_filename in tqdm(zip(lr_filenamesx2, lr_filenamesx4, hr_filenames)):
        lr_image_pathx2 = os.path.join(lr_dir_x2, lr_filenamex2)
        lr_image_pathx4 = os.path.join(lr_dir_x4, lr_filenamex4)
        hr_image_path = os.path.join(hr_dir, hr_filename)
        
        lr_imagex2 = Image.open(lr_image_pathx2).convert('RGB')
        lr_imagex4 = Image.open(lr_image_pathx4).convert('RGB')
        hr_image = Image.open(hr_image_path).convert('RGB')
        
        lr_imagex2 = to_tensor(lr_imagex2)  # Converts to [0, 1] range
        lr_imagex4 = to_tensor(lr_imagex4)  # Converts to [0, 1] range
        hr_image = to_tensor(hr_image)  # Converts to [0, 1] range

        # print(torch.min(lr_image), torch.max(lr_image), torch.min(hr_image), torch.max(hr_image))
        
        for _ in range(num_patches):
            # Randomly select the top-left pixel of the LR patch
            x_lrx4 = random.randint(0, lr_imagex4.shape[2] - lr_patch_size)
            y_lrx4 = random.randint(0, lr_imagex4.shape[1] - lr_patch_size)
            
            x_lrx2 = x_lrx4 * 2
            y_lrx2 = y_lrx4 * 2

            # Calculate the corresponding top-left pixel of the HR patch
            x_hr = x_lrx4 * 4
            y_hr = y_lrx4 * 4
            
            # Extract the patches
            lr_patchx4 = lr_imagex4[:, y_lrx4:y_lrx4+lr_patch_size, x_lrx4:x_lrx4+lr_patch_size]
            lr_patchx2 = lr_imagex2[:, y_lrx2:y_lrx2+(lr_patch_size*2), x_lrx2:x_lrx2+(lr_patch_size*2)]
            hr_patch = hr_image[:, y_hr:y_hr+(lr_patch_size*4), x_hr:x_hr+(lr_patch_size*4)]

            if (torch.max(lr_patchx4) == torch.min(lr_patchx4)) or (torch.max(lr_patchx2) == torch.min(lr_patchx2)) or (torch.max(hr_patch) == torch.min(hr_patch)):
                print("Skip due to divide by 0 error")
                continue
            
            # Save the patches as .mat files
            patch_filename = os.path.join(save_dir, f'{patch_id}.mat')
            sio.savemat(patch_filename, {'lr_patchx2': lr_patchx2.numpy(), 'lr_patchx4': lr_patchx4.numpy(), 'hr_patch': hr_patch.numpy()})
            
            patch_id += 1

In [4]:
# Example usage
lr_dir_x2 = './DIV2K_train_LR_bicubic/X2/'
lr_dir_x4 = './DIV2K_train_LR_bicubic/X4/'
hr_dir = './DIV2K_train_HR/'
save_dir = './Train_Patches'
extract_and_save_patches(lr_dir_x2, lr_dir_x4, hr_dir, save_dir, lr_patch_size=64, num_patches=20)

13it [00:06,  2.14it/s]

Skip due to divide by 0 error
Skip due to divide by 0 error


40it [00:19,  2.02it/s]

In [None]:
import matplotlib.pyplot as plt

patch = sio.loadmat('./Train_Patches/11.mat' )['hr_patch']
patch = np.transpose(patch, (1, 2, 0))
plt.imshow(patch)

NameError: name 'sio' is not defined