-
Notifications
You must be signed in to change notification settings - Fork 35
/
save_compared_images.py
60 lines (47 loc) · 2.21 KB
/
save_compared_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import sys
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from fanogan.save_compared_images import save_compared_images
def main(opt):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([transforms.Resize([opt.img_size]*2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])])
dataset = ImageFolder(opt.test_root, transform=transform)
test_dataloader = DataLoader(dataset, batch_size=opt.n_grid_lines,
shuffle=False)
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from mvtec_ad.model import Generator, Encoder
generator = Generator(opt)
encoder = Encoder(opt)
save_compared_images(opt, generator, encoder, test_dataloader, device)
"""
The code below is:
Copyright (c) 2018 Erik Linder-Norén
Licensed under MIT
(https://github.com/eriklindernoren/PyTorch-GAN/blob/master/LICENSE)
"""
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("test_root", type=str,
help="root name of your dataset in test mode")
parser.add_argument("--n_grid_lines", type=int, default=10,
help="number of grid lines in the saved image")
parser.add_argument("--force_download", "-f", action="store_true",
help="flag of force download")
parser.add_argument("--latent_dim", type=int, default=100,
help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=64,
help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3,
help="number of image channels")
parser.add_argument("--n_iters", type=int, default=None,
help="value of stopping iterations")
opt = parser.parse_args()
main(opt)