In [4]:
import os
import glob
import cv2
import numpy as np
import torch
import RRDBNet_arch as arch
import sys

# Check if we're running in a Jupyter environment
if 'ipykernel' in sys.modules:
    # Running in Jupyter, use predefined arguments
    args = {
        'model_path': 'models/RRDB_ESRGAN_x4.pth',
        'input_path': 'LR/*',
        'output_path': 'results'
    }
else:
    # Use argparse if not in Jupyter
    import argparse
    parser = argparse.ArgumentParser(description="Super-Resolution using RRDBNet")
    parser.add_argument('--model_path', type=str, default='models/RRDB_ESRGAN_x4.pth', help="Path to the pre-trained model.")
    parser.add_argument('--input_path', type=str, default='LR/*', help="Path to the low-resolution images.")
    parser.add_argument('--output_path', type=str, default='results', help="Directory to store the super-resolved images.")
    args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Device selection (CUDA or CPU)

# Initialize the model
model = arch.RRDBNet(3, 3, 64, 23, gc=32)  # Initialize the RRDBNet model
model.load_state_dict(torch.load(args['model_path']), strict=True)  # Load the model weights
model.eval()  # Set the model to evaluation mode
model = model.to(device)  # Move the model to the chosen device (GPU/CPU)

print(f'Model path: {args["model_path"]}. \nTesting...')

# Ensure output directory exists
os.makedirs(args['output_path'], exist_ok=True)

# Iterate over all test images
idx = 0
for path in glob.glob(args['input_path']):  # Use input path from the command-line argument
    idx += 1
    base = os.path.splitext(os.path.basename(path))[0]  # Extract the base name of the image file
    print(f'{idx} - {base}')
    
    # Read image
    img = cv2.imread(path, cv2.IMREAD_COLOR)  # Read the image in color
    img = img * 1.0 / 255  # Normalize to [0, 1] range
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()  # Convert to tensor and change channels to RGB
    img_LR = img.unsqueeze(0)  # Add batch dimension
    img_LR = img_LR.to(device)  # Move the image to the device (GPU/CPU)

    # Run the model with no gradient calculation (for testing)
    with torch.no_grad():
        output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()  # Get the output and move it to CPU
    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # Convert from CHW to HWC format
    output = (output * 255.0).round()  # Rescale to [0, 255] and round to nearest integer
    
    # Save the super-resolved image to the specified output directory
    output_path = os.path.join(args['result'], f'{base}_rlt.png')
    cv2.imwrite(output_path, output)  # Save the result
    print(f'Saved super-resolved image: {output_path}')


Model path: models/RRDB_ESRGAN_x4.pth. 
Testing...
