1
1
import tensorflow as tf
2
- from tensorflow .layers import conv2d , dropout , max_pooling2d , conv2d_transpose
2
+ from tensorflow .layers import conv2d , max_pooling2d , conv2d_transpose
3
3
from PIL import Image
4
4
import numpy as np
5
5
6
6
7
7
class DenoisingAutoEncoder :
8
8
9
- def __init__ (self , input_shape : tuple , optimizer , is_training : bool ):
9
+ def __init__ (self , input_shape : tuple , batch_input_shape : tuple , optimizer , is_training : bool ):
10
10
self .sess = tf .Session ()
11
11
self .input_shape = input_shape
12
- self .input_image = tf .placeholder (tf .float32 , shape = (None , 1240 , 1240 , 3 ), name = "input_image" )
13
- self .target_image = tf .placeholder (tf .float32 , shape = (None , 1240 , 1240 , 3 ), name = "target_image" )
12
+ self .input_image = tf .placeholder (tf .float32 , shape = batch_input_shape , name = "input_image" )
13
+ self .target_image = tf .placeholder (tf .float32 , shape = batch_input_shape , name = "target_image" )
14
+ self .training = is_training
14
15
15
16
with tf .name_scope ('Encoder' ):
16
- self .conv1 = tf .nn .leaky_relu (conv2d (self .input_image , 16 , (7 , 7 ), padding = 'same' , use_bias = False ))
17
- self .pool1 = max_pooling2d (self .conv1 , (4 , 4 ), (4 , 4 ))
18
- self .dropout1 = dropout (self .pool1 , 0.2 , training = is_training )
19
- self .conv2 = tf .nn .leaky_relu (conv2d (self .dropout1 , 20 , (5 , 5 ), padding = 'same' , use_bias = False ))
20
- self .pool2 = max_pooling2d (self .conv2 , (2 , 2 ), (2 , 2 ))
21
- self .dropout2 = dropout (self .pool2 , 0.3 , training = is_training )
22
- self .conv3 = tf .nn .leaky_relu (conv2d (self .dropout2 , 32 , (5 , 5 ), padding = 'same' , use_bias = False ))
17
+ self .conv1 = tf .nn .leaky_relu (conv2d (self .input_image , 16 , (5 , 5 ), padding = 'same' ))
18
+ self .pool1 = max_pooling2d (self .conv1 , (2 , 2 ), (2 , 2 ))
19
+ self .conv2 = tf .nn .leaky_relu (conv2d (self .pool1 , 32 , (3 , 3 ), padding = 'same' ))
20
+ self .pool2 = max_pooling2d (self .conv2 , (5 , 5 ), (5 , 5 ))
21
+ self .conv3 = tf .nn .leaky_relu (conv2d (self .pool2 , 64 , (3 , 3 ), padding = 'same' ))
23
22
self .pool3 = max_pooling2d (self .conv3 , (5 , 5 ), (5 , 5 ))
24
- self .dropout3 = dropout (self .pool3 , 0.3 , training = is_training )
25
- self .latent_repr = tf .nn .leaky_relu (conv2d (self .dropout3 , 128 , (3 , 3 ), padding = 'same' , use_bias = False ))
26
- # self.latent_repr = max_pooling2d(self.conv4, (5, 5), (5, 5))
23
+ self .latent_repr = tf .nn .leaky_relu (conv2d (self .pool3 , 256 , (3 , 3 ), padding = 'same' ))
27
24
28
25
with tf .name_scope ('Decoder' ):
29
- self .upsampling1 = tf .image .resize_images (self .latent_repr , (31 , 31 ),
26
+ self .upsampling1 = tf .image .resize_images (self .latent_repr , (5 , 5 ),
30
27
tf .image .ResizeMethod .BICUBIC )
31
28
self .conv5 = tf .nn .leaky_relu (
32
- conv2d_transpose (self .upsampling1 , 32 , (3 , 3 ), padding = 'same' , use_bias = False ))
33
- self .dropout4 = dropout (self .conv5 , 0.3 , training = is_training )
34
- self .upsampling2 = tf .image .resize_images (self .dropout4 , (155 , 155 ), tf .image .ResizeMethod .BICUBIC )
29
+ conv2d_transpose (self .upsampling1 , 64 , (3 , 3 ), padding = 'same' ))
30
+ self .upsampling2 = tf .image .resize_images (self .conv5 , (25 , 25 ), tf .image .ResizeMethod .BICUBIC )
35
31
self .conv6 = tf .nn .leaky_relu (
36
- conv2d_transpose (self .upsampling2 , 16 , (5 , 5 ), padding = 'same' , use_bias = False ))
37
- self .upsampling3 = tf .image .resize_images (self .conv6 , (310 , 310 ), tf .image .ResizeMethod .BICUBIC )
38
- self .conv7 = tf .nn .leaky_relu (conv2d_transpose (self .upsampling3 , 3 , (5 , 5 ), padding = 'same' , use_bias = False ))
39
- self .upsampling4 = tf .image .resize_images (self .conv7 , (1240 , 1240 ), tf .image .ResizeMethod .BICUBIC )
40
- self .conv8 = tf .nn .leaky_relu (conv2d_transpose (self .upsampling4 , 3 , (1 , 1 ), padding = 'same' , use_bias = True ))
32
+ conv2d_transpose (self .upsampling2 , 32 , (5 , 5 ), padding = 'same' ))
33
+ self .upsampling3 = tf .image .resize_images (self .conv6 , (50 , 50 ), tf .image .ResizeMethod .BICUBIC )
34
+ self .conv7 = tf .nn .leaky_relu (conv2d_transpose (self .upsampling3 , 3 , (5 , 5 ), padding = 'same' ))
41
35
42
- self .output_image = tf .nn .sigmoid (self .conv8 )
36
+ self .output_image = tf .nn .sigmoid (self .conv7 )
43
37
self .loss = tf .losses .mean_squared_error (self .target_image , self .output_image )
44
38
self .batch_loss = tf .reduce_mean (self .loss )
45
39
46
40
self .train_step = optimizer .minimize (self .batch_loss )
47
41
self .sess .run (tf .global_variables_initializer ())
48
42
self .saver = tf .train .Saver ()
43
+ self .__load_weights ()
49
44
50
- def train (self , epochs : int , ckpt_every : int ):
45
+ def validate (self ):
46
+ noisy_batch , target_batch = self .input_fn ('Data/val.tfrecords' , False , 1024 )
47
+ val_loss = 0
48
+ n_batch = 0
49
+ while True :
50
+ try :
51
+ noisies , targets = self .sess .run ([noisy_batch , target_batch ])
52
+ n_batch += 1
53
+ noisies /= 255
54
+ targets /= 255
55
+
56
+ l = self .sess .run ([self .batch_loss ], feed_dict = {self .input_image : noisies ,
57
+ self .target_image : targets })
58
+ print (l )
59
+ val_loss += l [0 ]
60
+ except tf .errors .OutOfRangeError :
61
+ val_loss = val_loss / n_batch
62
+ return val_loss
63
+
64
+ def train (self , epochs : int , ckpt_every : int , validate : bool ):
51
65
for e in range (1 , epochs + 1 ):
52
- noisy_batch , target_batch = self .input_fn ('Data/train.tfrecords' , True , 2 )
66
+ noisy_batch , target_batch = self .input_fn ('Data/train.tfrecords' , True , 1024 )
53
67
epoch_loss = self .train_epoch (noisy_batch , target_batch )
54
68
if e % ckpt_every == 0 :
55
69
self .checkpoint (e , epoch_loss )
56
- print ('Epoch Loss = {}, epoch={}' .format (epoch_loss , e ))
70
+ if validate :
71
+ print ('Epoch {}, train_loss ={}, val_loss={}' .format (e , epoch_loss , self .validate ()))
72
+ else :
73
+ print ('Epoch Loss = {}, epoch={}' .format (epoch_loss , e ))
57
74
58
75
def train_epoch (self , noisy_batch , target_batch ):
59
76
epoch_loss = 0
@@ -71,23 +88,43 @@ def train_epoch(self, noisy_batch, target_batch):
71
88
except tf .errors .OutOfRangeError :
72
89
return epoch_loss / n_batch
73
90
91
+ def __load_weights (self ):
92
+ weights_file = "Checkpoints/weights-epoch-6loss-0.005/weights-epoch-6loss-0.005.ckpt"
93
+ if not self .training :
94
+ print ('Loaded weights' )
95
+ self .saver .restore (self .sess , weights_file )
96
+
74
97
def checkpoint (self , epoch , loss ):
75
98
epoch = str (epoch )
76
99
loss = "{:.3f}" .format (loss )
77
100
file_name = 'weights-epoch-' + epoch + 'loss-' + loss
78
101
save_path = self .saver .save (self .sess , 'Checkpoints/' + file_name + "/" + file_name + '.ckpt' )
102
+
79
103
print ('Checkpoint for epoch {}, loss {} saved in {}' .format (epoch , loss , save_path ))
80
104
81
- def load (self , ckpt_path ):
82
- self .saver .restore (self .sess , ckpt_path )
105
+ def load (self , saved_path ):
106
+ self .saver .restore (self .sess , saved_path )
83
107
84
- def denoise (self , noisy_image ):
85
- latent , output_t = self . sess . run ([ self . conv8 , self . output_image ], feed_dict = { self . input_image : noisy_image } )
86
- print ( latent )
108
+ def denoise_patch (self , image_patch ):
109
+ image_patch = image_patch . reshape ( 1 , 50 , 50 , 3 )
110
+ latent , output_t = self . sess . run ([ self . conv7 , self . output_image ], feed_dict = { self . input_image : image_patch } )
87
111
output_t = np .array (output_t ) * 255.0
88
112
output_t = output_t .reshape (self .input_shape )
89
- # print(output_t)
90
- return Image .fromarray (output_t .astype ('uint8' )).convert ('RGB' )
113
+ return output_t
114
+
115
+ def denoise (self , image_array ):
116
+ d_image = np .zeros (shape = image_array .shape )
117
+ for x in range (50 , 3000 , 50 ):
118
+ for y in range (50 , 3000 , 50 ):
119
+ patch = image_array [x - 50 :x , y - 50 :y , :]
120
+
121
+ if patch .shape [0 ] != 50 or patch .shape [1 ] != 50 :
122
+ continue
123
+ patch = self .denoise_patch (patch )
124
+ d_image [x - 50 :x , y - 50 :y , :] = patch
125
+
126
+ # print(d_image)
127
+ return Image .fromarray (d_image .astype ('uint8' )).convert ('RGB' )
91
128
92
129
def close_session (self ):
93
130
self .sess .close ()
@@ -101,10 +138,10 @@ def parser(record):
101
138
parsed = tf .parse_single_example (record , keys_to_feature )
102
139
target_image = tf .decode_raw (parsed ['reference' ], tf .uint8 )
103
140
target_image = tf .cast (target_image , tf .float32 )
104
- target_image = tf .reshape (target_image , [1240 , 1240 , 3 ])
141
+ target_image = tf .reshape (target_image , [50 , 50 , 3 ])
105
142
noisy_image = tf .decode_raw (parsed ['noisy' ], tf .uint8 )
106
143
noisy_image = tf .cast (noisy_image , tf .float32 )
107
- noisy_image = tf .reshape (noisy_image , [1240 , 1240 , 3 ])
144
+ noisy_image = tf .reshape (noisy_image , [50 , 50 , 3 ])
108
145
return noisy_image , target_image
109
146
110
147
def input_fn (self , filename , train , batch_size = 4 , buffer_size = 2048 ):
@@ -116,15 +153,3 @@ def input_fn(self, filename, train, batch_size=4, buffer_size=2048):
116
153
iterator = dataset .make_one_shot_iterator ()
117
154
noisy_batch , target_batch = iterator .get_next ()
118
155
return noisy_batch , target_batch
119
-
120
-
121
- d = DenoisingAutoEncoder ((1240 , 1240 , 3 ), tf .train .AdamOptimizer (), True )
122
- # d.train(100, 10)
123
- d .load ('Checkpoints/weights-epoch-100loss-0.033/weights-epoch-100loss-0.033.ckpt' )
124
- sample_img = Image .open ('/home/aftaab/Datasets/Mi3_Aligned/Batch_017//IMG_20151116_151714Noisy.bmp' ).convert (
125
- 'RGB' ).resize ([1240 , 1240 ])
126
- sample_img_t = np .array (sample_img ).reshape ((1 , 1240 , 1240 , 3 )) / 255.0
127
- d_img = d .denoise (sample_img_t )
128
- d_img .save ('denoised.png' , 'PNG' )
129
- sample_img .save ('noisy.png' , 'PNG' )
130
- d .close_session ()
0 commit comments