In [12]:
!pip install torch



In [8]:
import os
import cv2
import torch
import numpy as np
import RRDBNet_arch as arch

# Define paths for uploaded and result images
upload_folder = 'uploads/'
result_folder = 'results/' # Path to save the super-resolution results
os.makedirs(result_folder, exist_ok=True)


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Path to the model directory (adjust based on your path)
models_dir = 'models'

def load_model(model_name, model_dir, device):
    model_path = os.path.join(model_dir, model_name)
    model = arch.RRDBNet(3, 3, 64, 23, gc=32)  # Define the architecture
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()  # Set the model to evaluation mode
    model = model.to(device)
    return model

# Load the model
model_name = 'RRDB_ESRGAN_x4.pth'  # Use the ESRGAN model
model = load_model(model_name, models_dir, device)


  model.load_state_dict(torch.load(model_path), strict=True)


In [10]:
def super_resolution(path_img, device, model):
    base = os.path.splitext(os.path.basename(path_img))[0]  # Get the base name of the image
    img = cv2.imread(path_img)  # Load the low-resolution image
    img = img * 1.0 / 255  # Normalize the image
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()  # Convert to torch tensor
    LR = img.unsqueeze(0).to(device)  # Add a batch dimension and move to the correct device (CPU/GPU)

    # Perform super-resolution
    with torch.no_grad():
        result = model(LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
    
    # Rearrange channels and scale back to image format
    result = np.transpose(result[[2, 1, 0], :, :], (1, 2, 0))
    result = (result * 255.0).round()
    
    # Save the result image
    result_path = os.path.join(result_folder, f'{base}_sr.png')
    cv2.imwrite(result_path, result)
    return result


In [11]:
# Loop through all images in the upload folder and perform super-resolution
for img_file in os.listdir(upload_folder):
    if img_file.endswith('.jpg'):  # Only process JPG images
        img_path = os.path.join(upload_folder, img_file)  # Get the full path of the image
        super_resolution(img_path, device, model)  # Apply super-resolution
