-
Notifications
You must be signed in to change notification settings - Fork 2
/
generate_measurement.py
118 lines (93 loc) · 4.33 KB
/
generate_measurement.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from functools import partial
import os
import argparse
import yaml
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from guided_diffusion.condition_methods import get_conditioning_method
from guided_diffusion.measurements import get_noise, get_operator
from guided_diffusion.unet import create_model
from guided_diffusion.gaussian_diffusion import create_sampler
from data.dataloader import get_dataset, get_dataloader
from util.img_utils import clear_color, mask_generator
from util.logger import get_logger
def load_yaml(file_path: str) -> dict:
with open(file_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_config', type=str)
parser.add_argument('--diffusion_config', type=str)
parser.add_argument('--task_config', type=str)
parser.add_argument('--data_root', type=str)
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()
# logger
logger = get_logger()
# Device setting
device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu'
logger.info(f"Device set to {device_str}.")
device = torch.device(device_str)
# Load configurations
model_config = load_yaml(args.model_config)
diffusion_config = load_yaml(args.diffusion_config)
task_config = load_yaml(args.task_config)
task_config['data']['root'] = args.data_root
# assert model_config['learn_sigma'] == diffusion_config['learn_sigma'], \
# "learn_sigma must be the same for model and diffusion configuartion."
# Prepare Operator and noise
measure_config = task_config['measurement']
operator = get_operator(device=device, **measure_config['operator'])
noiser = get_noise(**measure_config['noise'])
logger.info(f"Operation: {measure_config['operator']['name']} / Noise: {measure_config['noise']['name']}")
# Load diffusion sampler
sampler = create_sampler(**diffusion_config)
# Prepare conditioning method
cond_config = task_config['conditioning']
cond_method = get_conditioning_method(cond_config['method'], operator, noiser, **cond_config['params'])
measurement_cond_fn = cond_method.conditioning
logger.info(f"Conditioning method : {task_config['conditioning']['method']}")
# Working directory
out_path = os.path.join(task_config['data']['root'], measure_config['operator']['name'])
os.makedirs(out_path, exist_ok=True)
os.makedirs(os.path.join(out_path, 'input'), exist_ok=True)
# Prepare dataloader
data_config = task_config['data']
transform = transforms.Compose([transforms.Resize([256, 256]), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = get_dataset(**data_config, transforms=transform)
loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)
# Exception) In case of inpainting, we need to generate a mask
if measure_config['operator']['name'] == 'inpainting':
mask_gen = mask_generator(
**measure_config['mask_opt']
)
# Do Inference
# for i, (ref_img, path) in enumerate(loader):
for i, ref_img in enumerate(loader):
# print(f'path:{path[0].split("/")[-1].split(".")[0]}')
logger.info(f"Inference for image {i}")
fname = str(i).zfill(5) + '.png'
tensor_name = str(i).zfill(5) + '.pt'
# tensor_name = path[0].split("/")[-1].split(".")[0] +'.pt'
ref_img = ref_img.to(device)
# Exception) In case of inpainging,
if measure_config['operator']['name'] == 'inpainting':
mask = mask_gen(ref_img)
mask = mask[:, 0, :, :].unsqueeze(dim=0)
measurement_cond_fn = partial(cond_method.conditioning, mask=mask)
# Forward measurement model (Ax + n)
y = operator.forward(ref_img, mask=mask)
y_n = noiser(y)
else:
# Forward measurement model (Ax + n)
y = operator.forward(ref_img)
y_n = noiser(y)
plt.imsave(os.path.join(out_path, 'input', fname), clear_color(y_n))
if measure_config['operator']['name'] == 'inpainting':
y_n = [y_n, mask]
torch.save(y_n, os.path.join(out_path, 'input', tensor_name))
if __name__ == '__main__':
main()