In [None]:
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torchvision.transforms.functional import to_tensor
import sys

## Add the scripts folder to the path
sys.path.insert(0, '../scripts/')
from model_architecture import Generator

## Set the seed for reproducibility
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(42)

In [None]:
## Set the device
DEVICE = "cpu"

## Load the model
model = Generator(upscale_factor=4).to(DEVICE)

## Load the model weights state dict
state_dict = torch.load('../models/netG_4x_epoch5.pth.tar', map_location=torch.device(DEVICE))

## Load the model from state dict
model.load_state_dict(state_dict["model"], )

## Set the model to evaluation mode
model.eval()

In [None]:
# Load an image
hr_image = Image.open('../assets/sample_hr_input.png').convert('RGB')

## Create the LR image transformer by downsampling the HR image and applying bicubic interpolation
lr_scale = transforms.Resize((256,256), interpolation=Image.BICUBIC)

## Create the restored HR image tranformer (simple classical method) by upsampling the LR image and applying bicubic interpolation
hr_scale = transforms.Resize((1024,1024), interpolation=Image.BICUBIC)

## Create the LR Image from the original HR Image using the LR Image transformer
lr_image = lr_scale(hr_image)
lr_image.save("../assets/sample_lr_input.png")

## Create the restored HR Image from the LR Image using the classical method of restored HR Image transforms
hr_restore_img = hr_scale(lr_image)

## Convert the LR Image to a tensor
lr_image = to_tensor(lr_image)

# Move the image and model to GPU if available
if torch.cuda.is_available():
    lr_image = lr_image.cuda()

## Add a batch dimension to the image
lr_image = lr_image.unsqueeze(0)

lr_image.shape

# Perform model inference
with torch.no_grad():
    output = model(lr_image)

In [None]:
## Remove the batch dimension
out = output.squeeze(0)

## Transforms for displaying the images
display_transform = transforms.Compose([
    transforms.ToPILImage(),
])

## Transform the output image
out = display_transform(out)

## Save the output image
out.save("../assets/sample_sr_output.png")