From 1f5f6b87260f7b7c2d202d59cedc686c9c3f6b1e Mon Sep 17 00:00:00 2001 From: Hobofan Date: Wed, 9 Mar 2016 11:24:31 +0100 Subject: [PATCH] fix/solvers: remove CUDA build flag Removes the cuda build flag requirement for the solvers. --- src/lib.rs | 2 -- src/solvers/mod.rs | 39 ++++++++++++--------------------------- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6f27a6d3..797944ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -118,9 +118,7 @@ extern crate collenchyma_blas as coblas; extern crate collenchyma_nn as conn; pub mod layer; pub mod layers; -#[cfg(feature="cuda")] pub mod solver; -#[cfg(feature="cuda")] pub mod solvers; pub mod weight; diff --git a/src/solvers/mod.rs b/src/solvers/mod.rs index 468bf217..fcda65f0 100644 --- a/src/solvers/mod.rs +++ b/src/solvers/mod.rs @@ -35,7 +35,7 @@ use co::{IBackend, MemoryType, SharedTensor}; use conn::NN; use solver::*; use layer::*; -use util::{ArcLock, native_backend, LayerOps, SolverOps}; +use util::*; trait SGDSolver, NetB: IBackend + LayerOps> : ISolver { fn compute_update_value(&mut self, @@ -74,11 +74,13 @@ trait SGDSolver, NetB: IBackend + LayerOps::new(IBackend::device(backend), &1).unwrap(); match result.add_device(native.device()) { _ => result.sync(native.device()).unwrap() } - if let &MemoryType::Native(ref sumsq_result) = result.get(native.device()).unwrap() { - let sumsq_diff_slice = sumsq_result.as_slice::(); - sumsq_diff += sumsq_diff_slice[0]; - } else { - panic!(); + match result.get(native.device()).unwrap() { + &MemoryType::Native(ref sumsq_result) => { + let sumsq_diff_slice = sumsq_result.as_slice::(); + sumsq_diff += sumsq_diff_slice[0]; + }, + #[cfg(any(feature = "opencl", feature = "cuda"))] + _ => {} } } let l2norm_diff = sumsq_diff.sqrt(); @@ -90,13 +92,7 @@ trait SGDSolver, NetB: IBackend + LayerOps::new(native.device(), &1).unwrap(); - if let &mut MemoryType::Native(ref mut scale) = scale_shared.get_mut(native.device()).unwrap() { - let scale_slice = scale.as_mut_slice::(); - scale_slice[0] = scale_factor; - } else { - panic!(); - } + let mut scale_shared = native_scalar(scale_factor); for weight_gradient in net_gradients { let mut gradient = weight_gradient.write().unwrap(); @@ -117,13 +113,8 @@ trait SGDSolver, NetB: IBackend + LayerOps::new(native.device(), &1).unwrap(); - if let &mut MemoryType::Native(ref mut scale) = scale_factor_shared.get_mut(native.device()).unwrap() { - let scale_slice = scale.as_mut_slice::(); - scale_slice[0] = scale_factor; - } else { - panic!(); - } + + let mut scale_factor_shared = native_scalar(scale_factor); // self.backend().scal_plain(&scale_factor_shared, &mut gradient).unwrap(); self.backend().scal(&mut scale_factor_shared, &mut gradient).unwrap(); } @@ -141,13 +132,7 @@ trait SGDSolver, NetB: IBackend + LayerOps { let native = native_backend(); - let mut decay_shared = SharedTensor::::new(native.device(), &1).unwrap(); - if let &mut MemoryType::Native(ref mut decay) = decay_shared.get_mut(native.device()).unwrap() { - let decay_slice = decay.as_mut_slice::(); - decay_slice[0] = local_decay; - } else { - panic!(); - } + let decay_shared = native_scalar(local_decay); let gradient = &mut weight_gradient.write().unwrap(); // gradient.regularize_l2(self.backend(), &decay_shared); // backend.axpy_plain(&decay_shared, &self.data, &mut self.diff).unwrap();