In [None]:
import os
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from MESR_TEST import MESR

In [None]:
def load_model(model_path, device):
    model = MESR(in_channels=3, mid_channels=64, out_channels=3, num_blocks=12)
    # loading the weights of the model 
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    # evaluation mode to use the fixed running statistics
    model.eval()  
    print("Model loaded successfully!")
    return model

def preprocess_image(image_path, device):
    # stages of preprocessing used in the training
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  
        transforms.ToTensor(),      
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 
    ])
    
    # loading and preprocessing
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0)
    return input_tensor.to(device), image

def postprocess_output(output_tensor):
    # Undo normalization and convert to image
    output_tensor = output_tensor.squeeze(0).cpu().detach()
    # denormalization of the output tensor
    output_tensor = output_tensor * 0.5 + 0.5  
    output_image = transforms.ToPILImage()(output_tensor)
    return output_image

def visualize_results(input_image, enhanced_image):
    plt.figure(figsize=(10, 5))
    
    # display input
    plt.subplot(1, 2, 1)
    plt.title("Low-Resolution Input")
    plt.imshow(input_image)
    plt.axis("off")
    
    # display output
    plt.subplot(1, 2, 2)
    plt.title("High-Resolution Enhanced")
    plt.imshow(enhanced_image)
    plt.axis("off")
    plt.show()

def test_model(model_path, image_path, device): 
    """ load the model, preprocess the patch, run inference on the input tensor and postprocess """
    model = load_model(model_path, device)
    input_tensor, input_image = preprocess_image(image_path, device)

    with torch.no_grad():
        output_tensor = model(input_tensor)

    enhanced_image = postprocess_output(output_tensor)
    visualize_results(input_image, enhanced_image)

In [None]:
'''
Flow looks like this:
instantiate the model -> set in the eval mode to load the training weights -> apply transformations on the input -> 

preprocess_image() would yeild input tensor and image -> feed the tensor in model to get output tensor -> postprocess ->

get the output image -> display alongside the input image

'''

In [None]:
if __name__ == "__main__":

    model_path = "./model_output/best_model.pth"
    image_path = "./sample_image.jpg" 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    test_model(model_path, image_path, device)