-
Notifications
You must be signed in to change notification settings - Fork 0
/
Training.py
44 lines (33 loc) · 1.46 KB
/
Training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from NeuralNetwork import Configuration
from NeuralNetwork.DataProvider import UNetDataProvider
from NeuralNetwork.Trainer import Trainer
from NeuralNetwork.UNet import UNet
from Setting import SEED, TRAINING_DATA_PATH, TRAINING_LOG_PATH, TRAINING_OUTPUT_PATH
import numpy as np
import os
import random
import tensorflow as tf
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.set_random_seed(SEED)
Configuration.LoadConfiguration('.')
networkCostParameter = {
'class-weights': Configuration.classWeights
}
networkParameters = {
'layers': Configuration.layerNumber,
'featuresRoot': Configuration.featuresRoot
}
network = UNet(3, Configuration.classNumber, 'cross-entropy', networkCostParameter, **networkParameters)
dataProvider = UNetDataProvider(Configuration.batchSize, 3, Configuration.classNumber, TRAINING_DATA_PATH)
dataNumber = 0
for patchName in dataProvider.patchData.keys():
dataNumber += len(dataProvider.patchData[patchName]['file-names'])
trainerParameters = {
'learning-rate': Configuration.learningRate,
'decay-steps': dataNumber // Configuration.batchSize,
'decay-rate': Configuration.learningRateDecay
}
trainer = Trainer(network, 'adam', False, **trainerParameters)
trainer.Train(dataProvider, dataNumber // Configuration.batchSize, TRAINING_OUTPUT_PATH, TRAINING_LOG_PATH, Configuration.epoch, 0.5, 100, True, False)