Skip to content

Commit 3e852f4

Browse files
committed
created datasetrecords
0 parents  commit 3e852f4

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.idea/

create_datasets.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import glob
2+
import sys, os
3+
from PIL import Image
4+
from sklearn.model_selection import train_test_split
5+
import tensorflow as tf
6+
7+
8+
def _bytes_feature(value):
9+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
10+
11+
12+
def load_image(addr):
13+
img = Image.open(addr).resize([1240, 1240])
14+
return img
15+
16+
17+
def createDataRecord(out_filename, addrs):
18+
# open the TFRecords file
19+
20+
writer = tf.python_io.TFRecordWriter(out_filename)
21+
for addr, i in zip(addrs, range(1, len(addrs)+1)):
22+
23+
ref = load_image(addr['reference'])
24+
noisy = load_image(addr['noisy'])
25+
26+
# Create a feature
27+
feature = {
28+
'reference': _bytes_feature(ref.tobytes()),
29+
'noisy': _bytes_feature(noisy.tobytes())
30+
}
31+
# Create an example protocol buffer
32+
example = tf.train.Example(features=tf.train.Features(feature=feature))
33+
34+
# Serialize to string and write on the file
35+
writer.write(example.SerializeToString())
36+
print('Image {} wrote to record'.format(i))
37+
writer.close()
38+
sys.stdout.flush()
39+
40+
41+
base_path = '/home/aftaab/Datasets/Mi3_Aligned'
42+
43+
44+
addrs = []
45+
46+
for directory in os.listdir(base_path):
47+
if not os.path.isdir(os.path.join(base_path, directory)):
48+
continue
49+
ref_path = base_path + "/" + directory + "/*Reference.bmp"
50+
noisy_path = base_path + "/" + directory + "/*Noisy.bmp"
51+
ref_image = glob.glob(ref_path)[0]
52+
noisy_image = glob.glob(noisy_path)[0]
53+
addrs.append({'reference': ref_image, 'noisy': noisy_image})
54+
55+
train_addrs, test_addrs = train_test_split(addrs, test_size=0.2)
56+
57+
createDataRecord('Data/train.tfrecords', train_addrs)
58+
createDataRecord('Data/test.tfrecords', test_addrs)

0 commit comments

Comments
 (0)