diff --git a/src/layer/fully_connected_layer.rs b/src/layer/fully_connected_layer.rs index 9e2dbcf..476c6eb 100644 --- a/src/layer/fully_connected_layer.rs +++ b/src/layer/fully_connected_layer.rs @@ -4,17 +4,20 @@ use layer::weights_matrix::WeightsMatrix; use layer::traits::{ ProcessInputSignal, CalculateOutputErrorSignal, + PropagateErrorSignal, + ApplyErrorSignalCorrection, HasOutputSignal, HasErrorSignal, }; use errors::{Result}; +use utils::{LearnRate, LearnMomentum}; #[derive(Debug, Clone, PartialEq)] pub struct FullyConnectedLayer { - weights : WeightsMatrix, - deltas : WeightsMatrix, - outputs : SignalBuffer, - gradients: ErrorSignalBuffer + weights : WeightsMatrix, + deltas : WeightsMatrix, + outputs : SignalBuffer, + error_signal: ErrorSignalBuffer } impl FullyConnectedLayer { @@ -22,9 +25,9 @@ impl FullyConnectedLayer { let (inputs, outputs) = (weights.inputs(), weights.outputs()); Ok(FullyConnectedLayer{ weights, - deltas : WeightsMatrix::zeros(inputs, outputs)?, - outputs : SignalBuffer::zeros(outputs)?, - gradients: ErrorSignalBuffer::zeros(outputs)?, + deltas : WeightsMatrix::zeros(inputs, outputs)?, + outputs : SignalBuffer::zeros(outputs)?, + error_signal: ErrorSignalBuffer::zeros(outputs)?, }) } @@ -35,28 +38,43 @@ impl FullyConnectedLayer { } impl ProcessInputSignal for FullyConnectedLayer { - fn process_input_signal(&mut self, signal: &SignalBuffer) { - if self.output_signal().len() != signal.len() { - panic!("Error: unmatching signals to layer size") // TODO: Replace this with error. (Needs to change trait.) + fn process_input_signal(&mut self, input_signal: &SignalBuffer) { + if self.output_signal().len() != input_signal.len() { + panic!("Error: unmatching signals to layer size") // TODO: Replace this with error. } use ndarray::linalg::general_mat_vec_mul; - general_mat_vec_mul(1.0, &self.weights.view(), &signal.biased_view(), 1.0, &mut self.outputs.view_mut()) + general_mat_vec_mul(1.0, &self.weights.view(), &input_signal.biased_view(), 1.0, &mut self.outputs.view_mut()) } } impl CalculateOutputErrorSignal for FullyConnectedLayer { fn calculate_output_error_signal(&mut self, target_signal: &SignalBuffer) { if self.output_signal().len() != target_signal.len() { - panic!("Error: unmatching signals to layer size") // TODO: Replace this with error. (Needs to change trait.) + panic!("Error: unmatching signals to layer size") // TODO: Replace this with error. } use ndarray::Zip; - Zip::from(&mut self.gradients.view_mut()) + Zip::from(&mut self.error_signal.view_mut()) .and(&self.outputs.view()) .and(&target_signal.view()) - .apply(|g, &t, &o| { - *g = t - o + .apply(|e, &o, &t| { + *e = t - o } - ); + ) + } +} + +impl PropagateErrorSignal for FullyConnectedLayer { + fn propagate_error_signal

(&mut self, _propagated: &mut P) + where P: HasErrorSignal + { + unimplemented!() + } +} + +impl ApplyErrorSignalCorrection for FullyConnectedLayer { + fn apply_error_signal_correction(&mut self, _signal: &SignalBuffer, _lr: LearnRate, _lm: LearnMomentum) { + // Nothing to do here since there are no weights that could be updated! + unimplemented!() } } @@ -72,10 +90,10 @@ impl HasOutputSignal for FullyConnectedLayer { impl HasErrorSignal for FullyConnectedLayer { fn error_signal(&self) -> &ErrorSignalBuffer { - &self.gradients + &self.error_signal } fn error_signal_mut(&mut self) -> &mut ErrorSignalBuffer { - &mut self.gradients + &mut self.error_signal } }