In [None]:
import yaml

with open("configs/ct_mcq_vae.yaml", 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

In [None]:
from models import CTMCQVAE
import torch
import re

model = CTMCQVAE(**config['model_params'])
checkpoint = torch.load(config['trainer_params']['resume_from_checkpoint'])
state_dict = dict(map(lambda x : (re.sub('model.', '', x[0]), x[1]), checkpoint['state_dict'].items()))
model.eval()
model.load_state_dict(state_dict)

In [None]:
from dataset import VAEDataset

dataset = VAEDataset(**config['data_params'])
dataset.setup()

In [None]:
loader = iter(dataset.test_dataloader())
# loader = iter(dataset.val_dataloader())
# loader = dataset.train_dataloader()

In [None]:
mode = ""
while mode != "action": 
    batch = next(loader)
    test_input, test_label, *args = batch
    mode = args[0]["mode"][0]
test_output = model(test_input, labels = test_label, **args[0])
test_y = args[0]["input_y"]

print("mode: ", mode)

In [None]:
import torchvision
import matplotlib.pyplot as plt

input_images = [torchvision.transforms.functional.to_pil_image(test_input[i]) for i in range(test_input.size(0))]
output_images = [torchvision.transforms.functional.to_pil_image(test_output[0][i]) for i in range(test_output[0].size(0))]

dataset_name = config["data_params"]["dataset_name"]
factors = {
    # "TCeleba": ["5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes", "Bald", "Bangs", "Big_Lips", "Big_Nose", "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair", "Bushy_Eyebrows", "Chubby", "Double_Chin", "Eyeglasses", "Goatee", "Gray_Hair", "Heavy_Makeup", "High_Cheekbones", "Male", "Mouth_Slightly_Open", "Mustache", "Narrow_Eyes", "No_Beard", "Oval_Face", "Pale_Skin", "Pointy_Nose", "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling", "Straight_Hair", "Wavy_Hair", "Wearing_Earrings", "Wearing_Hat", "Wearing_Lipstick", "Wearing_Necklace", "Wearing_Necktie", "Young"],
    "TCeleba": ["5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes", "Bald", "Bangs", "Big_Lips", "Big_Nose", "Black_Hair", "Blond_Hair"],
    "TCars3D": ['elevation', 'azimuth', 'object_type'],
    "TShapes3D": ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation'],
    "TDSprites": ['shape', 'scale', 'orientation', 'position_x', 'position_y'],
    "TSprites": ['bottomwear', 'topwear', 'hair', 'eyes', 'shoes', 'body', 'action', 'rotation', 'frame']
}
action_list = factors[dataset_name]

if mode == "action":
    y_images = [torchvision.transforms.functional.to_pil_image(test_y[i]) for i in range(test_y.size(0))]
    action_ids = [torch.argmax(action_vector).tolist() for action_vector in args[0]["action"]]

fig = plt.figure(figsize=(50, 14))
figsize = (3*len(input_images)//12)+1, 12
for i in range(len(input_images)):
    fig.add_subplot(figsize[0], figsize[1], 3*i+1)
    plt.imshow(input_images[i])
    plt.axis('off')
    fig.add_subplot(figsize[0], figsize[1], 3*i+2)
    plt.imshow(output_images[i])
    plt.axis('off')
    if mode == "action":
        action = action_ids[i]
        dir = "+"
        if action >= len(action_list):
            action = action - len(action_list)
            dir = "-"
        plt.title(f"action {action}: {action_list[action]} {dir}")
        fig.add_subplot(figsize[0], figsize[1], 3*i+3)
        plt.imshow(y_images[i])
        plt.title("true y")
        plt.axis('off')



In [None]:
import os
import torchvision.utils as vutils
import torch.nn.functional as F

id = "test5_noiseexo"
nb_actions = 2*len(action_list)
n=0

sample = test_input[n].unsqueeze(0)
print(test_label[n])
vutils.save_image(sample.data,
                    os.path.join("images/", 
                                f"{dataset_name}_{id}/",
                                f"CT-VAE-V5_input_action_{dataset_name}_{id}.png"),
                    normalize=True,
                    nrow=12)
sample = sample.repeat(nb_actions,1,1,1)

for i in range(5):
    action = F.one_hot(torch.arange(nb_actions)).to(dtype=torch.float32)
    sample_args = {
        "mode": ["action"] * nb_actions,
        "action": action,
        "input_y": sample
    }
    sample = model(sample, labels = None, **sample_args)[0]

    for sign_id, sign in enumerate(["+", "-"]):
        for action_id, action_name in enumerate(action_list):
            vutils.save_image(sample[sign_id * len(action_list) + action_id].data,
                        os.path.join("images/", 
                                    f"{dataset_name}_{id}/",
                                    f"CT-VAE-V5_output_action_{dataset_name}_{action_name}_{sign}{i}_{id}.png"),
                        normalize=True,
                        nrow=12)


In [None]:
import pandas as pd

d= {'_'.join([a,dir, metric]):[] for a in action_list for metric in ["causal_acc", "causal_nodir_acc"] for dir in ["+", "-"]}
action = F.one_hot(torch.arange(nb_actions)).to(dtype=torch.float32)
sample_in = [test_input for _ in range(len(action))]
for i in range(1):
    for a in range(len(action)):
        sample_args = {
            "mode": ["action"] * test_input.size(0),
            "action": action[a].unsqueeze(0).repeat(test_input.size(0),1),
            "input_y": sample_in[a]
        }
        sample_out = model(sample_in[a], labels = None, **sample_args)[0]
        sample_out_args = {
            "mode": ["causal"] * test_input.size(0),
            "action": action[a].unsqueeze(0).repeat(test_input.size(0),1),
            "input_y": sample_out
        }
        metrics = model(test_input, labels = None, **sample_out_args)[4]
        dir = ["+","-"][a//len(action_list)]
        d[f"{action_list[a % len(action_list)]}_{dir}_causal_acc"].append(metrics["causal_acc"].tolist())
        d[f"{action_list[a % len(action_list)]}_{dir}_causal_nodir_acc"].append(metrics["causal_nodir_acc"].tolist())
        sample_in[a] = sample_out

print(pd.DataFrame(d))

In [None]:
import matplotlib.pyplot as plt

plt.clf()
plt.ylim(0,1)
for name, values in d.items():
    name_attrs = name.split("_")
    if "Hair" in name_attrs:
        plot_name = ("Factor" if "nodir" in name_attrs else "Action") + (" (+)" if "+" in name_attrs else " (-)")
        plt.plot(range(len(values)), values, label=plot_name)
        plt.xlabel("Steps")
        plt.ylabel("Acc.")
plt.legend()
plt.show()

In [None]:
from tqdm import tqdm

loader = iter(dataset.test_dataloader())
val_acc = []
val_nodir_acc = []

for batch in tqdm(loader):
    test_input, test_label, *args = batch
    mode = args[0]["mode"][0]
    if mode == "causal":
        test_output = model(test_input, labels = test_label, **args[0])
        metrics = test_output[4]
        val_acc.append(metrics["causal_acc"])
        val_nodir_acc.append(metrics["causal_nodir_acc"])
val_acc = torch.stack(val_acc)
val_nodir_acc = torch.stack(val_nodir_acc)

print("val acc: ", val_acc)
print("val nodir acc: ", val_nodir_acc)
print("val acc: mean: ", val_acc.mean())
print("val nodir acc mean: ", val_nodir_acc.mean())