In [None]:
import torch
import numpy as np
from PIL import Image
from networks import VGGEncoder, VGGDecoder

# Load an image

In [None]:
image = Image.open('/home/gpu1/datasets/COCO/images/train2017/000000001006.jpg')
w, h = image.size
image = image.resize(((w // 8) * 8, (h // 8) * 8), Image.LANCZOS)
image

# Load the trained models

In [None]:
e = VGGEncoder()
e.load_state_dict(torch.load('models/encoder.pth'))

d = {i: VGGDecoder(i) for i in [1, 2, 3, 4]}
for i, m in d.items():
    m.load_state_dict(torch.load(f'models/decoder{i}.pth'))

# Encode and decode

In [None]:
x = torch.FloatTensor(np.array(image)/255.0)
x = x.permute(2, 0, 1).unsqueeze(0)

print(x.shape)

In [None]:
level = 4  # 1, 2, 3, 4

with torch.no_grad():
    
    features, pooling_indices = e(x, level)
    y = d[level](features[level], pooling_indices)
    
    y = y.clamp(0.0, 1.0)
    y *= 255.0
    y = y[0].permute(1, 2, 0).numpy().astype('uint8')

print(y.shape)

# See the result

In [None]:
Image.fromarray(y)