In [1]:
from PIL import Image
import torch
from torchvision import transforms
import models
import os
from tqdm import tqdm

In [22]:
SZ = 256, 256
DEVICE = 'cuda'
IMAGES_DIR = 'test/satellite/'
SAVE_DIR = 'test/satellite_SR_512_77th_ep'

In [10]:
def open_and_process(img_path, sz):
    img = Image.open(img_path)
    img = img.convert('RGB')
    img = img.resize(sz)
    
    to_tensor = transforms.ToTensor()
    img_tensor = to_tensor(img)
    return img_tensor.unsqueeze(0)

In [11]:
def save_tensor(tensor, path):
    img_tensor = tensor.detach().cpu()[0]
    
    to_pil = transforms.ToPILImage()
    img = to_pil(img_tensor)
    
    img.save(path, 'JPEG')

In [19]:
G = models.GeneratorBNFirst(3, 3, upscale=2)

In [21]:
G.load_state_dict(torch.load('01-weights/01-G_epoch-0077_total-loss-0.003.pth', map_location='cpu'))

In [23]:
G.to(DEVICE)

GeneratorBNFirst(
  (first_conv): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (first_prelu): PReLU(num_parameters=1)
  (B): Sequential(
    (0): ResBlock(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu1): PReLU(num_parameters=1)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu2): PReLU(num_parameters=1)
    )
    (1): ResBlock(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu1): PReLU(num_parameters=1)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1

In [24]:
if not os.path.exists(SAVE_DIR):
    os.mkdir(SAVE_DIR)
files = os.listdir(IMAGES_DIR)
for fname in tqdm(files):
    fpath = os.path.join(IMAGES_DIR, fname)
    tensor = open_and_process(fpath, SZ)
    tensor = tensor.to(DEVICE)
    sr = G(tensor)
    fn, ext = os.path.splitext(fname)
    save_path = os.path.join(SAVE_DIR, fn + '.jpg')
    save_tensor(sr, save_path)

100%|██████████| 24/24 [00:01<00:00, 13.80it/s]
