In [None]:
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [None]:
# Download dataset from https://www.kaggle.com/datasets/soumikrakshit/div2k-high-resolution-images?resource=download and extract the two folders within the zip (DIV2K_train_HR and DIV2K_valid_HR) to datasets/
folder_path = 'datasets/DIV2K_train_HR/DIV2K_train_HR'
output_folder = 'datasets/DIV2K_patches'
os.makedirs(output_folder, exist_ok=True)

In [None]:
image_files = sorted([
    os.path.join(folder_path, f)
    for f in os.listdir(folder_path)
    if f.lower().endswith('.png')
])

print(f" Found {len(image_files)} images in DIV2K dataset")

for f in image_files[:10]:
    print(os.path.basename(f))

In [None]:
def extract_patches(img, patch_size=128, stride=64):
    h, w = img.shape[:2]
    patches = []
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            patch = img[y:y + patch_size, x:x + patch_size]
            patches.append(patch)
    return patches

In [None]:
patch_size = 128  #Change accordingly
stride = 128 #Change accordingly
count = 0

for img_path in tqdm(image_files, desc="Extracting patches"):
    img = cv2.imread(img_path)
    if img is None:
        print(f"Skipping unreadable file: {img_path}")
        continue

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    patches = extract_patches(img, patch_size, stride)

    # Save each extracted patch
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    for i, patch in enumerate(patches):
        save_path = os.path.join(output_folder, f"{base_name}_patch{i:03d}.png")
        cv2.imwrite(save_path, cv2.cvtColor(patch, cv2.COLOR_RGB2BGR))
        count += 1

In [None]:
print(f"\n Done! Created {count} image patches and saved them in:")
print(f"{output_folder}")