In [1]:
import torch
from utils import *
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

In [2]:
# Data parameters
srgan_checkpoint = "/home/mw/project/checkpoint_srgan.pth.tar"
srresnet_checkpoint = '/home/mw/project/checkpoint_srresnet.pth.tar'

HR_image_path = '/home/mw/input/dataset76853/benchmark/benchmark/Set14/HR/baboon.png'
LR_image_path = '/home/mw/input/dataset76853/benchmark/benchmark/Set14/LR_bicubic/X4/baboonx4.png'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")

In [3]:
# Load srresnet models
srresnet = torch.load(srresnet_checkpoint)['model'].to(device)
srresnet.eval()

  out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
  out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
  stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0, activation='LeakyReLu'))
  stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0, activation='LeakyReLu'))


SRResNet(
  (conv_block1): ConvolutionalBlock(
    (conv_block): Sequential(
      (0): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
      (1): PReLU(num_parameters=1)
    )
  )
  (residual_blocks): Sequential(
    (0): ResidualBlock(
      (conv_block1): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): PReLU(num_parameters=1)
        )
      )
      (conv_block2): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (1): ResidualBlock(
      (conv_block1): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

In [4]:
# Load srgan models
srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
srgan_generator.eval()

Generator(
  (net): SRResNet(
    (conv_block1): ConvolutionalBlock(
      (conv_block): Sequential(
        (0): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
        (1): PReLU(num_parameters=1)
      )
    )
    (residual_blocks): Sequential(
      (0): ResidualBlock(
        (conv_block1): ConvolutionalBlock(
          (conv_block): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): PReLU(num_parameters=1)
          )
        )
        (conv_block2): ConvolutionalBlock(
          (conv_block): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
      )
      (1): ResidualBlock(
        (conv_block1): ConvolutionalBlock(
          (conv_block): Seque

In [5]:
hr_img = Image.open(HR_image_path, mode="r")
hr_img = hr_img.convert('RGB')
lr_img = Image.open(LR_image_path, mode="r")
lr_img = lr_img.convert('RGB')

In [6]:
# Bicubic Upsampling
bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)
# Super-resolution (SR) with SRGAN
sr_img_srgan = srgan_generator(convert_image(lr_img, source='pil', target='imagenet-norm', device=device).unsqueeze(0).to(device))
sr_img_srgan = sr_img_srgan.squeeze(0).detach()
sr_img_srgan = sr_img_srgan.cpu()  # 将Tensor从GPU转移到CPU  
sr_img_srgan = convert_image(sr_img_srgan, source='[-1, 1]', target='pil', device=device)

In [7]:
# Super-resolution (SR) with SRResNet
sr_img_srresnet = srresnet(convert_image(lr_img, source='pil', target='imagenet-norm', device=device).unsqueeze(0).to(device))
sr_img_srresnet = sr_img_srresnet.squeeze(0).detach()
sr_img_srresnet = sr_img_srresnet.cpu()  # 将Tensor从GPU转移到CPU 
sr_img_srresnet = convert_image(sr_img_srresnet, source='[-1, 1]', target='pil', device=device)

In [8]:
# Visualization
plt.figure('Bicubic')
plt.imshow(bicubic_img)
plt.axis('off')
plt.title('Bicubic')
plt.figure('SRRESNET')
plt.imshow(sr_img_srresnet)
plt.axis('off')
plt.title('SRRESNET')
plt.figure('SRGAN')
plt.imshow(sr_img_srgan)
plt.axis('off')
plt.title('SRGAN')
plt.figure('Original HR')
plt.imshow(hr_img)
plt.axis('off')
plt.title('Original HR')
plt.show()