-
Notifications
You must be signed in to change notification settings - Fork 3
/
eval_generator.py
114 lines (92 loc) · 4.44 KB
/
eval_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Train a Generator architecture with an MSE loss."""
import argparse
import os
import torch
from glob import glob
import pickle
from tqdm import tqdm
from PIL import Image
from utils import CheckpointIO
import utils.misc as misc
import utils.plot as plot
import utils.metrics as metrics
from dataset import get_dataset
if __name__ == '__main__':
import matplotlib.pyplot as plt; plt.switch_backend('Agg'); plt.ioff()
# Arguments
parser = argparse.ArgumentParser(
description='Evaluate image regression with a trained Generator.'
)
parser.add_argument('expname', type=str, help='Name of experiment.')
parser.add_argument('--psnr', action='store_true', help='Evaluate PSNR of regressed images.')
parser.add_argument('--image-evolution', action='store_true', help='Create video of image evolution.')
parser.add_argument('--spectrum-evolution', action='store_true', help='Create video of spectrum evolution.')
parser.add_argument('--spectrum-error-evolution', action='store_true', help='Create image of spectrum error evolution.')
args = parser.parse_args()
run_dir = os.path.join('output/generator_testbed', args.expname)
cfg = misc.load_config(os.path.join(run_dir, 'config.yaml'))
# fix random seed (ensures to sample same latent codes as in training)
torch.manual_seed(cfg['training']['seed'])
torch.cuda.manual_seed_all(cfg['training']['seed'])
device = torch.device("cuda:0")
# Short hands
batch_size = cfg['training']['batch_size']
nworkers = cfg['training']['nworkers']
out_dir = os.path.join(run_dir, 'eval')
log_dir = os.path.join(run_dir, 'logs')
img_dir = os.path.join(run_dir, 'imgs')
plot_dir = os.path.join(run_dir, 'plots')
# Create missing directories
os.makedirs(out_dir, exist_ok=True)
if args.psnr or args.spectrum_error_evolution: # Load dataset
# Dataset
dataset = get_dataset(cfg)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=nworkers,
pin_memory=True, drop_last=False)
if args.psnr: # Load trained model to evaluate psnr of all images
print('Evaluate PSNR...')
# Logger
checkpoint_io = CheckpointIO(checkpoint_dir=run_dir)
# Create models
common_kwargs = misc.EasyDict(resolution=cfg.data.resolution)
model = misc.construct_class_by_name(**cfg.model, **common_kwargs).train().requires_grad_(True).to(device)
# Put model on gpu if needed
model = model.to(device)
# Register modules to checkpoint
checkpoint_io.register_modules(
model=model,
)
# Load checkpoint
load_dict = checkpoint_io.load('model.pt')
print(f'Using checkpoint from iteration {load_dict["it"]}.')
psnr = []
for img, z in tqdm(dataloader):
img, z = img.to(device), z.to(device)
model = model.eval()
pred = model(z)
psnr.append(metrics.psnr(pred, img))
psnr = torch.cat(psnr).mean().item()
print(f'Average PSNR: {psnr:.1f}.')
if args.image_evolution:
print('Plot image evolution...')
images = [Image.open(f) for f in sorted(glob(os.path.join(img_dir, 'samples_*.png')))]
misc.make_video(images, os.path.join(out_dir, 'image_evolution.mp4'), fps=20, quality=8)
print('Done.')
if args.spectrum_evolution:
print('Plot spectrum evolution...')
images = [Image.open(f) for f in sorted(glob(os.path.join(plot_dir, 'spectrum_*.png')))]
misc.make_video(images, os.path.join(out_dir, 'spectrum_evolution.mp4'), fps=20, quality=8, macro_block_size=None)
print('Done.')
if args.spectrum_error_evolution:
print('Plot spectrum error evolution...')
spec_file_real = os.path.join(dataset.root, f'spectrum{dataset.resolution}_N{len(dataset)}.pkl') # Loads cache file from training
with open(spec_file_real, 'rb') as f:
spec_real = pickle.load(f)
spec_gen_all = []
for spec_file in sorted(glob(os.path.join(log_dir, 'spectrum_*.pkl'))):
with open(spec_file, 'rb') as f:
spec = pickle.load(f)
spec_gen_all.append(spec)
filename = os.path.join(out_dir, 'spectrum_error_evolution.png')
plot.plot_spectrum_error_evolution(spec_real, spec_gen_all, dataset.resolution, filename)
print('Done.')