/
new_train.py
122 lines (92 loc) · 4.18 KB
/
new_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import print_function
import cv2
import numpy as np
from keras.models import Model
from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D, BatchNormalization, Dropout, Activation, Reshape, Dense, Flatten
from keras.optimizers import Adam, SGD
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping
from keras import backend as K
from keras.regularizers import l2
from keras.utils.visualize_util import plot
from VggDNetGraphProvider import *
from data import load_train_data, load_test_data, random_crops
img_rows = 256
img_cols = 320
smooth = 1.
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return (2. * K.dot(y_true_f, K.transpose(y_pred_f)) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
def get_unet():
# TODO: try adding batch_norm to these between the activation and the conv2d layers
# TODO: add dropout layers?
vgg_provider = VggDNetGraphProvider()
graph = vgg_provider.get_vgg_partial_graph(img_rows, img_cols)
graph.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
return graph
def preprocess(imgs, imgs_mask_train = None, number_augs_per_im = 0):
# TODO: this logic sucks
# TODO: also rotational invariances?
if imgs_mask_train is None:
number_augs_per_im = 0
imgs_p = np.ndarray((imgs.shape[0]*(number_augs_per_im+1), imgs.shape[1], img_rows, img_cols), dtype=np.uint8)
if imgs_mask_train is not None:
imgs_masks_p = np.ndarray((imgs.shape[0]*(number_augs_per_im+1), imgs.shape[1], img_rows, img_cols), dtype=np.uint8)
for i in range(imgs.shape[0]):
imgs_p[i, 0] = cv2.resize(imgs[i, 0], (img_cols, img_rows), interpolation=cv2.INTER_CUBIC)
if imgs_mask_train is not None:
imgs_masks_p[i, 0] = cv2.resize(imgs_mask_train[i, 0], (img_cols, img_rows), interpolation=cv2.INTER_CUBIC)
if imgs_mask_train is None:
return imgs_p
for j in range(number_augs_per_im):
au_img, au_msk = random_crops(imgs, imgs_mask_train, (int(imgs.shape[2] * .8), int(imgs.shape[3]*.8)))
for i in range(imgs.shape[0]):
imgs_p[i*(j+2), 0] = cv2.resize(au_img[i, 0], (img_cols, img_rows), interpolation=cv2.INTER_CUBIC)
imgs_masks_p[i*(j+2), 0] = cv2.resize(au_msk[i, 0], (img_cols, img_rows), interpolation=cv2.INTER_CUBIC)
return imgs_p, imgs_masks_p
def train_and_predict():
print('-'*30)
print('Loading and preprocessing train data...')
print('-'*30)
imgs_train, imgs_mask_train = load_train_data()
imgs_train, imgs_mask_train = preprocess(imgs_train, imgs_mask_train)
imgs_train = imgs_train.astype('float32')
mean = np.mean(imgs_train) # mean for data centering
std = np.std(imgs_train) # std for data normalization
imgs_train -= mean
imgs_train /= std
imgs_mask_train = imgs_mask_train.astype('float32')
imgs_mask_train /= 255. # scale masks to [0, 1]
print('-'*30)
print('Creating and compiling model...')
print('-'*30)
model = get_unet()
model_checkpoint = ModelCheckpoint('unet.hdf5', monitor='val_loss', save_best_only=True)
plot(model, to_file='model.png')
print('-'*30)
print('Fitting model...')
print('-'*30)
early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1)
model.fit(imgs_train, imgs_mask_train, batch_size=1, nb_epoch=500, verbose=1, shuffle=True,
callbacks=[model_checkpoint, early_stopping], validation_split=0.10)
print('-'*30)
print('Loading and preprocessing test data...')
print('-'*30)
imgs_test, imgs_id_test = load_test_data()
imgs_test = preprocess(imgs_test)
imgs_test = imgs_test.astype('float32')
imgs_test -= mean
imgs_test /= std
print('-'*30)
print('Loading saved weights...')
print('-'*30)
model.load_weights('unet.hdf5')
print('-'*30)
print('Predicting masks on test data...')
print('-'*30)
imgs_mask_test = model.predict(imgs_test, verbose=1)
np.save('imgs_mask_test.npy', imgs_mask_test)
if __name__ == '__main__':
train_and_predict()