In [1]:
import os
import glob
import h5py
import numpy as np
from PIL import Image
from tqdm import tqdm

# --- Parameters ---
HR_IMAGE_DIR = 'Data/T91' # Folder with your high-res training images (e.g., T91 dataset)
H5_OUTPUT_FILE = 'train_data.h5'
UPSCALE_FACTOR = 2
PATCH_SIZE = 32  # The size of the HR patches to be extracted
STRIDE = 14      # The step size to move across the image for patch extraction

# --- Script ---
def create_training_dataset(hr_image_dir, h5_output_file):
    """
    Preprocesses a directory of HR images to create an HDF5 dataset of LR/HR patch pairs.
    """
    hr_image_paths = glob.glob(os.path.join(hr_image_dir, '*.png'))
    
    lr_patches = []
    hr_patches = []

    print(f"Processing {len(hr_image_paths)} images...")
    
    for img_path in tqdm(hr_image_paths):
        try:
            # Open HR image and convert to Y channel
            hr_img = Image.open(img_path).convert('RGB')
            hr_img_ycbcr = hr_img.convert('YCbCr')
            hr_img_y, _, _ = hr_img_ycbcr.split()
            
            width, height = hr_img_y.size

            # Extract patches
            for i in range(0, height - PATCH_SIZE + 1, STRIDE):
                for j in range(0, width - PATCH_SIZE + 1, STRIDE):
                    # Define the box for cropping the HR patch
                    box = (j, i, j + PATCH_SIZE, i + PATCH_SIZE)
                    hr_patch = hr_img_y.crop(box)
                    
                    # Create the LR patch
                    # 1. Downscale
                    lr_size = (PATCH_SIZE // UPSCALE_FACTOR, PATCH_SIZE // UPSCALE_FACTOR)
                    lr_patch_downscaled = hr_patch.resize(lr_size, Image.BICUBIC)
                    # 2. Upscale back to original patch size (this is the model input)
                    lr_patch_upscaled = lr_patch_downscaled.resize(hr_patch.size, Image.BICUBIC)
                    
                    # Convert to numpy arrays and append
                    hr_patches.append(np.array(hr_patch))
                    lr_patches.append(np.array(lr_patch_upscaled))

        except Exception as e:
            print(f"Could not process {img_path}: {e}")

    print(f"Generated {len(hr_patches)} patches.")

    # Convert lists to numpy arrays
    lr_patches = np.array(lr_patches)
    hr_patches = np.array(hr_patches)
    
    # Add channel dimension (required by PyTorch Conv2d)
    # Shape becomes (num_patches, 1, height, width)
    lr_patches = np.expand_dims(lr_patches, axis=1)
    hr_patches = np.expand_dims(hr_patches, axis=1)

    # Save to HDF5 file
    with h5py.File(h5_output_file, 'w') as hf:
        hf.create_dataset('lr_images', data=lr_patches, compression="gzip", chunks=True)
        hf.create_dataset('hr_images', data=hr_patches, compression="gzip", chunks=True)
        
    print(f"Successfully saved dataset to {h5_output_file}")


# --- Run the function ---
if __name__ == '__main__':
    # Make sure the HR_IMAGE_DIR exists
    if not os.path.isdir(HR_IMAGE_DIR) or not os.listdir(HR_IMAGE_DIR):
         print(f"Error: Directory '{HR_IMAGE_DIR}' is empty or does not exist.")
         print("Please download a dataset like T91 and place the images there.")
    else:
        create_training_dataset(HR_IMAGE_DIR, H5_OUTPUT_FILE)

Processing 91 images...


100%|██████████| 91/91 [00:03<00:00, 27.23it/s]


Generated 22227 patches.
Successfully saved dataset to train_data.h5


In [2]:
# Define the model
import torch
from torch import nn

class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        # Layer 1: Patch extraction and representation
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=9, padding=4)
        # Layer 2: Non-linear mapping
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, padding=0)
        # Layer 3: Reconstruction
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=5, padding=2)
        
        # Activation function
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [3]:
# Load data
import h5py
from torch.utils.data import Dataset, DataLoader

# Assume you have created an HDF5 file named 'train_data.h5'
# with datasets 'lr_images' and 'hr_images'

class SRDataset(Dataset):
    def __init__(self, h5_file_path):
        super(SRDataset, self).__init__()
        self.h5_file = h5py.File(h5_file_path, 'r')
        self.inputs = self.h5_file['lr_images']
        self.labels = self.h5_file['hr_images']

    def __getitem__(self, index):
        # Convert to PyTorch tensors
        input_tensor = torch.from_numpy(self.inputs[index]).float()
        label_tensor = torch.from_numpy(self.labels[index]).float()
        return input_tensor, label_tensor

    def __len__(self):
        return len(self.inputs)

In [None]:
# Train model on loaded data
import torch.optim as optim
from tqdm import tqdm # For a nice progress bar

# --- Hyperparameters ---
LEARNING_RATE = 1e-4
BATCH_SIZE = 16
NUM_EPOCHS = 100

# --- Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = SRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- Data Loading ---
# Replace with the actual path to your HDF5 file
train_dataset = SRDataset(h5_file_path='train_data.h5')
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# --- Training Loop ---
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    
    # Use tqdm for a progress bar
    loop = tqdm(train_loader, leave=True)
    for data, target in loop:
        data, target = data.to(device), target.to(device)

        # Forward pass
        output = model(data)
        loss = criterion(output, target)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        # Update progress bar
        loop.set_description(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")
        loop.set_postfix(loss=loss.item())

    print(f"Epoch {epoch+1}, Average Loss: {running_loss / len(train_loader):.4f}")

# --- Save the model ---
torch.save(model.state_dict(), 'srcnn_model.pth')
print("Training complete. Model saved.")

Using device: cuda


Epoch [1/100]:  11%|█         | 150/1390 [02:04<14:59,  1.38it/s, loss=1.01e+3]

In [None]:
from PIL import Image
import torchvision.transforms as transforms
import numpy as np

# --- Load the trained model ---
model = SRCNN().to(device)
model.load_state_dict(torch.load('srcnn_model.pth'))
model.eval()

# --- Prepare the input image ---
# 1. Open a low-resolution image
lr_image_path = 'path/to/your/low_res_image.png'
img = Image.open(lr_image_path).convert('YCbCr')
y, cb, cr = img.split()

# 2. Upscale using bicubic interpolation (this is the model's input)
# Note: The original paper does this step before feeding to the network
y_bicubic = y.resize( (y.width * 2, y.height * 2), Image.BICUBIC) # Example upscale factor of 2

# 3. Convert the Y channel to a tensor
transform = transforms.Compose([
    transforms.ToTensor(),
])
input_tensor = transform(y_bicubic).unsqueeze(0).to(device)


# --- Run the model ---
with torch.no_grad():
    output_tensor = model(input_tensor)

# --- Post-process the output ---
output_tensor = output_tensor.cpu().squeeze(0)
output_img_data = output_tensor.numpy().clip(0, 1) * 255.0
output_img = Image.fromarray(np.uint8(output_img_data[0]), mode='L')

# --- Merge with original color channels (upscaled) ---
cb_bicubic = cb.resize(output_img.size, Image.BICUBIC)
cr_bicubic = cr.resize(output_img.size, Image.BICUBIC)

final_img = Image.merge('YCbCr', [output_img, cb_bicubic, cr_bicubic]).convert('RGB')
final_img.save('output_high_res_image.png')
print("Super-resolution complete. Image saved.")