Fix: mean shape in compatible with input shape #5719

Open
wants to merge 1 commit into
from
Jump to file or symbol
Failed to load files and symbols.
+7 −1
Split
View
@@ -256,7 +256,13 @@ def set_mean(self, in_, mean):
if len(ms) != 3:
raise ValueError('Mean shape invalid')
if ms != self.inputs[in_][1:]:
- raise ValueError('Mean shape incompatible with input shape.')
+ print(self.inputs[in_])
+ in_shape = self.inputs[in_][1:]
+ m_min, m_max = mean.min(), mean.max()
+ normal_mean = (mean - m_min) / (m_max - m_min)
+ mean = resize_image(normal_mean.transpose((1,2,0)),
+ in_shape[1:]).transpose((2,0,1)) * \
+ (m_max - m_min) + m_min
self.mean[in_] = mean
def set_input_scale(self, in_, scale):