forked from royerlab/aydin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_n2t.py
46 lines (31 loc) · 1.32 KB
/
demo_n2t.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# flake8: noqa
import numpy
import torch
from aydin.io.datasets import add_noise, normalise, camera
from aydin.nn.models.torch.torch_res_unet import ResidualUNetModel
from aydin.nn.models.torch.torch_unet import n2t_unet_train_loop, UNetModel
from aydin.nn.pytorch.it_ptcnn import to_numpy
def demo_supervised_2D_n2t(model_class):
visualize = True
lizard_image = normalise(camera()[:256, :256])
lizard_image = numpy.expand_dims(lizard_image, axis=0)
lizard_image = numpy.expand_dims(lizard_image, axis=0)
input_image = add_noise(lizard_image)
input_image = torch.tensor(input_image)
lizard_image = torch.tensor(lizard_image)
model = model_class(nb_unet_levels=2, supervised=True, spacetime_ndim=2)
n2t_unet_train_loop(input_image, lizard_image, model)
denoised = model(input_image)
if visualize:
import napari
viewer = napari.Viewer()
viewer.add_image(to_numpy(lizard_image), name="groundtruth")
viewer.add_image(to_numpy(input_image), name="noisy")
viewer.add_image(to_numpy(denoised), name="denoised")
napari.run()
# assert result.shape == input_image.shape
# assert result.dtype == input_image.dtype
if __name__ == '__main__':
model_class = UNetModel
# model_class = ResidualUNetModel
demo_supervised_2D_n2t(model_class)