From 8637c7ab7c392f69f5fa058ea02937797b6849a2 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Fri, 5 Apr 2024 08:51:47 -0700 Subject: [PATCH] Added ELU class - closes #51 --- CHANGELOG.md | 4 ++++ lib/torch.rb | 1 + lib/torch/nn/elu.rb | 20 ++++++++++++++++++++ lib/torch/nn/functional.rb | 8 ++++++++ test/nn/activations_test.rb | 6 ++++++ 5 files changed, 39 insertions(+) create mode 100644 lib/torch/nn/elu.rb diff --git a/CHANGELOG.md b/CHANGELOG.md index 56a20df..0f7b3d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.15.1 (unreleased) + +- Added `ELU` class + ## 0.15.0 (2024-02-28) - Updated LibTorch to 2.2.0 diff --git a/lib/torch.rb b/lib/torch.rb index 4aa34ff..7d8aa2d 100644 --- a/lib/torch.rb +++ b/lib/torch.rb @@ -123,6 +123,7 @@ require_relative "torch/nn/feature_alpha_dropout" # nn activations +require_relative "torch/nn/elu" require_relative "torch/nn/hardshrink" require_relative "torch/nn/leaky_relu" require_relative "torch/nn/log_sigmoid" diff --git a/lib/torch/nn/elu.rb b/lib/torch/nn/elu.rb new file mode 100644 index 0000000..55305e5 --- /dev/null +++ b/lib/torch/nn/elu.rb @@ -0,0 +1,20 @@ +module Torch + module NN + class ELU < Module + def initialize(alpha: 1, inplace: false) + super() + @alpha = alpha + @inplace = inplace + end + + def forward(input) + F.elu(input, alpha: @alpha, inplace: @inplace) + end + + def extra_inspect + inplace_str = @inplace ? ", inplace: true" : "" + format("alpha: %s", @alpha) + inplace_str + end + end + end +end diff --git a/lib/torch/nn/functional.rb b/lib/torch/nn/functional.rb index 43c51fd..3c85d75 100644 --- a/lib/torch/nn/functional.rb +++ b/lib/torch/nn/functional.rb @@ -174,6 +174,14 @@ def pad(input, pad, mode: "constant", value: 0) # activation layers + def elu(input, alpha: 1, inplace: false) + if inplace + NN.elu!(input, alpha) + else + NN.elu(input, alpha) + end + end + def hardshrink(input, lambd = 0.5) Torch.hardshrink(input, lambd) end diff --git a/test/nn/activations_test.rb b/test/nn/activations_test.rb index 67064e2..73d3652 100644 --- a/test/nn/activations_test.rb +++ b/test/nn/activations_test.rb @@ -1,6 +1,12 @@ require_relative "../test_helper" class ActivationsTest < Minitest::Test + def test_elu + m = Torch::NN::ELU.new + input = Torch.randn(2) + _output = m.call(input) + end + def test_hardshrink m = Torch::NN::Hardshrink.new input = Torch.randn(2)