-
Notifications
You must be signed in to change notification settings - Fork 22
/
dataset.py
169 lines (134 loc) · 6.11 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
This script shows:
1) how to simulate a dataset
2) apply a reconstruction algorithm
3) compute metrics
"""
import hydra
from hydra.utils import to_absolute_path
from lensless.utils.io import save_image
from lensless.utils.image import rgb2gray, resize
import numpy as np
from lensless import ADMM
from lensless.eval.metric import mse, psnr, ssim, lpips, LPIPS_MIN_DIM
from waveprop.simulation import FarFieldSimulator
import glob
import os
from tqdm import tqdm
from lensless.utils.io import load_image, load_psf
@hydra.main(version_base=None, config_path="../../configs", config_name="simulate_dataset")
def simulate(config):
# set seed
np.random.seed(config.seed)
dataset = to_absolute_path(config.files.dataset)
if not os.path.isdir(dataset):
print(f"No dataset found at {dataset}")
try:
from torchvision.datasets.utils import download_and_extract_archive
except ImportError:
exit()
msg = "Do you want to download the sample CelebA dataset (764KB)?"
# default to yes if no input is given
valid = input("%s (Y/n) " % msg).lower() != "n"
if valid:
url = "https://drive.switch.ch/index.php/s/Q5OdDQMwhucIlt8/download"
filename = "celeb_mini.zip"
download_and_extract_archive(
url, os.path.dirname(dataset), filename=filename, remove_finished=True
)
psf_fp = to_absolute_path(config.files.psf)
assert os.path.exists(psf_fp), f"PSF {psf_fp} does not exist."
if config.save.bool:
save_dir = to_absolute_path(config.save.dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
os.makedirs(os.path.join(save_dir, "sensor_plane"))
os.makedirs(os.path.join(save_dir, "object_plane"))
os.makedirs(os.path.join(save_dir, "reconstruction"))
# load psf as numpy array
print("\nPSF:")
psf = load_psf(psf_fp, verbose=True, downsample=config.simulation.downsample)
# remove depth dimension, as 3D not supported by FarFieldSimulator
psf_sim = psf[0]
if config.simulation.grayscale and len(psf_sim.shape) == 3:
psf_sim = rgb2gray(psf_sim)
if config.simulation.downsample > 1:
print(f"Downsampled to {psf_sim.shape}.")
# prepare simulator object
simulator = FarFieldSimulator(psf=psf_sim, **config.simulation)
# loop over files in dataset
print("\nSimulating dataset...")
files = glob.glob(os.path.join(dataset, f"*.{config.files.image_ext}"))
if config.files.n_files is not None:
files = files[: config.files.n_files]
for fp in tqdm(files):
# load image as numpy array
image = load_image(fp)
if config.simulation.grayscale and len(image.shape) == 3:
image = rgb2gray(image)
# simulate image
image_plane, object_plane = simulator.propagate(image, return_object_plane=True)
if config.save.bool:
bn = os.path.basename(fp).split(".")[0] + ".png"
# can serve as ground truth
object_plane_fp = os.path.join(save_dir, "object_plane", bn)
save_image(object_plane, object_plane_fp) # use max range of 255
# lensless image
lensless_fp = os.path.join(save_dir, "sensor_plane", bn)
save_image(image_plane, lensless_fp, max_val=config.simulation.max_val)
# reconstruction
if config.admm.bool:
print("\nReconstructing lensless measurements...")
output_dim = image_plane.shape
if config.simulation.output_dim is not None:
# best would be to incorporate downsampling in the reconstruction
# for now downsample the PSF
print("-- Resizing PSF to", config.simulation.output_dim, "for reconstruction.")
psf = resize(psf, shape=config.simulation.output_dim)
# -- initialize reconstruction object
recon = ADMM(psf, **config.admm)
# -- metrics
mse_vals = []
psnr_vals = []
ssim_vals = []
if not config.simulation.grayscale and np.min(output_dim[:2]) >= LPIPS_MIN_DIM:
lpips_vals = []
else:
lpips_vals = None
# -- loop over files in dataset
files = glob.glob(os.path.join(save_dir, "sensor_plane", "*.png"))
if config.files.n_files is not None:
files = files[: config.files.n_files]
for fp in tqdm(files):
lensless = load_image(fp, as_4d=True)
lensless = lensless / np.max(lensless)
recon.set_data(lensless)
res, _ = recon.apply(n_iter=config.admm.n_iter, disp_iter=config.admm.disp_iter)
recovered = res[0] # first depth
if config.save.bool:
bn = os.path.basename(fp).split(".")[0] + ".png"
lensless_fp = os.path.join(save_dir, "reconstruction", bn)
save_image(recovered, lensless_fp, max_val=config.simulation.max_val)
# compute metrics
object_plane_fp = os.path.join(save_dir, "object_plane", os.path.basename(fp))
object_plane = load_image(object_plane_fp)
if config.simulation.output_dim is not None:
# best would be to incorporate downsampling in the reconstruction
# for now downsample the PSF
# print("-- Resizing object plane to", config.simulation.output_dim, "for metric.")
object_plane = resize(object_plane, shape=config.simulation.output_dim)
mse_vals.append(mse(object_plane, recovered))
psnr_vals.append(psnr(object_plane, recovered))
if config.simulation.grayscale:
ssim_vals.append(ssim(object_plane, recovered, channel_axis=None))
else:
ssim_vals.append(ssim(object_plane, recovered))
if lpips_vals is not None:
lpips_vals.append(lpips(object_plane, recovered))
print("\nMSE (avg)", np.mean(mse_vals))
print("PSNR (avg)", np.mean(psnr_vals))
print("SSIM (avg)", np.mean(ssim_vals))
if lpips_vals is not None:
print("LPIPS (avg)", np.mean(lpips_vals))
if __name__ == "__main__":
simulate()