|
| 1 | +import glob |
| 2 | +import sys, os |
| 3 | +from PIL import Image |
| 4 | +from sklearn.model_selection import train_test_split |
| 5 | +import tensorflow as tf |
| 6 | + |
| 7 | + |
| 8 | +def _bytes_feature(value): |
| 9 | + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) |
| 10 | + |
| 11 | + |
| 12 | +def load_image(addr): |
| 13 | + img = Image.open(addr).resize([1240, 1240]) |
| 14 | + return img |
| 15 | + |
| 16 | + |
| 17 | +def createDataRecord(out_filename, addrs): |
| 18 | + # open the TFRecords file |
| 19 | + |
| 20 | + writer = tf.python_io.TFRecordWriter(out_filename) |
| 21 | + for addr, i in zip(addrs, range(1, len(addrs)+1)): |
| 22 | + |
| 23 | + ref = load_image(addr['reference']) |
| 24 | + noisy = load_image(addr['noisy']) |
| 25 | + |
| 26 | + # Create a feature |
| 27 | + feature = { |
| 28 | + 'reference': _bytes_feature(ref.tobytes()), |
| 29 | + 'noisy': _bytes_feature(noisy.tobytes()) |
| 30 | + } |
| 31 | + # Create an example protocol buffer |
| 32 | + example = tf.train.Example(features=tf.train.Features(feature=feature)) |
| 33 | + |
| 34 | + # Serialize to string and write on the file |
| 35 | + writer.write(example.SerializeToString()) |
| 36 | + print('Image {} wrote to record'.format(i)) |
| 37 | + writer.close() |
| 38 | + sys.stdout.flush() |
| 39 | + |
| 40 | + |
| 41 | +base_path = '/home/aftaab/Datasets/Mi3_Aligned' |
| 42 | + |
| 43 | + |
| 44 | +addrs = [] |
| 45 | + |
| 46 | +for directory in os.listdir(base_path): |
| 47 | + if not os.path.isdir(os.path.join(base_path, directory)): |
| 48 | + continue |
| 49 | + ref_path = base_path + "/" + directory + "/*Reference.bmp" |
| 50 | + noisy_path = base_path + "/" + directory + "/*Noisy.bmp" |
| 51 | + ref_image = glob.glob(ref_path)[0] |
| 52 | + noisy_image = glob.glob(noisy_path)[0] |
| 53 | + addrs.append({'reference': ref_image, 'noisy': noisy_image}) |
| 54 | + |
| 55 | +train_addrs, test_addrs = train_test_split(addrs, test_size=0.2) |
| 56 | + |
| 57 | +createDataRecord('Data/train.tfrecords', train_addrs) |
| 58 | +createDataRecord('Data/test.tfrecords', test_addrs) |
0 commit comments