-
Notifications
You must be signed in to change notification settings - Fork 5
/
Trainer.lua
143 lines (104 loc) · 3.6 KB
/
Trainer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
require 'torch'
require 'xlua'
require 'optim'
require 'MiniBatch'
do
local Trainer = torch.class('Trainer')
function Trainer:__init(dataset, model, criterion, batchSize, scaleMin, scaleMax, sizeX, sizeY)
self.dataset = dataset
self.model = model
self.criterion = criterion
self.model:cuda()
self.criterion:cuda()
self.batchSize = batchSize
self.scaleMin = scaleMin
self.scaleMax = scaleMax
self.sizeX = sizeX
self.sizeY = sizeY
self.parameters, self.gradParameters = model:getParameters()
self.confusion = optim.ConfusionMatrix(self.dataset.classes)
self.minibatch = MiniBatch(self.dataset, self.batchSize, self.scaleMin, self.scaleMax, self.sizeX, self.sizeY)
end
function Trainer:setAdamParam(learningRate, learningRateDecay, epsilon, beta1, beta2)
self.optimState = {
learningRate = learningRate,
learningRateDecay = learningRateDecay,
epsilon = epsilon,
beta1 = beta1,
beta2 = beta2
}
self.optimMethod = optim.adam
print("Adam parameters:")
print(self.optimState)
end
-- Viewed the output/target in order to feed the confusion matrix
local transpose = function(input)
local res = res or input:new()
res:resizeAs(input):copy(input)
res = res:transpose(2,4):transpose(2,3):contiguous() -- bdhw -> bwhd -> bhwd
res = res:view(res:size(1)*res:size(2)*res:size(3), res:size(4)):contiguous()
return res
end
local transpose_back = function(input, grad)
local res = res or grad:new()
res:resizeAs(grad):copy(grad)
res = res:view(input:size(1),input:size(3), input:size(4), input:size(2))
res = res:transpose(2,3):transpose(2,4):contiguous() -- bhwd -> bwhd -> bdhw
return res
end
function Trainer:train(nbIteration, extra_ratio)
local time = sys.clock()
self.confusion:zero()
self.model:training()
collectgarbage()
epoch = epoch or 1
print("==> Doing epoch on training data:")
print(string.format('==> epoch #%04d [batchSize = %d]', epoch, self.batchSize))
for t=1, nbIteration do
if not opt.silent then
xlua.progress(t, nbIteration)
end
-- Create mini-batch
self.batch = self.minibatch:getTrainingBatch(extra_ratio)
local feval = function()
-- reset gradients
self.gradParameters:zero()
local outputs = self.model:forward(self.batch.inputs)
local t_outputs = transpose(outputs)
local t_targets = self.batch.targets:view(-1):contiguous()
local f = self.criterion:forward(t_outputs,t_targets)
local df_do = self.criterion:backward(t_outputs,t_targets)
local t_df_do = transpose_back(outputs, df_do)
self.model:backward(self.batch.inputs,t_df_do)
self.confusion:batchAdd(t_outputs,t_targets)
return f,self.gradParameters
end
-- optimize on current mini-batch
self.optimMethod(feval, self.parameters, self.optimState)
end
time = sys.clock() - time
print(string.format('\tTime : %s', xlua.formatTime(time)))
return self.confusion
end
function Trainer:valid(nbIteration)
local time = sys.clock()
self.confusion:zero()
self.model:evaluate()
collectgarbage()
epoch = epoch or 1
print("==> Doing epoch on validation data:")
print(string.format('==> epoch #%04d [batchSize = %d]', epoch, self.batchSize))
for t=1, nbIteration do
if not opt.silent then
xlua.progress(t, nbIteration)
end
-- Create mini-batch
self.batch = self.minibatch:getValidationBatch()
local outputs = self.model:forward(self.batch.inputs)
self.confusion:batchAdd(transpose(outputs),self.batch.targets:view(-1))
end
time = sys.clock() - time
print(string.format('\tTime : %s', xlua.formatTime(time)))
return self.confusion
end
end