Skip to content

Commit

Permalink
added some initial stub impls for FullyConnectedLayer for PropagateEr…
Browse files Browse the repository at this point in the history
…rorSignal and ApplyErrorSignalCorrection traits.
  • Loading branch information
Robbepop committed Sep 20, 2017
1 parent df157fd commit 5b694da
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions src/layer/fully_connected_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,30 @@ use layer::weights_matrix::WeightsMatrix;
use layer::traits::{
ProcessInputSignal,
CalculateOutputErrorSignal,
PropagateErrorSignal,
ApplyErrorSignalCorrection,
HasOutputSignal,
HasErrorSignal,
};
use errors::{Result};
use utils::{LearnRate, LearnMomentum};

#[derive(Debug, Clone, PartialEq)]
pub struct FullyConnectedLayer {
weights : WeightsMatrix,
deltas : WeightsMatrix,
outputs : SignalBuffer,
gradients: ErrorSignalBuffer
weights : WeightsMatrix,
deltas : WeightsMatrix,
outputs : SignalBuffer,
error_signal: ErrorSignalBuffer
}

impl FullyConnectedLayer {
pub(crate) fn with_weights(weights: WeightsMatrix) -> Result<Self> {
let (inputs, outputs) = (weights.inputs(), weights.outputs());
Ok(FullyConnectedLayer{
weights,
deltas : WeightsMatrix::zeros(inputs, outputs)?,
outputs : SignalBuffer::zeros(outputs)?,
gradients: ErrorSignalBuffer::zeros(outputs)?,
deltas : WeightsMatrix::zeros(inputs, outputs)?,
outputs : SignalBuffer::zeros(outputs)?,
error_signal: ErrorSignalBuffer::zeros(outputs)?,
})
}

Expand All @@ -35,28 +38,43 @@ impl FullyConnectedLayer {
}

impl ProcessInputSignal for FullyConnectedLayer {
fn process_input_signal(&mut self, signal: &SignalBuffer) {
if self.output_signal().len() != signal.len() {
panic!("Error: unmatching signals to layer size") // TODO: Replace this with error. (Needs to change trait.)
fn process_input_signal(&mut self, input_signal: &SignalBuffer) {
if self.output_signal().len() != input_signal.len() {
panic!("Error: unmatching signals to layer size") // TODO: Replace this with error.
}
use ndarray::linalg::general_mat_vec_mul;
general_mat_vec_mul(1.0, &self.weights.view(), &signal.biased_view(), 1.0, &mut self.outputs.view_mut())
general_mat_vec_mul(1.0, &self.weights.view(), &input_signal.biased_view(), 1.0, &mut self.outputs.view_mut())
}
}

impl CalculateOutputErrorSignal for FullyConnectedLayer {
fn calculate_output_error_signal(&mut self, target_signal: &SignalBuffer) {
if self.output_signal().len() != target_signal.len() {
panic!("Error: unmatching signals to layer size") // TODO: Replace this with error. (Needs to change trait.)
panic!("Error: unmatching signals to layer size") // TODO: Replace this with error.
}
use ndarray::Zip;
Zip::from(&mut self.gradients.view_mut())
Zip::from(&mut self.error_signal.view_mut())
.and(&self.outputs.view())
.and(&target_signal.view())
.apply(|g, &t, &o| {
*g = t - o
.apply(|e, &o, &t| {
*e = t - o
}
);
)
}
}

impl PropagateErrorSignal for FullyConnectedLayer {
fn propagate_error_signal<P>(&mut self, _propagated: &mut P)
where P: HasErrorSignal
{
unimplemented!()
}
}

impl ApplyErrorSignalCorrection for FullyConnectedLayer {
fn apply_error_signal_correction(&mut self, _signal: &SignalBuffer, _lr: LearnRate, _lm: LearnMomentum) {
// Nothing to do here since there are no weights that could be updated!
unimplemented!()
}
}

Expand All @@ -72,10 +90,10 @@ impl HasOutputSignal for FullyConnectedLayer {

impl HasErrorSignal for FullyConnectedLayer {
fn error_signal(&self) -> &ErrorSignalBuffer {
&self.gradients
&self.error_signal
}

fn error_signal_mut(&mut self) -> &mut ErrorSignalBuffer {
&mut self.gradients
&mut self.error_signal
}
}

0 comments on commit 5b694da

Please sign in to comment.