In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
import os
from tqdm.auto import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from matplotlib import pyplot as plt

from torch_tools.data import UnannotatedDataset
from torch_tools.visualization import to_image, to_image_grid

# Model Load and Data Init

In [None]:
from copy import deepcopy
from models.networks import define_G, define_D


def load_from_state(state_path):
    state = torch.load(state_path, map_location='cpu')

    gen = define_G(3, 3, 64, 'resnet_9blocks', norm='instance').cuda()
    gen.load_state_dict(
        {k.replace('netG.module.', ''): val for k, val in state.items()},
        strict=False,
    );

    dis = define_D(3, 64, 'basic', norm='instance').cuda()
    dis.load_state_dict(
        {k.replace('netD.module.', ''): val for k, val in state.items()},
        strict=False,
    );
    return gen, dis


target_task = 'zebra2horse'
gen, dis = load_from_state(f'checkpoints/CUT_synth_aug/{target_task}/state_dict.pt')

In [None]:
ds_path = {
    'summer2winter': 'datasets/summer2winter_yosemite/testA',
    'winter2summer': 'datasets/summer2winter_yosemite/testB',
    'apple2orange': 'datasets/apple2orange/testA',
    'orange2apple': 'datasets/apple2orange/testB',
    'horse2zebra': 'datasets/horse2zebra/testA',
    'zebra2horse': 'datasets/horse2zebra/testB',
}[target_task]


ds_source_test = UnannotatedDataset(ds_path)
ds_source_train = UnannotatedDataset(ds_path.replace('test', 'train'))

# Loaded Model Visualization

In [None]:
batch_size = 32
source = next(iter(DataLoader(ds_source_test, batch_size, shuffle=True))).cuda()


with torch.no_grad():
    sample = gen(source)

to_image_grid(torch.cat([source, sample]), nrow=len(source))

In [None]:
from fid import fid

def compute_img2img_fid(gen, model=None):
    def it_gen():
        for sample in DataLoader(ds_source_test, batch_size=16):
            with torch.no_grad():
                yield gen(sample.cuda())

    ds_target = ds_path.replace('testA', 'testB') if 'testA' in ds_path else \
                ds_path.replace('testB', 'testA')
    it_real = DataLoader(UnannotatedDataset(ds_target), batch_size=16)
    fid_val = fid.calculate_fid_given_iterators(
        it_gen(),
        it_real,
        compute_option=fid.FIDBackend.numpy,
        verbose=True,
        normalize_input=False,
    )

    return fid_val

print(compute_img2img_fid(gen))

# StyleGAN2-ADA Samples

In [None]:
from StyleGAN2_ada.training.networks import Generator as StyleGAN2AdaGenerator


chkpts_dir = './checkpoints/StyleGAN2_ADA'
chkpts = os.listdir(chkpts_dir)
samples_per_gen = 6
fig, axs = plt.subplots(len(chkpts), 1, figsize=(2 * samples_per_gen, 2 * len(chkpts)))


for ax, chkpt in zip(axs, chkpts):
    G = StyleGAN2AdaGenerator(512, 0, 512, 256, 3)
    G.load_state_dict(torch.load(f'{chkpts_dir}/{chkpt}', map_location='cpu')['G_ema'])
    G.eval().cuda()

    with torch.no_grad():
        imgs_orig = G(torch.randn([samples_per_gen, 512], device='cuda'),
                      c=None, truncation_psi=1.0)

    ax.axis('off')
    ax.set_title(chkpt)
    ax.imshow(to_image_grid(imgs_orig))