Skip to content

PumasAI/SimpleChains.jl

main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
June 28, 2023 15:09
June 28, 2023 15:09
src
August 28, 2023 14:36
August 28, 2023 18:31
September 6, 2021 10:55
February 16, 2022 17:46
August 12, 2022 15:12

SimpleChains

Stable Dev Build Status codecov-img

SimpleChains.jl only supports simple chains, but it intends to be fast for small problems on the CPU. Currently, valgrad! is the only means of extracting gradient information.

using SimpleChains, BenchmarkTools

# 24 covariates each per 200 observations
x = rand(24, 200); # 24 inputs per 200 observations

# 2 responses each per 200 observations
y = Matrix{Float64}(undef, 2, 200) .= randn.() .* 10;

schain = SimpleChain(
  static(24), # input dimension (optional)
  TurboDense{true}(tanh, 8), # dense layer with bias that maps to 8 outputs and applies `tanh` activation
  SimpleChains.Dropout(0.2), # dropout layer
  TurboDense{false}(identity, 2), # dense layer without bias that maps to 2 outputs and `identity` activation
  SquaredLoss(y)
);

p = SimpleChains.init_params(schain)
g = similar(p);

# Entirely in place evaluation
@benchmark valgrad!($g, $schain, $x, $p) # dropout active

For comparison, using Flux, we would write:

using Flux

chain = Chain(
  Dense(24, 8, tanh; bias = true),
  Flux.Dropout(0.2),
  Dense(8, 2, identity; bias = false)
);
chain.layers[2].active = true # activate dropout

ya = Array(y);

@benchmark gradient(Flux.params($chain)) do
  Flux.mse($chain($x), $ya)
end

Benchmark results:

julia> @benchmark valgrad!($g, $schain, $x, $p) # dropout active
BechmarkTools.Trial: 10000 samples with 6 evaluations.
 Range (min … max):  5.274 μs …  33.075 μs  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     5.657 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.646 μs ± 349.777 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%
 Memory estimate: 0 bytes, allocs estimate: 0.
  
julia> @benchmark gradient(Flux.params($chain)) do
         Flux.mse($chain($x), $ya)
       end
BechmarkTools.Trial: 10000 samples with 1 evaluations.
 Range (min … max):   83.674 μs …   4.865 ms  ┊ GC (min … max): 0.00%93.21%
 Time  (median):      96.430 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   106.897 μs ± 197.689 μs  ┊ GC (mean ± σ):  7.96% ±  4.22%
 Memory estimate: 182.55 KiB, allocs estimate: 316.