In [12]:
import os
import shutil
import time

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.io import write_video
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

from src.data.image_dataset import ImageDataset
from src.model.discriminator import UNIT_MsImageDis
from src.model.generator import UNIT_VAEGen
from src.utils.train import read_config
from src.utils.datasets import get_dataloaders

In [13]:
root = '../../../results/UNIT-2022_08_23-13_02_26/'
conf, uneasy_conf = read_config(root + 'config.yml')
conf.device = 'cpu'

In [14]:
print("########## Loading data.")
_, test_dl = get_dataloaders(conf)

########## Loading data.
Domain A size:  4670
Domain B size:  252660
Domain A size:  4670
Domain B size:  252660


In [15]:
print("########## Loading model.")
gen_A = UNIT_VAEGen(conf.data.in_dim_A, conf.model.gen).to(conf.device).eval()
gen_A.load_state_dict(torch.load(root + 'checkpoints/gen_A_epoch999.PTH', map_location='cpu'))
gen_B = UNIT_VAEGen(conf.data.in_dim_B, conf.model.gen).to(conf.device).eval()
gen_B.load_state_dict(torch.load(root + 'checkpoints/gen_B_epoch999.PTH', map_location='cpu'))
disc_A = UNIT_MsImageDis(conf.data.in_dim_A, conf.model.disc).to(conf.device).eval()
disc_A.load_state_dict(torch.load(root + 'checkpoints/disc_A_epoch999.PTH', map_location='cpu'))
disc_B = UNIT_MsImageDis(conf.data.in_dim_B, conf.model.disc).to(conf.device).eval()
disc_B.load_state_dict(torch.load(root + 'checkpoints/disc_B_epoch999.PTH', map_location='cpu'))

########## Loading model.


<All keys matched successfully>

In [18]:
print("########## Evaluation")
time.sleep(0.1)

with torch.no_grad():
    sample = next(iter(test_dl))
    real_image_A = sample["A"].to(conf.device)
    real_image_B = sample["B"].to(conf.device)

    N, T, C, H, W = real_image_A.shape

    h_a, n_a = gen_A.encode(real_image_A.view(N*T, C, H, W))
    h_b, n_b = gen_B.encode(real_image_B.view(N*T, C, H, W))

    # Decode (cross domain)
    x_ba = gen_A.decode(h_b + n_b).view((N, T, C, H, W))
    x_ab = gen_B.decode(h_a + n_a).view((N, T, C, H, W))

    # Decode (within domain)
    x_a_recon = gen_A.decode(h_a + n_a).view((N, T, C, H, W))
    x_b_recon = gen_B.decode(h_b + n_b).view((N, T, C, H, W))

    # Encode again
    h_b_recon, n_b_recon = gen_A.encode(x_ba.view(N*T, C, H, W))
    h_a_recon, n_a_recon = gen_B.encode(x_ab.view(N*T, C, H, W))

    # Decode again (if needed)
    x_aba = gen_A.decode(h_a_recon + n_a_recon).view((N, T, C, H, W)) \
        if conf.model.recon_x_cyc_weight > 0 else None
    x_bab = gen_B.decode(h_b_recon + n_b_recon).view((N, T, C, H, W)) \
        if conf.model.recon_x_cyc_weight > 0 else None

    # De-normalize
    real_image_A = 0.5 * (real_image_A + 1.0)
    real_image_B = 0.5 * (real_image_B + 1.0)
    x_a_recon = 0.5 * (x_a_recon + 1.0)
    x_b_recon = 0.5 * (x_b_recon + 1.0)
    x_ab = 0.5 * (x_ab + 1.0)
    x_ba = 0.5 * (x_ba + 1.0)

########## Evaluation


In [19]:
os.makedirs(root + "plots/")
os.makedirs(root + "videos/")
for n in range(real_image_A.shape[0]):

    nt = 10
    # fig, ax = plt.subplots(6, real_image_A.shape[1], figsize=(real_image_A.shape[1]*3, 18))
    fig, ax = plt.subplots(6, nt, figsize=(nt * 3, 18))

    # for t in range(real_image_A.shape[1]):
    for i, t in enumerate(np.linspace(start=0, stop=real_image_A.shape[1] - 1, num=nt, dtype=int)):

        print(t)

        ax[0, i].imshow(real_image_A[n, t].permute(1, 2, 0).cpu().numpy())
        ax[1, i].imshow(x_a_recon[n, t].permute(1, 2, 0).cpu().numpy())
        ax[2, i].imshow(x_ab[n, t].permute(1, 2, 0).cpu().numpy())
        ax[3, i].imshow(real_image_B[n, t].permute(1, 2, 0).cpu().numpy())
        ax[4, i].imshow(x_b_recon[n, t].permute(1, 2, 0).cpu().numpy())
        ax[5, i].imshow(x_ba[n, t].permute(1, 2, 0).cpu().numpy())

    plt.savefig(root + f"plots/n{n}.SVG")
    plt.close()

    write_video(filename=root + f"videos/n{n}_real_A.mp4", fps=5,
                video_array=(real_image_A[n]*255).cpu().permute(0, 2, 3, 1).to(torch.uint8))
    write_video(filename=root + f"videos/n{n}_real_B.mp4", fps=5,
                video_array=(real_image_B[n]*255).cpu().permute(0, 2, 3, 1).to(torch.uint8))
    write_video(filename=root + f"videos/n{n}_AB.mp4", fps=5,
                video_array=(x_ab[n]*255).cpu().permute(0, 2, 3, 1).to(torch.uint8))
    write_video(filename=root + f"videos/n{n}_BA.mp4", fps=5,
                video_array=(x_ba[n]*255).cpu().permute(0, 2, 3, 1).to(torch.uint8))

0
5
10
16
21
27
32
38
43
49
