In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# model configuration
model_path = "D:/luke/edsr-baseline-lte.pth"
# model_path = "D:/luke/lte_geo/save/_train_swinir-lte_geo/230516-1847_resident_crumble_3867/resident_crumble_3867_epoch-last.pth"

# image configuration
lr_path = './demo/Urban100_img012x2.png'
gt_path = './demo/Urban100_img012.png'
scale = 2
xx = 700
yy = 550
obs_size = 120


In [None]:
import models

from torchvision import transforms
from PIL import Image

from test import reshape
from utils import to_pixel_samples

# load model
model_spec = torch.load(model_path)['model']
model = models.make(model_spec, load_sd=True).cuda()

# load image
img_lr = transforms.ToTensor()(Image.open(lr_path).convert('RGB'))
img_gt = transforms.ToTensor()(Image.open(gt_path).convert('RGB'))
img_gt = img_gt.unsqueeze(0)



In [None]:
# evaluation
inp = ((img_lr.unsqueeze(0).cuda() - 0.5) / 0.5)#.unsqueeze(0)

hr_coord, hr_val = to_pixel_samples(img_gt)

hr_cell = torch.ones_like(hr_coord)
hr_cell[:, 0] *= 2 / img_gt.shape[-2]
hr_cell[:, 1] *= 2 / img_gt.shape[-1]
hr_coord = hr_coord.unsqueeze(0)
hr_cell = hr_cell.unsqueeze(0)

print(img_gt.shape, inp.shape, hr_coord.shape, hr_cell.shape)


In [None]:
model.eval()
with torch.no_grad():
    sr = (
        model(
            inp.flip(-2),
            hr_coord.to("cuda", non_blocking=True),
            hr_cell.to("cuda", non_blocking=True),
        )
        .detach()
        .cpu()
    )
    freq = model.coeff.flip(-2)
    coef = model.freqq.flip(-2)

    # model.gen_feat(inp.flip(-2)) # due to a digital image coordinate conventions (https://blogs.mathworks.com/steve/2011/08/26/digital-image-processing-using-matlab-digital-image-representation/)
    # freq = model.freq(model.feat).flip(-2)
    # coef = model.coef(model.feat).flip(-2)


In [None]:
from PIL import ImageDraw

# Display GT
im = Image.open(gt_path).convert('RGB')
draw = ImageDraw.Draw(im)
draw.rectangle([yy-obs_size//2, xx-obs_size//2, yy+obs_size//2, xx+obs_size//2], outline="red", width=3)
display(im)


In [None]:
sr, batch = reshape(dict(inp=inp, coord=hr_coord, gt=img_gt), 0, 0, hr_coord, sr)

import matplotlib.pyplot as plt
plt.figure()
plt.imshow(inp.cpu().squeeze().permute(1,2,0).numpy())
plt.figure()
plt.imshow(sr.cpu().squeeze().permute(1,2,0).numpy())

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Display Fourier Feature Space
plt.rcParams["figure.figsize"] = (8, 8)
freq_x = torch.stack(torch.split(freq, 2, dim=1), dim=2)[0, 1, :, xx//scale, yy//scale].cpu().numpy()
freq_y = torch.stack(torch.split(freq, 2, dim=1), dim=2)[0, 0, :, xx//scale, yy//scale].cpu().numpy()
mag    = (coef[0, :freq.shape[1]//2, xx//scale, yy//scale]**2 + coef[0, freq.shape[1]//2:, xx//scale, yy//scale]**2).cpu().numpy()
sc = plt.scatter(freq_x, freq_y, c=mag, vmin=0, vmax=max(mag)/4, s=None, cmap='bwr')
# plt.colorbar(sc)
plt.xticks(np.linspace(-1.5, 1.5, 5))
plt.yticks(np.linspace(-1.5, 1.5, 5))

plt.tight_layout()
plt.show()
