|
| 1 | +import tensorflow as tf |
| 2 | +from tensorflow.layers import conv2d, dropout, max_pooling2d, conv2d_transpose |
| 3 | +from PIL import Image |
| 4 | +import numpy as np |
| 5 | + |
| 6 | + |
| 7 | +class DenoisingAutoEncoder: |
| 8 | + |
| 9 | + def __init__(self, input_shape: tuple, optimizer, is_training: bool): |
| 10 | + self.sess = tf.Session() |
| 11 | + self.input_shape = input_shape |
| 12 | + self.input_image = tf.placeholder(tf.float32, shape=(None, 1240, 1240, 3), name="input_image") |
| 13 | + self.target_image = tf.placeholder(tf.float32, shape=(None, 1240, 1240, 3), name="target_image") |
| 14 | + |
| 15 | + with tf.name_scope('Encoder'): |
| 16 | + self.conv1 = tf.nn.leaky_relu(conv2d(self.input_image, 16, (7, 7), padding='same', use_bias=False)) |
| 17 | + self.pool1 = max_pooling2d(self.conv1, (4, 4), (4, 4)) |
| 18 | + self.dropout1 = dropout(self.pool1, 0.2, training=is_training) |
| 19 | + self.conv2 = tf.nn.leaky_relu(conv2d(self.dropout1, 20, (5, 5), padding='same', use_bias=False)) |
| 20 | + self.pool2 = max_pooling2d(self.conv2, (2, 2), (2, 2)) |
| 21 | + self.dropout2 = dropout(self.pool2, 0.3, training=is_training) |
| 22 | + self.conv3 = tf.nn.leaky_relu(conv2d(self.dropout2, 32, (5, 5), padding='same', use_bias=False)) |
| 23 | + self.pool3 = max_pooling2d(self.conv3, (5, 5), (5, 5)) |
| 24 | + self.dropout3 = dropout(self.pool3, 0.3, training=is_training) |
| 25 | + self.conv4 = tf.nn.leaky_relu(conv2d(self.dropout3, 64, (3, 3), padding='same', use_bias=False)) |
| 26 | + self.latent_repr = max_pooling2d(self.conv4, (31, 31), (31, 31)) |
| 27 | + |
| 28 | + with tf.name_scope('Decoder'): |
| 29 | + self.upsampling1 = tf.image.resize_images(self.latent_repr, (31, 31), |
| 30 | + tf.image.ResizeMethod.BICUBIC) |
| 31 | + self.conv5 = tf.nn.leaky_relu( |
| 32 | + conv2d_transpose(self.upsampling1, 32, (3, 3), padding='same', use_bias=False)) |
| 33 | + self.dropout4 = dropout(self.conv5, 0.3, training=is_training) |
| 34 | + self.upsampling2 = tf.image.resize_images(self.dropout4, (155, 155), tf.image.ResizeMethod.BICUBIC) |
| 35 | + self.conv6 = tf.nn.leaky_relu( |
| 36 | + conv2d_transpose(self.upsampling2, 16, (5, 5), padding='same', use_bias=False)) |
| 37 | + self.upsampling3 = tf.image.resize_images(self.conv6, (310, 310), tf.image.ResizeMethod.BICUBIC) |
| 38 | + self.conv7 = tf.nn.leaky_relu(conv2d_transpose(self.upsampling3, 3, (5, 5), padding='same', use_bias=False)) |
| 39 | + self.upsampling4 = tf.image.resize_images(self.conv7, (1240, 1240), tf.image.ResizeMethod.BICUBIC) |
| 40 | + self.conv8 = tf.nn.leaky_relu(conv2d_transpose(self.upsampling4, 3, (1, 1), padding='same', use_bias=True)) |
| 41 | + |
| 42 | + self.output_image = tf.nn.sigmoid(self.conv8) |
| 43 | + self.loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.target_image, logits=self.output_image) |
| 44 | + self.batch_loss = tf.reduce_mean(self.loss) |
| 45 | + |
| 46 | + self.train_step = optimizer.minimize(self.batch_loss) |
| 47 | + self.sess.run(tf.global_variables_initializer()) |
| 48 | + self.saver = tf.train.Saver() |
| 49 | + |
| 50 | + def train(self, epochs: int, ckpt_every: int): |
| 51 | + for e in range(1, epochs + 1): |
| 52 | + noisy_batch, target_batch = self.input_fn('Data/train.tfrecords', True, 2) |
| 53 | + epoch_loss = self.train_epoch(noisy_batch, target_batch) |
| 54 | + if e % ckpt_every == 0: |
| 55 | + self.checkpoint(e, epoch_loss) |
| 56 | + print('Epoch Loss = {}, epoch={}'.format(epoch_loss, e)) |
| 57 | + |
| 58 | + def train_epoch(self, noisy_batch, target_batch): |
| 59 | + epoch_loss = 0 |
| 60 | + n_batch = 0 |
| 61 | + while True: |
| 62 | + try: |
| 63 | + noisies, targets = self.sess.run([noisy_batch, target_batch]) |
| 64 | + n_batch += 1 |
| 65 | + noisies /= 255 |
| 66 | + targets /= 255 |
| 67 | + |
| 68 | + _, l = self.sess.run([self.train_step, self.batch_loss], |
| 69 | + feed_dict={self.input_image: noisies, self.target_image: targets}) |
| 70 | + epoch_loss += l |
| 71 | + except tf.errors.OutOfRangeError: |
| 72 | + return epoch_loss / n_batch |
| 73 | + |
| 74 | + def checkpoint(self, epoch, loss): |
| 75 | + epoch = str(epoch) |
| 76 | + loss = "{:.3f}".format(loss) |
| 77 | + file_name = 'weights-epoch-' + epoch + 'loss-' + loss |
| 78 | + save_path = self.saver.save(self.sess, 'Checkpoints/' + file_name + "/" + file_name + '.ckpt') |
| 79 | + print('Checkpoint for epoch {}, loss {} saved in {}'.format(epoch, loss, save_path)) |
| 80 | + |
| 81 | + def load(self, ckpt_path): |
| 82 | + self.saver.restore(self.sess, ckpt_path) |
| 83 | + |
| 84 | + def denoise(self, noisy_image): |
| 85 | + latent, output_t = self.sess.run([self.conv8, self.output_image], feed_dict={self.input_image: noisy_image}) |
| 86 | + print(latent) |
| 87 | + output_t = np.array(output_t) * 255.0 |
| 88 | + output_t = output_t.reshape(self.input_shape) |
| 89 | + # print(output_t) |
| 90 | + return Image.fromarray(output_t.astype('uint8')).convert('RGB') |
| 91 | + |
| 92 | + def close_session(self): |
| 93 | + self.sess.close() |
| 94 | + |
| 95 | + @staticmethod |
| 96 | + def parser(record): |
| 97 | + keys_to_feature = { |
| 98 | + "reference": tf.FixedLenFeature([], tf.string), |
| 99 | + "noisy": tf.FixedLenFeature([], tf.string) |
| 100 | + } |
| 101 | + parsed = tf.parse_single_example(record, keys_to_feature) |
| 102 | + target_image = tf.decode_raw(parsed['reference'], tf.uint8) |
| 103 | + target_image = tf.cast(target_image, tf.float32) |
| 104 | + target_image = tf.reshape(target_image, [1240, 1240, 3]) |
| 105 | + noisy_image = tf.decode_raw(parsed['noisy'], tf.uint8) |
| 106 | + noisy_image = tf.cast(noisy_image, tf.float32) |
| 107 | + noisy_image = tf.reshape(noisy_image, [1240, 1240, 3]) |
| 108 | + return noisy_image, target_image |
| 109 | + |
| 110 | + def input_fn(self, filename, train, batch_size=4, buffer_size=2048): |
| 111 | + dataset = tf.data.TFRecordDataset(filename) |
| 112 | + dataset = dataset.map(self.parser) |
| 113 | + if train: |
| 114 | + dataset = dataset.shuffle(buffer_size=buffer_size) |
| 115 | + dataset = dataset.batch(batch_size) |
| 116 | + iterator = dataset.make_one_shot_iterator() |
| 117 | + noisy_batch, target_batch = iterator.get_next() |
| 118 | + return noisy_batch, target_batch |
| 119 | + |
| 120 | + |
| 121 | +d = DenoisingAutoEncoder((1240, 1240, 3), tf.train.AdamOptimizer(), True) |
| 122 | +d.train(10, 5) |
| 123 | +# d.load('Checkpoints/weights-epoch-30loss-0.709/weights-epoch-30loss-0.709.ckpt') |
| 124 | +sample_img = Image.open('/home/aftaab/Datasets/Mi3_Aligned/Batch_001//IMG_20160202_015247Noisy.bmp').convert( |
| 125 | + 'RGB').resize([1240, 1240]) |
| 126 | +sample_img_t = np.array(sample_img).reshape((1, 1240, 1240, 3)) / 255.0 |
| 127 | +d_img = d.denoise(sample_img_t) |
| 128 | +d_img.save('denoised.png', 'PNG') |
| 129 | +sample_img.save('noisy.png', 'PNG') |
| 130 | +d.close_session() |
0 commit comments