In [None]:
import yaml

with open("configs/ct_mcq_vae.yaml", 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

In [None]:
from models import CTMCQVAE
import torch
import re

model = CTMCQVAE(**config['model_params'])
checkpoint = torch.load(config['trainer_params']['resume_from_checkpoint'])
state_dict = dict(map(lambda x : (re.sub('model.', '', x[0]), x[1]), checkpoint['state_dict'].items()))
model.eval()
model.load_state_dict(state_dict)

In [None]:
from dataset import VAEDataset

dataset = VAEDataset(**config['data_params'])
dataset.setup()

In [None]:
loader = iter(dataset.test_dataloader())
# loader = dataset.train_dataloader()

In [None]:
batch = next(loader)
test_input, test_label, *args = batch
test_output = model(test_input, labels = test_label, **args[0])
mode = args[0]["mode"][0]
print("mode: ", mode)
if mode == "action":
    test_y = args[0]["input_y"]

In [None]:
import torchvision
import matplotlib.pyplot as plt

input_images = [torchvision.transforms.functional.to_pil_image(test_input[i]) for i in range(test_input.size(0))]
output_images = [torchvision.transforms.functional.to_pil_image(test_output[0][i]) for i in range(test_output[0].size(0))]

# action_list = ["5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes", "Bald", "Bangs", "Big_Lips", "Big_Nose", "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair", "Bushy_Eyebrows", "Chubby", "Double_Chin", "Eyeglasses", "Goatee", "Gray_Hair", "Heavy_Makeup", "High_Cheekbones", "Male", "Mouth_Slightly_Open", "Mustache", "Narrow_Eyes", "No_Beard", "Oval_Face", "Pale_Skin", "Pointy_Nose", "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling", "Straight_Hair", "Wavy_Hair", "Wearing_Earrings", "Wearing_Hat", "Wearing_Lipstick", "Wearing_Necklace", "Wearing_Necktie", "Young"]
# action_list = ['elevation', 'azimuth', 'object_type']
# action_list = ['shape', 'scale', 'orientation', 'position_x', 'position_y']
action_list = ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation']
if mode == "action":
    y_images = [torchvision.transforms.functional.to_pil_image(test_y[i]) for i in range(test_y.size(0))]
    action_ids = [torch.argmax(action_vector).tolist() for action_vector in args[0]["action"]]

fig = plt.figure(figsize=(50, 14))
figsize = (3*len(input_images)//12)+1, 12
for i in range(len(input_images)):
    fig.add_subplot(figsize[0], figsize[1], 3*i+1)
    plt.imshow(input_images[i])
    plt.axis('off')
    fig.add_subplot(figsize[0], figsize[1], 3*i+2)
    plt.imshow(output_images[i])
    plt.axis('off')
    if mode == "action":
        action = action_ids[i]
        dir = "+"
        if action >= len(action_list):
            action = action - len(action_list)
            dir = "-"
        plt.title(f"action {action}: {action_list[action]} {dir}")
        fig.add_subplot(figsize[0], figsize[1], 3*i+3)
        plt.imshow(y_images[i])
        plt.title("true y")
        plt.axis('off')



In [None]:
import torch.nn.functional as F
import torch
args[0]["action"]

nb_iter = 2*len(action_list)
# nb_iter = len(args[0]["action"])
sample = test_input[6].unsqueeze(0).repeat(nb_iter,1,1,1)
action = F.one_hot(torch.arange(nb_iter)).to(dtype=torch.float32)
# action = args[0]["action"]
sample_args = {
    "mode": ["action"] * nb_iter,
    "action": action,
    "input_y": sample
}
sample_output = model(sample, labels = None, **sample_args)

In [None]:
import os
import torchvision.utils as vutils

it=5
vutils.save_image(sample.data,
                    os.path.join("images/", 
                                f"CT-VAE-V3_input_{mode}_3dshapes_{it}.png"),
                    normalize=True,
                    nrow=12)
vutils.save_image(sample_output[0].data,
                    os.path.join("images/", 
                                f"CT-VAE-V3_output_{mode}_3dshapes_{it}.png"),
                    normalize=True,
                    nrow=12)