In [None]:
import torch
import numpy as np
from PIL import Image
import time
import torchvision.transforms as transforms
from glob import glob
from torchvision.utils import save_image, make_grid

import networks 

torch.set_grad_enabled(False)

In [None]:
tensor_transform = transforms.ToTensor()

In [None]:
def srganPath(p):
    return p + "SRGAN_generator.pth"
def metaSrganPath(p):
    return p + "MetaSRGAN_generator.pth"
def autoencoderPath(p):
    return p + "DNO_Autoencoder.pth"

In [None]:
def save_sample(imgs, path):
    batch_size = imgs[0].shape[0]
    grid_imgs = []
    for img_seq in imgs:
        grid_imgs.append(make_grid(img_seq[0:batch_size], nrow=1, normalize=False))
    imgs_all = torch.cat(grid_imgs, dim=-1)
    sample_num = len(glob(path + "MetaSRGAN_Inference_*.png"))
    save_image(imgs_all, path + ("MetaSRGAN_Inference_%d.png" % sample_num))

In [None]:
saved_model_path = "../CheckPoints/"
src              = "../Samples/Example5/"
dst              = "../Outputs/"

In [None]:
metaSrgan = networks.GeneratorResNet().cuda()
srgan     = networks.GeneratorSlowNet().cuda()
encoder   = networks.MetaAutoencoderTail().cuda()

metaSrgan.load_state_dict(torch.load(metaSrganPath(saved_model_path)))
srgan.load_state_dict(torch.load(srganPath(saved_model_path)))
encoder.load_state_dict(torch.load(autoencoderPath(saved_model_path)))

metaSrgan.eval()
srgan.eval()
encoder.eval()

In [None]:
start_time = time.time()
img    = Image.open(src + "Image.png").convert(mode="RGB")
normal = Image.open(src + "Normal.png").convert(mode="RGB")
depth  = Image.open(src + "Depth.png").convert(mode="RGB")
obj    = Image.open(src + "Object.png").convert(mode="RGB")

img    = tensor_transform(img).cuda()[None,:]
normal = tensor_transform(normal).cuda()[None,:]
depth  = tensor_transform(depth).cuda()[None,:]
obj    = tensor_transform(obj).cuda()[None,:]

normal = srgan(normal)
depth  = srgan(depth)
obj    = srgan(obj)

metas  = torch.cat((depth[:,0:1], normal, obj[:,0:1]), dim=1)

metaVec = encoder(metas)

inference = metaSrgan(img, metaVec)

save_sample([inference], dst)

print("Inference complete in %f seconds" % (time.time() - start_time))