|
|
@@ -256,7 +256,13 @@ def set_mean(self, in_, mean): |
|
|
if len(ms) != 3:
|
|
|
raise ValueError('Mean shape invalid')
|
|
|
if ms != self.inputs[in_][1:]:
|
|
|
- raise ValueError('Mean shape incompatible with input shape.')
|
|
|
+ print(self.inputs[in_])
|
|
|
+ in_shape = self.inputs[in_][1:]
|
|
|
+ m_min, m_max = mean.min(), mean.max()
|
|
|
+ normal_mean = (mean - m_min) / (m_max - m_min)
|
|
|
+ mean = resize_image(normal_mean.transpose((1,2,0)),
|
|
|
+ in_shape[1:]).transpose((2,0,1)) * \
|
|
|
+ (m_max - m_min) + m_min
|
|
|
self.mean[in_] = mean
|
|
|
|
|
|
def set_input_scale(self, in_, scale):
|
|
|
|