In [1]:
import os
import torch
import numpy as np
import json
from easydict import EasyDict as edict
from test_utils import load_dataset, LoaderSampler, tensor2img
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset, DataLoader
from core.my_solver import Solver
from core.fid_score import calculate_frechet_distance
from core.my_metrics import get_Z_pushed_loader_stats, calculate_cost
import matplotlib.pyplot as plt

os.environ['CUDA_VISIBLE_DEVICES'] = "0"


DATASET1, DATASET1_PATH = 'handbag', '../data/handbag_128.hdf5'
DATASET2, DATASET2_PATH = 'shoes', '../data/shoes_128.hdf5'

IMG_SIZE = 128

filename = 'stats/{}_{}_test.json'.format(DATASET2, IMG_SIZE)
with open(filename, 'r') as fp:
    data_stats = json.load(fp)
    mu_data, sigma_data = data_stats['mu'], data_stats['sigma']
del data_stats

device = 'cuda'
input_shape = (3, IMG_SIZE, IMG_SIZE)

torch.manual_seed(0)
np.random.seed(0)

train_loader_a, test_loader_a = load_dataset(DATASET1, DATASET1_PATH,
                                             img_size=IMG_SIZE, batch_size=32)
train_loader_b, test_loader_b = load_dataset(DATASET2, DATASET2_PATH,
                                             img_size=IMG_SIZE, batch_size=32)

n_batches = min(len(train_loader_a), len(train_loader_b))

X_sampler = LoaderSampler(train_loader_a, device=device)
X_test_sampler = LoaderSampler(test_loader_a, device=device)
Y_sampler = LoaderSampler(train_loader_b, device=device)
Y_test_sampler = LoaderSampler(test_loader_b, device=device)

torch.manual_seed(0)
np.random.seed(0)
X_fixed = X_sampler.sample(12)
Y_fixed = Y_sampler.sample(12)

X_test_fixed = X_test_sampler.sample(12)
Y_test_fixed = Y_test_sampler.sample(12)

indices = [0, 243, 2, 35, 189, 246]
X = torch.stack([test_loader_a.dataset[indices[i]][0].to(device) for i in range(len(indices))])
                
lambdas_arr = [0.0]
                
fig, axes = plt.subplots(1 + len(lambdas_arr), len(indices), figsize=(2*len(indices) + 1, 2*len(lambdas_arr) + 2), dpi=200)


