In [1]:
import pandas as pd
import torch
import torchvision
from models import Generator
import torchvision.transforms as transforms
from PIL import Image
from torchvision.transforms.functional import to_tensor

## Set the seed for reproducibility
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(42)

In [2]:
## Set the device
DEVICE = "cuda"

## Load the model
model = Generator(upscale_factor=4).to(DEVICE)

## Load the model weights state dict
state_dict = torch.load('../train_2_112220/netG_4x_epoch5.pth.tar')  # Load model weights

## Load the model from state dict
model.load_state_dict(state_dict["model"])

## Set the model to evaluation mode
model.eval()

Generator(
  (initial): ConvBlock(
    (cnn): SeperableConv2d(
      (depthwise): Conv2d(3, 3, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), groups=3)
      (pointwise): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (bn): Identity()
    (act): PReLU(num_parameters=64)
  )
  (residual): Sequential(
    (0): ResidualBlock(
      (block1): ConvBlock(
        (cnn): SeperableConv2d(
          (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (pointwise): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): PReLU(num_parameters=64)
      )
      (block2): ConvBlock(
        (cnn): SeperableConv2d(
          (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (pointwise): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=Fa

In [4]:
# Load an image
hr_image = Image.open('../dataset/valid/00001273_000.png').convert('RGB')

## Create the LR image transformer by downsampling the HR image and applying bicubic interpolation
lr_scale = transforms.Resize((256,256), interpolation=Image.BICUBIC)

## Create the restored HR image tranformer (simple classical method) by upsampling the LR image and applying bicubic interpolation
hr_scale = transforms.Resize((1024,1024), interpolation=Image.BICUBIC)

## Create the LR Image from the original HR Image using the LR Image transformer
lr_image = lr_scale(hr_image)

## Create the restored HR Image from the LR Image using the classical method of restored HR Image transforms
hr_restore_img = hr_scale(lr_image)

# Move the image and model to GPU if available
lr_image = to_tensor(lr_image).cuda()
lr_image = lr_image.unsqueeze(0)

lr_image.shape

# Perform model inference
with torch.no_grad():
    output = model(lr_image)

In [5]:
## Remove the batch dimension
out = output.squeeze(0)

## Transforms for displaying the images
display_transform = transforms.Compose([
    transforms.ToPILImage(),
])

## Transform the output image
out = display_transform(out)

## Save the output image
out.save("output.png")
hr_image.save("input.png")
hr_restore_img.save("naive.png")

## Create data split

In [5]:
import random
import os
import glob
import pickle

In [6]:
## Get list of all images in full dataset
all_images_list = glob.glob(f"./data/*/*/*.png", recursive=False)

## Shuffle the data
random.shuffle(all_images_list)

In [7]:
train_images = all_images_list[:90000]
test_images = all_images_list[90000:]

In [8]:
## Save the train and test split
with open("./dataset/train_images.pkl", "wb") as fp:
    pickle.dump(train_images, fp)
with open("./dataset/val_images.pkl", "wb") as fp:
    pickle.dump(test_images, fp)

In [9]:
all_images_list

['./data/images_004/images/00006769_016.png',
 './data/images_008/images/00016739_000.png',
 './data/images_002/images/00001409_002.png',
 './data/images_004/images/00007185_013.png',
 './data/images_011/images/00027765_000.png',
 './data/images_006/images/00012219_002.png',
 './data/images_011/images/00027875_002.png',
 './data/images_002/images/00003348_006.png',
 './data/images_010/images/00021441_002.png',
 './data/images_008/images/00017835_004.png',
 './data/images_003/images/00004808_090.png',
 './data/images_006/images/00012094_007.png',
 './data/images_003/images/00004944_000.png',
 './data/images_005/images/00009465_002.png',
 './data/images_004/images/00006674_008.png',
 './data/images_008/images/00018055_039.png',
 './data/images_003/images/00005580_001.png',
 './data/images_002/images/00003510_005.png',
 './data/images_007/images/00014846_000.png',
 './data/images_002/images/00003665_010.png',
 './data/images_009/images/00018489_000.png',
 './data/images_004/images/0000882