In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from time import time

from dataloader import ImageDataset
from model import *

np.random.seed(69)
torch.manual_seed(69)

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
print(f'Using: {device}')
if str(device) == 'cuda': print(torch.cuda.get_device_name()) 

In [None]:
IMAGE_SIZE = 1024
model_name = 'SRTransformer6_best_86.pth'

test_set = ImageDataset("data/test/", 2, size=IMAGE_SIZE)

# batch_size 1 for the proper PSNR calculation
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

# init model
model = SRTransformer6()
model.to(device)

# loss function
loss_fn = nn.MSELoss()
loss_l1 = nn.L1Loss()

checkpoint = torch.load('models/' + model_name, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

from utils import count_trainable
print('Model parameters:', count_trainable(model))

In [None]:
from math import log10
with torch.no_grad():
    t0 = time()
    test_loss = 0
    test_loss_l1 = 0
    psnr = 0
    for batch in test_loader:
        # load data to the device
        x, y = batch[0].to(device), batch[1].to(device)
        out = model.forward(x)
        loss = loss_fn(out, y)
        test_loss += loss.item()
        test_loss_l1 += loss_l1(out, y).item()
        psnr += 10 * log10(1 / loss.item())
    test_loss /= len(test_loader)
    test_loss_l1 /= len(test_loader)
    psnr /= len(test_loader)
    print(f'PSNR: {psnr:.04f} | Loss (L1): {test_loss_l1:.06f} | Loss (L2): {test_loss:.06f} | Test time: {time()-t0:.2f}')

In [None]:
for batch in test_loader:
    # load data to the device
    x, y = batch[0].to(device), batch[1].to(device)
    out = model.forward(x)
    break
prediction = torch.permute(torch.squeeze(out), (1, 2, 0)).detach().cpu().numpy()
real = torch.permute(torch.squeeze(y), (1, 2, 0)).detach().cpu().numpy()

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from utils import save_plot

if IMAGE_SIZE > 550:
    start1, end1 = 100, 250
    start2, end2 = 400, 550
else:
    start1, end1 = IMAGE_SIZE//2 - 25, IMAGE_SIZE//2 + 25
    start2, end2 = IMAGE_SIZE//2 - 25, IMAGE_SIZE//2 + 25

fig = plt.figure(figsize=(10,10))

plt.subplot(2, 2, 1)
plt.imshow(prediction)
plt.title('Prediction')
plt.axis('off')

ax = plt.gca()
rect = Rectangle((start2, start1), end2-start2, end1-start1, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)

plt.subplot(2, 2, 2)
plt.imshow(real)
plt.title('Original')
plt.axis('off')

ax = plt.gca()
rect = Rectangle((start2, start1), end2-start2, end1-start1, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)

plt.subplot(2, 2, 3)
plt.imshow(prediction[start1:end1, start2:end2, :])
plt.title('Prediction')
plt.axis('off')

plt.subplot(2, 2, 4)
plt.imshow(real[start1:end1, start2:end2, :])
plt.title('Original')
plt.axis('off')
plt.show()

plt.show()
save_plot(fig, f'prediction_image_{IMAGE_SIZE}')