-
Notifications
You must be signed in to change notification settings - Fork 0
/
CustomGenerator.py
103 lines (75 loc) · 3.13 KB
/
CustomGenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from skimage.io import imread
from skimage.transform import resize
import numpy as np
from tensorflow.keras.utils import Sequence
from DataLoader import *
from skimage.color import lab2rgb
from Utils import preprocess_data
'''
Request as input the list of
filenames (pathes) of the training set,
to read from disk in batches
(too big training set for mermory)
'''
class CustomSequence(Sequence):
def __init__(self, filenames_in, batch_size, color_space = 'grey', shuffle = True, max = 102., patch_size=128, n_patches=200): #label not provided as x = y
self.max = max
self.color_space = color_space
self.batch_size = batch_size
self.patch_size = patch_size
self.n_patches = n_patches
self.shuffle = shuffle
self.x = filenames_in
self.datalen = len(filenames_in)
self.indexes = np.arange(self.datalen)
self.counter=0
if self.shuffle:
np.random.shuffle(self.indexes)
def __getitem__(self, index):
batch_indexes = np.array(self.indexes[index*self.batch_size:(index+1)*self.batch_size])
filenames_batch = [self.x[i] for i in batch_indexes]
#filenames_batch = self.x[batch_indexes]
x_batch = load_patches_from_filenames(filenames_batch, self.patch_size, True, self.n_patches, grayscale=False)
x_batch = preprocess_data(x_batch)
return x_batch, x_batch
def __len__(self):
return self.datalen // self.batch_size
def on_epoch_end(self):
self.indexes = np.arange(self.datalen)
if self.shuffle:
np.random.shuffle(self.indexes)
'''
class CustomSequence(Sequence):
def __init__(self, filenames_in, batch_size, color_space = 'cielab', shuffle = True, max = 102., patch_size=128, n_patches=200): #label not provided as x = y
self.max = max
self.color_space = color_space
self.batch_size = batch_size
self.patch_size = patch_size
self.n_patches = n_patches
self.shuffle = shuffle
self.x = filenames_in
self.datalen = len(filenames_in)
self.indexes = np.arange(self.datalen)
if self.shuffle:
np.random.shuffle(self.indexes)
def __getitem__(self, index):
batch_indexes = np.array(self.indexes[index*self.batch_size:(index+1)*self.batch_size])
filenames_batch = [self.x[i] for i in batch_indexes]
#filenames_batch = self.x[batch_indexes]
x_batch = load_patches_from_filenames(filenames_batch, self.patch_size, True, self.n_patches, grayscale=False)
#visualize_results(x_batch[0], x_batch[1], "a")
if self.color_space == 'cielab':
x_batch = prepare_dataset_colorssim(x_batch)
x_batch = x_batch / self.max
else:
x_batch = x_batch / 255.
#print("getting an item of shape :")
#print(x_batch.shape)
return x_batch, x_batch
def __len__(self):
return self.datalen // self.batch_size
def on_epoch_end(self):
self.indexes = np.arange(self.datalen)
if self.shuffle:
np.random.shuffle(self.indexes)
'''