# Save padded image
Given an image path (`im_pth`), the code below will extract its file name (`file_name`) and generate the following images in the `transformed_images` directory:
* unpadded original image: `{file_name}_no_pad.png`
* padded image to fit into square dimensions: `{file_name}_pad.png`

In [1]:
import os
from PIL import Image
from IPython.display import display

img_size = 2048
im_pth = '/home/aisinai/data/mimic/valid/p10228846/s01/view1_frontal.jpg'  # change to your image path
base = os.path.basename(im_pth)
file_name = os.path.splitext(base)[0]

os.makedirs('transformed_images', exist_ok=True)

im = Image.open(im_pth)
im.save(f'transformed_images/{file_name}_no_pad.png')

old_size = im.size  # old_size[0] is in (width, height) format
ratio = float(img_size) / max(old_size)
new_size = tuple([int(x * ratio) for x in old_size])
im = im.resize(new_size, Image.ANTIALIAS)

# create a new image for padding and paste the resized on it
new_im = Image.new("RGB", (img_size, img_size))
new_im.paste(im, ((img_size - new_size[0]) // 2,
                  (img_size - new_size[1]) // 2))
new_im.save(f'transformed_images/{file_name}_pad.png')

# Save reconstructed images
There are 4 models, identified by their train runs. First three models have encoder output depth of 64.
* 0: Model A. 2 convolutions in the first / bottom layer | 1 convolution  in the second / top layer
* 1: Model B. 3 convolutions in the first / bottom layer | 2 convolutions in the second / top layer
* 3: Model C. 4 convolutions in the first / bottom layer | 2 convolutions in the second / top layer

Last model, Model D, have encoder output depth of 1.
* embed1: Model D. 2 convolutions in the first / bottom layer | 1 convolution in the second / top layer

It takes as the input the padded image in the `transformed_images` directory from the above code block and generate the following images in the `transformed_images` directory: 
* `{file_name}_original.png`
* Output from Model A: `{file_name}_recon_A.png`
* Output from Model B: `{file_name}_recon_B.png`
* Output from Model C: `{file_name}_recon_C.png`
* Output from Model D: `{file_name}_recon_D.png`

Note that the images are converted to grayscale with the formula `gray = 0.2989 * r + 0.5870 * g + 0.1140 * b`
to eliminate the blue tint that results from plotting the RGB output.

In [2]:
import torch
from torch import nn
from torch.autograd import Variable
from networks import VQVAE
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from utilities import rgb2gray

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalization = transforms.Normalize(mean=mean, std=std)
transform_array = [transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), normalization]
transform = transforms.Compose(transform_array)

image = torch.zeros((1, 3, img_size, img_size))  # img_size from above
image[0, :] = transform(Image.open(f'transformed_images/{file_name}_pad.png').convert('RGB'))  # file_name from above

mean = torch.FloatTensor([0.485, 0.456, 0.406]).reshape(3, 1, 1).type(Tensor)
std = torch.FloatTensor([0.229, 0.224, 0.225]).reshape(3, 1, 1).type(Tensor)

for model_name in ['A', 'B', 'C', 'D']:
    if model_name == 'A':
        model_dir = '/home/aisinai/work/VQ-VAE2/20200422/vq_vae/CheXpert/0/checkpoint/vqvae_040.pt'
        model = VQVAE(first_stride=4, second_stride=2).cuda() if cuda else VQVAE()
    elif model_name == 'B':
        model_dir = '/home/aisinai/work/VQ-VAE2/20200422/vq_vae/CheXpert/1/checkpoint/vqvae_040.pt'
        model = VQVAE(first_stride=8, second_stride=4).cuda() if cuda else VQVAE()
    elif model_name == 'C':
        model_dir = '/home/aisinai/work/VQ-VAE2/20200422/vq_vae/CheXpert/3/checkpoint/vqvae_040.pt'
        model = VQVAE(first_stride=16, second_stride=4).cuda() if cuda else VQVAE()
    elif model_name == 'D':
        model_dir = '/home/aisinai/work/VQ-VAE2/20200422/vq_vae/CheXpert/embed1/checkpoint/vqvae_040.pt'
        model = VQVAE(first_stride=4, second_stride=2, embed_dim=1).cuda() if cuda else VQVAE()

    model.load_state_dict(torch.load(model_dir))
    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        device_ids = list(range(n_gpu))
        model = nn.DataParallel(model, device_ids=device_ids)
    model.eval()
    original_img = Variable(image.type(Tensor))

    with torch.no_grad():
        out, _ = model(original_img)
        decoded_img, _ = model(original_img)
        quant_t, quant_b, _, id_t, id_b = model.encode(original_img)
        upsample_t = model.upsample_t(quant_t)
        quant = torch.cat([upsample_t, quant_b], 1)

    original_img = original_img * std - mean
    out = out * std - mean
    save_image(rgb2gray(original_img[0,:]).data,
               f'transformed_images/{file_name}_original.png', 
               nrow=1, normalize=True, range=(-1, 1))
    save_image(rgb2gray(out[0,:]).data,
               f'transformed_images/{file_name}_recon_{model_name}.png',
               nrow=1, normalize=True, range=(-1, 1))

