In [2]:
import os
import time

import torch
from torchvision.io import write_video
from tqdm import tqdm

from src.agent.cycle_gan import CycleGAN_Agent
from src.utils.datasets import get_dataloaders
from src.utils.train import read_config

In [3]:
root = "../../../results/CycleGAN_Cadis_Cataract/2022_08_04-10_12_21/"
conf, uneasy_conf = read_config(root + 'config.yml')
conf.device = 'cpu'
conf.data.seq_frames_test = -1

In [4]:
print("########## Loading data.")
train_dl, test_dl = get_dataloaders(conf)

########## Loading data.
Domain A --- Training: 3493 --- Testing: 3
Domain B --- Training: 169249 --- Testing: 10


In [5]:
print("########## Loading model.")
agent = CycleGAN_Agent(conf)
agent.netG_A2B.load_state_dict(torch.load(root + "checkpoints/Gen_A2B_ep280.PTH", map_location='cpu'))
agent.netG_B2A.load_state_dict(torch.load(root + "checkpoints/Gen_B2A_ep280.PTH", map_location='cpu'))

########## Loading model.


<All keys matched successfully>

In [5]:
print("########## Evaluating")
time.sleep(0.1)

sample = next(iter(test_dl))
real_seq_A = sample['A']
real_seq_B = sample['B']

########## Evaluating


In [6]:
real_seq_A.shape

torch.Size([1, 50, 3, 512, 512])

In [7]:
gen_seq_AB = None
gen_seq_BA = None
with torch.no_grad():
    for t in tqdm(range(real_seq_A.shape[1]-2)):
        gen_img_AB = agent.netG_A2B(real_seq_A[:, t])
        gen_img_BA = agent.netG_B2A(real_seq_B[:, t])

        gen_seq_AB = gen_img_AB.unsqueeze(1) if gen_seq_AB is None\
            else torch.cat([gen_seq_AB, gen_img_AB.unsqueeze(1)], dim=1)
        gen_seq_BA = gen_img_BA.unsqueeze(1) if gen_seq_BA is None\
            else torch.cat([gen_seq_BA, gen_img_BA.unsqueeze(1)], dim=1)

100%|██████████| 48/48 [01:42<00:00,  2.14s/it]


In [9]:
real_seq_A = (real_seq_A + 1.0)/2.0
real_seq_B = (real_seq_B + 1.0)/2.0
gen_seq_AB = (gen_seq_AB + 1.0)/2.0
gen_seq_BA = (gen_seq_BA + 1.0)/2.0

In [11]:
os.makedirs(root + 'videos/', exist_ok=True)
write_video(root + 'videos/seq_A.mp4', video_array=255*real_seq_A.squeeze(0).permute(0, 2, 3, 1), fps=5)
write_video(root + 'videos/seq_B.mp4', video_array=255*real_seq_B.squeeze(0).permute(0, 2, 3, 1), fps=5)
write_video(root + 'videos/seq_A2B.mp4', video_array=255*gen_seq_AB.squeeze(0).permute(0, 2, 3, 1), fps=5)
write_video(root + 'videos/seq_B2A.mp4', video_array=255*gen_seq_BA.squeeze(0).permute(0, 2, 3, 1), fps=5)