transform = Compose([Resize((IMG_SIZE, IMG_SIZE)),
                     ToTensor(),
                     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
test_dataset_a = ImageFolder('/cache/selikhanovych/OT_competitors/stargan-v2/data/celeba2anime_test/celeba_female', 
                             transform)
test_loader_a = DataLoader(dataset=test_dataset_a,
                           batch_size=20,
                           num_workers=4,
                           pin_memory=True,
                           shuffle=False)
X_test_sampler = LoaderSampler(test_loader_a)


for k, lambda_val in enumerate(lambdas_arr):
    args = edict({
    'img_size': 128,
    'num_domains': 2,
    'latent_dim': 16,
    'hidden_dim': 512,
    'style_dim': 64,
    'lambda_reg': 1.0,
    'lambda_cyc': 1.0,
    'lambda_sty': 1.0,
    'lambda_ds': 1.0,
    'lambda_id': lambda_id,
    'ds_iter': 100000,
    'w_hpf': -1.0,
    'randcrop_prob': 0.5,
    'total_iters': 100000,
    'resume_iter': 0,
    'batch_size': 16,
    'val_batch_size': 32,
    'lr': 1e-4,
    'f_lr': 1e-6,
    'beta1': 0.0,
    'beta2': 0.99,
    'weight_decay': 1e-4,
    'num_outs_per_domain': 10,
    'mode': 'train',
    'num_workers': 4,
    'seed': 777,
    'train_img_dir': '/cache/selikhanovych/extremal_ot/stargan-v2/data/handbag2shoes_train',
    'val_img_dir': '/cache/selikhanovych/extremal_ot/stargan-v2/data/handbag2shoes_test',
    'sample_dir': 'expr/samples',
    'checkpoint_dir': f'/cache/selikhanovych/extremal_ot/stargan-v2/handbag2shoes_exps/checkpoints_lambda_{lambda_id}',
    'eval_dir': 'expr/eval',
    'result_dir': 'expr/results',
    'src_dir': 'assets/representative/celeba_hq/src',
    'ref_dir': 'assets/representative/celeba_hq/ref',
    'inp_dir': 'assets/representative/custom/female',
    'out_dir': 'assets/representative/celeba_hq/src/female',
    'wing_path': 'expr/checkpoints/wing.ckpt',
    'lm_path': 'expr/checkpoints/celeba_lm_mean.npz',
    'print_every': 10,
    'sample_every': 4000,
    'save_every': 4000,
    'eval_every': 4000,
    'device': 0,

    'target_dataset': 'shoes',
    'OUTPUT_PATH': f'/cache/selikhanovych/extremal_ot/stargan-v2/handbag2shoes_exps/checkpoints_lambda_{lambda_id}',
    'domains': {
        'source': 'handbag',
        'target': 'shoes',
    },
    'train_a': '/cache/selikhanovych/extremal_ot/stargan-v2/data/handbag2shoes_train/handbag',
    'train_b': '/cache/selikhanovych/extremal_ot/stargan-v2/data/handbag2shoes_train/shoes',
    'test_a': '/cache/selikhanovych/extremal_ot/stargan-v2/data/handbag2shoes_test/handbag',
    'test_b': '/cache/selikhanovych/extremal_ot/stargan-v2/data/handbag2shoes_test/shoes',

    'n_epochs': 1
    })
                
    eval_trg_domain = args.domains['target']
    eval_src_domain = args.domains['source']

    domains = os.listdir(args.val_img_dir)
    domains.sort()

    final_trg_index = -1
    final_src_index = -1

    for trg_idx, trg_domain in enumerate(domains):
        src_domains = [x for x in domains if x != trg_domain]
        for src_idx, src_domain in enumerate(src_domains):
            if src_domain == eval_src_domain and trg_domain == eval_trg_domain:
                final_trg_index = trg_idx
                final_src_index = src_idx

    eval_trg_domain = args.domains['target']
    eval_src_domain = args.domains['source']

    domains = os.listdir(args.val_img_dir)
    domains.sort()
    num_domains = len(domains)
    for trg_idx, trg_domain in enumerate(domains):
        src_domains = [x for x in domains if x != trg_domain]
        for src_idx, src_domain in enumerate(src_domains):
            if src_domain == eval_src_domain and trg_domain == eval_trg_domain:
                print(f"trg_idx = {trg_idx}, trg_domain = {trg_domain}, src_domain = {src_domain}")
                break
        break

    print(f"trg_idx = {trg_idx}, lambda_id = {args.lambda_id}")
                
    solver = Solver(args)
    iterations = [1 + 4000*(i + 1) for i in range(24)]
    best_fid = np.inf
    best_iter = 0
    for iteration in iterations:
        print(f"iteration = {iteration}")
        solver._load_checkpoint(iteration)

        nets_ema = solver.nets_ema
        nets_ema.mapping_network.eval()
        nets_ema.generator.eval()

        torch.manual_seed(0xBADBEEF)
        np.random.seed(0xBADBEEF)

        num_calculation_fid = 1
        fid_values = []
        l2_cost_values = []
        l1_cost_values = []

        for i in range(num_calculation_fid):
            mu, sigma = get_Z_pushed_loader_stats(nets_ema, args.domains, args, device,
                                                  batch_size=37, n_epochs=args.n_epochs)
            fid = calculate_frechet_distance(solver.mu_data, solver.sigma_data, mu, sigma)
            print(f"FID = {fid}")
            fid_values.append(fid)

            l1_cost = calculate_cost(nets_ema, args, trg_idx, X_test_sampler.loader, device,
                       cost_type='l1', verbose=True)
            print(f"l1 = {l1_cost}")
            l1_cost_values.append(l1_cost)

            l2_cost = calculate_cost(nets_ema, args, trg_idx, X_test_sampler.loader, device,
                       cost_type='mse', verbose=True)
            print(f"l2 = {l2_cost}")
            l2_cost_values.append(l2_cost)
            
        if fid < best_fid:
            best_fid = fid
            best_iter = iteration
            print(f"best fid = {best_fid}, best_iter = {best_iter}, lambda = {lambda_val}, l1 = {l1_cost}, l2 = {l2_cost}")

Keys: <KeysViewHDF5 ['imgs']>


  "See the documentation of nn.Upsample for details.".format(mode))


Keys: <KeysViewHDF5 ['imgs']>


KeyboardInterrupt: 