Skip to content

Commit e6f33ac

Browse files
committed
Finished
1 parent 31d4b28 commit e6f33ac

16 files changed

+293
-61
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
.idea/
2-
*.png
32
Data/
4-
Checkpoints/
3+
__pycache__/
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
model_checkpoint_path: "weights-epoch-6loss-0.005.ckpt"
2+
all_model_checkpoint_paths: "../weights-epoch-2loss-0.009/weights-epoch-2loss-0.009.ckpt"
3+
all_model_checkpoint_paths: "../weights-epoch-3loss-0.007/weights-epoch-3loss-0.007.ckpt"
4+
all_model_checkpoint_paths: "../weights-epoch-4loss-0.006/weights-epoch-4loss-0.006.ckpt"
5+
all_model_checkpoint_paths: "../weights-epoch-5loss-0.005/weights-epoch-5loss-0.005.ckpt"
6+
all_model_checkpoint_paths: "weights-epoch-6loss-0.005.ckpt"
Binary file not shown.
Binary file not shown.

README.md

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,45 @@
1-
### Denoising Autoencoder
2-
Implementation of a denoising autoencoder trained on the RENOIR dataset.
1+
## Denoising Autoencoder
2+
Implementation of a denoising autoencoder trained on the RENOIR dataset(MI 3 images).
3+
4+
## Setting up locally
5+
6+
pip install -r requirements.txt
7+
8+
9+
## Dataset
10+
50x50px patches were taken from the reference and noisy images in the dataset. I've serialised these into TFRecords, which can be downloaded using,
11+
12+
python download_data.py
13+
14+
This will download the train and validation records required for training.
15+
16+
## Training and inference
17+
1. For training you can run,
18+
19+
python train.py -e <num_of_epochs> -c <checkpoint_after> -v <validation_enabled, 1 or 0>
20+
Example:
21+
22+
python train.py -e 50 -c 5 -v 1
23+
Default values are training for 10 epochs, checkpointing every 1 epoch with validation enabled
24+
25+
2. For inference,
26+
27+
python predict.py -i <input_file> -o <output_file>
28+
29+
30+
## Results
31+
I've trained the model for only 6 epochs(which is a very very small fraction of what a lot of papers recommend), so the results aren't particularly good.
32+
33+
1. Reference:
34+
![Reference Image](https://github.com/Aftaab99/DenoisingAutoencoder/blob/master/images/reference.bmp "Reference Image")
35+
36+
2. Noisy
37+
![Noisy Image](https://github.com/Aftaab99/DenoisingAutoencoder/blob/master/images/noisy.png "Noisy Image")
38+
39+
3. Denoised
40+
41+
![Denoised Image](https://github.com/Aftaab99/DenoisingAutoencoder/blob/master/images/denoised.png "Denoised Image")
42+
43+
### References
44+
1. J. Anaya, A. Barbu. RENOIR - A Dataset for Real Low-Light Image Noise Reduction.([arxiv](https://arxiv.org/abs/1409.8230))
45+

__pycache__/model.cpython-36.pyc

-2.31 KB
Binary file not shown.

create_datasets.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,52 @@
33
from PIL import Image
44
from sklearn.model_selection import train_test_split
55
import tensorflow as tf
6+
import numpy as np
7+
from random import shuffle
68

79

810
def _bytes_feature(value):
911
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
1012

1113

12-
def load_image(addr):
13-
img = Image.open(addr).resize([1240, 1240])
14-
return img
14+
def extract_patches(ref, noisy):
15+
patch_list = []
16+
for x in range(50, 3000, 50):
17+
for y in range(50, 3000, 50):
18+
patch_ref = ref[x - 50:x, y - 50:y, :]
19+
patch_noisy = noisy[x - 50:x, y - 50:y, :]
20+
if patch_ref.shape[0] != 50 or patch_ref.shape[1] != 50:
21+
continue
22+
patch_list.append({'ref': patch_ref, 'noisy': patch_noisy})
23+
return patch_list
24+
25+
26+
def get_patches(addr):
27+
ref_img = Image.open(addr['reference']).convert('RGB')
28+
noisy_image = Image.open(addr['noisy']).convert('RGB')
29+
ref_img_t = np.array(ref_img)
30+
noisy_image_t = np.array(noisy_image)
31+
patch_list = extract_patches(ref_img_t, noisy_image_t)
32+
33+
return patch_list
1534

1635

1736
def create_data_record(out_filename, addrs):
1837
# open the TFRecords file
1938

2039
writer = tf.python_io.TFRecordWriter(out_filename)
21-
for addr, i in zip(addrs, range(1, len(addrs)+1)):
2240

23-
ref = load_image(addr['reference'])
24-
noisy = load_image(addr['noisy'])
41+
patch_list = []
2542

43+
for addr, i in zip(addrs, range(1, len(addrs) + 1)):
44+
patch_list = patch_list + get_patches(addr)
45+
46+
shuffle(patch_list)
47+
48+
for item, i in zip(patch_list, range(1, len(patch_list) + 1)):
2649
# Create a feature
50+
ref = Image.fromarray(item['ref'])
51+
noisy = Image.fromarray(item['noisy'])
2752
feature = {
2853
'reference': _bytes_feature(ref.tobytes()),
2954
'noisy': _bytes_feature(noisy.tobytes())
@@ -33,14 +58,13 @@ def create_data_record(out_filename, addrs):
3358

3459
# Serialize to string and write on the file
3560
writer.write(example.SerializeToString())
36-
print('Image {} wrote to record'.format(i))
61+
print('Patch {} wrote to record'.format(i))
3762
writer.close()
3863
sys.stdout.flush()
3964

4065

4166
base_path = '/home/aftaab/Datasets/Mi3_Aligned'
4267

43-
4468
addrs = []
4569

4670
for directory in os.listdir(base_path):
@@ -55,4 +79,4 @@ def create_data_record(out_filename, addrs):
5579
train_addrs, test_addrs = train_test_split(addrs, test_size=0.2)
5680

5781
create_data_record('Data/train.tfrecords', train_addrs)
58-
create_data_record('Data/test.tfrecords', test_addrs)
82+
create_data_record('Data/val.tfrecords', test_addrs)

download_data.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import requests
2+
import os
3+
4+
5+
def download_file_from_google_drive(id, destination):
6+
URL = "https://drive.google.com/uc?export=download"
7+
8+
session = requests.Session()
9+
10+
response = session.get(URL, params={'id': id}, stream=True)
11+
token = get_confirm_token(response)
12+
13+
if token:
14+
params = {'id': id, 'confirm': token}
15+
response = session.get(URL, params=params, stream=True)
16+
17+
save_response_content(response, destination)
18+
19+
20+
def get_confirm_token(response):
21+
for key, value in response.cookies.items():
22+
if key.startswith('download_warning'):
23+
return value
24+
25+
return None
26+
27+
28+
def save_response_content(response, destination):
29+
CHUNK_SIZE = 32768
30+
31+
with open(destination, "wb") as f:
32+
for chunk in response.iter_content(CHUNK_SIZE):
33+
if chunk: # filter out keep-alive new chunks
34+
f.write(chunk)
35+
36+
37+
if __name__ == "__main__":
38+
39+
train_file_id = '12ctvUrf-Jivr0P9kHThW2GxIZzQf-R3I'
40+
train_file_dest = 'Data/train.tfrecords'
41+
val_file_id = ' 1YovsQgVVNeUyDGyhD0XpAf83HxThkKAZ'
42+
val_file_dest = 'Data/val.tfrecords'
43+
44+
if not os.path.exists('./Data'):
45+
os.mkdir('./Data')
46+
print('Downloading train records...')
47+
download_file_from_google_drive(train_file_id, train_file_dest)
48+
49+
print('Downloading validation records...')
50+
download_file_from_google_drive(val_file_id, val_file_dest)

images/denoised.png

4.63 MB
Loading

images/noisy.png

19.4 MB
Loading

images/reference.bmp

25.7 MB
Binary file not shown.

model.py

Lines changed: 73 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,76 @@
11
import tensorflow as tf
2-
from tensorflow.layers import conv2d, dropout, max_pooling2d, conv2d_transpose
2+
from tensorflow.layers import conv2d, max_pooling2d, conv2d_transpose
33
from PIL import Image
44
import numpy as np
55

66

77
class DenoisingAutoEncoder:
88

9-
def __init__(self, input_shape: tuple, optimizer, is_training: bool):
9+
def __init__(self, input_shape: tuple, batch_input_shape: tuple, optimizer, is_training: bool):
1010
self.sess = tf.Session()
1111
self.input_shape = input_shape
12-
self.input_image = tf.placeholder(tf.float32, shape=(None, 1240, 1240, 3), name="input_image")
13-
self.target_image = tf.placeholder(tf.float32, shape=(None, 1240, 1240, 3), name="target_image")
12+
self.input_image = tf.placeholder(tf.float32, shape=batch_input_shape, name="input_image")
13+
self.target_image = tf.placeholder(tf.float32, shape=batch_input_shape, name="target_image")
14+
self.training = is_training
1415

1516
with tf.name_scope('Encoder'):
16-
self.conv1 = tf.nn.leaky_relu(conv2d(self.input_image, 16, (7, 7), padding='same', use_bias=False))
17-
self.pool1 = max_pooling2d(self.conv1, (4, 4), (4, 4))
18-
self.dropout1 = dropout(self.pool1, 0.2, training=is_training)
19-
self.conv2 = tf.nn.leaky_relu(conv2d(self.dropout1, 20, (5, 5), padding='same', use_bias=False))
20-
self.pool2 = max_pooling2d(self.conv2, (2, 2), (2, 2))
21-
self.dropout2 = dropout(self.pool2, 0.3, training=is_training)
22-
self.conv3 = tf.nn.leaky_relu(conv2d(self.dropout2, 32, (5, 5), padding='same', use_bias=False))
17+
self.conv1 = tf.nn.leaky_relu(conv2d(self.input_image, 16, (5, 5), padding='same'))
18+
self.pool1 = max_pooling2d(self.conv1, (2, 2), (2, 2))
19+
self.conv2 = tf.nn.leaky_relu(conv2d(self.pool1, 32, (3, 3), padding='same'))
20+
self.pool2 = max_pooling2d(self.conv2, (5, 5), (5, 5))
21+
self.conv3 = tf.nn.leaky_relu(conv2d(self.pool2, 64, (3, 3), padding='same'))
2322
self.pool3 = max_pooling2d(self.conv3, (5, 5), (5, 5))
24-
self.dropout3 = dropout(self.pool3, 0.3, training=is_training)
25-
self.latent_repr = tf.nn.leaky_relu(conv2d(self.dropout3, 128, (3, 3), padding='same', use_bias=False))
26-
# self.latent_repr = max_pooling2d(self.conv4, (5, 5), (5, 5))
23+
self.latent_repr = tf.nn.leaky_relu(conv2d(self.pool3, 256, (3, 3), padding='same'))
2724

2825
with tf.name_scope('Decoder'):
29-
self.upsampling1 = tf.image.resize_images(self.latent_repr, (31, 31),
26+
self.upsampling1 = tf.image.resize_images(self.latent_repr, (5, 5),
3027
tf.image.ResizeMethod.BICUBIC)
3128
self.conv5 = tf.nn.leaky_relu(
32-
conv2d_transpose(self.upsampling1, 32, (3, 3), padding='same', use_bias=False))
33-
self.dropout4 = dropout(self.conv5, 0.3, training=is_training)
34-
self.upsampling2 = tf.image.resize_images(self.dropout4, (155, 155), tf.image.ResizeMethod.BICUBIC)
29+
conv2d_transpose(self.upsampling1, 64, (3, 3), padding='same'))
30+
self.upsampling2 = tf.image.resize_images(self.conv5, (25, 25), tf.image.ResizeMethod.BICUBIC)
3531
self.conv6 = tf.nn.leaky_relu(
36-
conv2d_transpose(self.upsampling2, 16, (5, 5), padding='same', use_bias=False))
37-
self.upsampling3 = tf.image.resize_images(self.conv6, (310, 310), tf.image.ResizeMethod.BICUBIC)
38-
self.conv7 = tf.nn.leaky_relu(conv2d_transpose(self.upsampling3, 3, (5, 5), padding='same', use_bias=False))
39-
self.upsampling4 = tf.image.resize_images(self.conv7, (1240, 1240), tf.image.ResizeMethod.BICUBIC)
40-
self.conv8 = tf.nn.leaky_relu(conv2d_transpose(self.upsampling4, 3, (1, 1), padding='same', use_bias=True))
32+
conv2d_transpose(self.upsampling2, 32, (5, 5), padding='same'))
33+
self.upsampling3 = tf.image.resize_images(self.conv6, (50, 50), tf.image.ResizeMethod.BICUBIC)
34+
self.conv7 = tf.nn.leaky_relu(conv2d_transpose(self.upsampling3, 3, (5, 5), padding='same'))
4135

42-
self.output_image = tf.nn.sigmoid(self.conv8)
36+
self.output_image = tf.nn.sigmoid(self.conv7)
4337
self.loss = tf.losses.mean_squared_error(self.target_image, self.output_image)
4438
self.batch_loss = tf.reduce_mean(self.loss)
4539

4640
self.train_step = optimizer.minimize(self.batch_loss)
4741
self.sess.run(tf.global_variables_initializer())
4842
self.saver = tf.train.Saver()
43+
self.__load_weights()
4944

50-
def train(self, epochs: int, ckpt_every: int):
45+
def validate(self):
46+
noisy_batch, target_batch = self.input_fn('Data/val.tfrecords', False, 1024)
47+
val_loss = 0
48+
n_batch = 0
49+
while True:
50+
try:
51+
noisies, targets = self.sess.run([noisy_batch, target_batch])
52+
n_batch += 1
53+
noisies /= 255
54+
targets /= 255
55+
56+
l = self.sess.run([self.batch_loss], feed_dict={self.input_image: noisies,
57+
self.target_image: targets})
58+
print(l)
59+
val_loss += l[0]
60+
except tf.errors.OutOfRangeError:
61+
val_loss = val_loss / n_batch
62+
return val_loss
63+
64+
def train(self, epochs: int, ckpt_every: int, validate: bool):
5165
for e in range(1, epochs + 1):
52-
noisy_batch, target_batch = self.input_fn('Data/train.tfrecords', True, 2)
66+
noisy_batch, target_batch = self.input_fn('Data/train.tfrecords', True, 1024)
5367
epoch_loss = self.train_epoch(noisy_batch, target_batch)
5468
if e % ckpt_every == 0:
5569
self.checkpoint(e, epoch_loss)
56-
print('Epoch Loss = {}, epoch={}'.format(epoch_loss, e))
70+
if validate:
71+
print('Epoch {}, train_loss ={}, val_loss={}'.format(e, epoch_loss, self.validate()))
72+
else:
73+
print('Epoch Loss = {}, epoch={}'.format(epoch_loss, e))
5774

5875
def train_epoch(self, noisy_batch, target_batch):
5976
epoch_loss = 0
@@ -71,23 +88,43 @@ def train_epoch(self, noisy_batch, target_batch):
7188
except tf.errors.OutOfRangeError:
7289
return epoch_loss / n_batch
7390

91+
def __load_weights(self):
92+
weights_file = "Checkpoints/weights-epoch-6loss-0.005/weights-epoch-6loss-0.005.ckpt"
93+
if not self.training:
94+
print('Loaded weights')
95+
self.saver.restore(self.sess, weights_file)
96+
7497
def checkpoint(self, epoch, loss):
7598
epoch = str(epoch)
7699
loss = "{:.3f}".format(loss)
77100
file_name = 'weights-epoch-' + epoch + 'loss-' + loss
78101
save_path = self.saver.save(self.sess, 'Checkpoints/' + file_name + "/" + file_name + '.ckpt')
102+
79103
print('Checkpoint for epoch {}, loss {} saved in {}'.format(epoch, loss, save_path))
80104

81-
def load(self, ckpt_path):
82-
self.saver.restore(self.sess, ckpt_path)
105+
def load(self, saved_path):
106+
self.saver.restore(self.sess, saved_path)
83107

84-
def denoise(self, noisy_image):
85-
latent, output_t = self.sess.run([self.conv8, self.output_image], feed_dict={self.input_image: noisy_image})
86-
print(latent)
108+
def denoise_patch(self, image_patch):
109+
image_patch = image_patch.reshape(1, 50, 50, 3)
110+
latent, output_t = self.sess.run([self.conv7, self.output_image], feed_dict={self.input_image: image_patch})
87111
output_t = np.array(output_t) * 255.0
88112
output_t = output_t.reshape(self.input_shape)
89-
# print(output_t)
90-
return Image.fromarray(output_t.astype('uint8')).convert('RGB')
113+
return output_t
114+
115+
def denoise(self, image_array):
116+
d_image = np.zeros(shape=image_array.shape)
117+
for x in range(50, 3000, 50):
118+
for y in range(50, 3000, 50):
119+
patch = image_array[x - 50:x, y - 50:y, :]
120+
121+
if patch.shape[0] != 50 or patch.shape[1] != 50:
122+
continue
123+
patch = self.denoise_patch(patch)
124+
d_image[x - 50:x, y - 50:y, :] = patch
125+
126+
# print(d_image)
127+
return Image.fromarray(d_image.astype('uint8')).convert('RGB')
91128

92129
def close_session(self):
93130
self.sess.close()
@@ -101,10 +138,10 @@ def parser(record):
101138
parsed = tf.parse_single_example(record, keys_to_feature)
102139
target_image = tf.decode_raw(parsed['reference'], tf.uint8)
103140
target_image = tf.cast(target_image, tf.float32)
104-
target_image = tf.reshape(target_image, [1240, 1240, 3])
141+
target_image = tf.reshape(target_image, [50, 50, 3])
105142
noisy_image = tf.decode_raw(parsed['noisy'], tf.uint8)
106143
noisy_image = tf.cast(noisy_image, tf.float32)
107-
noisy_image = tf.reshape(noisy_image, [1240, 1240, 3])
144+
noisy_image = tf.reshape(noisy_image, [50, 50, 3])
108145
return noisy_image, target_image
109146

110147
def input_fn(self, filename, train, batch_size=4, buffer_size=2048):
@@ -116,15 +153,3 @@ def input_fn(self, filename, train, batch_size=4, buffer_size=2048):
116153
iterator = dataset.make_one_shot_iterator()
117154
noisy_batch, target_batch = iterator.get_next()
118155
return noisy_batch, target_batch
119-
120-
121-
d = DenoisingAutoEncoder((1240, 1240, 3), tf.train.AdamOptimizer(), True)
122-
# d.train(100, 10)
123-
d.load('Checkpoints/weights-epoch-100loss-0.033/weights-epoch-100loss-0.033.ckpt')
124-
sample_img = Image.open('/home/aftaab/Datasets/Mi3_Aligned/Batch_017//IMG_20151116_151714Noisy.bmp').convert(
125-
'RGB').resize([1240, 1240])
126-
sample_img_t = np.array(sample_img).reshape((1, 1240, 1240, 3)) / 255.0
127-
d_img = d.denoise(sample_img_t)
128-
d_img.save('denoised.png', 'PNG')
129-
sample_img.save('noisy.png', 'PNG')
130-
d.close_session()

0 commit comments

Comments
 (0)