In [4]:
import os
import time

import torch
from torchvision.io import write_video
from tqdm import tqdm

from src.agent.recycle_gan import ReCycleGAN
from src.utils.datasets import get_dataloaders
from src.utils.train import read_config

In [5]:
root = '../../../results/RecyleGAN_Cadis_Cataract/2022_09_15-12_36_49/'
conf, uneasy_conf = read_config(root + 'config.yml')
conf.device = 'cpu'
conf.data.seq_frames_test = T = 50

In [6]:
print("########## Loading data.")
_, test_dl = get_dataloaders(conf)

########## Loading data.
Domain A size:  4670
Domain B size:  252660
Domain A size:  4670
Domain B size:  252660


In [7]:
print("########## Loading model.")
agent = ReCycleGAN(conf)
agent.netG_B2A.load_state_dict(torch.load(
    root + 'checkpoints/Gen_B2A_ep280.PTH', map_location='cpu'))
agent.netG_A2B.load_state_dict(torch.load(
    root + 'checkpoints/Gen_A2B_ep280.PTH', map_location='cpu'))

########## Loading model.


<All keys matched successfully>

In [8]:
seq_A = None
gen_seq_A = None
seq_B = None
gen_seq_B = None

dl = iter(test_dl)
dt = 1

print("########## Evaluating")
time.sleep(0.1)

for t in tqdm(range(T-2)):

    for _ in range(dt):
        sample = next(dl)

    sample_A = 0.5 * (sample['A'][:, t].data + 1.0)
    gen_sample_A = 0.5 * (agent.netG_B2A(sample['B'][:, t].to(conf.device)).data + 1.0).cpu()
    sample_B = 0.5 * (sample['B'][:, t].data + 1.0)
    gen_sample_B = 0.5 * (agent.netG_A2B(sample['A'][:, t].to(conf.device)).data + 1.0).cpu()

    seq_A = sample_A.unsqueeze(1) if seq_A is None else torch.cat([seq_A, sample_A.unsqueeze(1)], dim=1)
    gen_seq_A = gen_sample_A.unsqueeze(1) if gen_seq_A is None \
        else torch.cat([gen_seq_A, gen_sample_A.unsqueeze(1)], dim=1)
    seq_B = sample_B.unsqueeze(1) if seq_B is None else torch.cat([seq_B, sample_B.unsqueeze(1)], dim=1)
    gen_seq_B = gen_sample_B.unsqueeze(1) if gen_seq_B is None \
        else torch.cat([gen_seq_B, gen_sample_B.unsqueeze(1)], dim=1)

print(seq_A.shape)

########## Evaluating


100%|██████████| 48/48 [01:15<00:00,  1.58s/it]

torch.Size([1, 48, 3, 256, 256])





In [9]:
os.makedirs(root + 'videos/', exist_ok=True)
write_video(root + 'videos/seq_A.mp4', video_array=255*seq_A.squeeze(0).permute(0, 2, 3, 1), fps=5)
write_video(root + 'videos/seq_B.mp4', video_array=255*seq_B.squeeze(0).permute(0, 2, 3, 1), fps=5)
write_video(root + 'videos/seq_A2B.mp4', video_array=255*gen_seq_B.squeeze(0).permute(0, 2, 3, 1), fps=5)
write_video(root + 'videos/seq_B2A.mp4', video_array=255*gen_seq_A.squeeze(0).permute(0, 2, 3, 1), fps=5)