Skip to content

Commit

Permalink
Added ELU class - closes #51
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Apr 5, 2024
1 parent a982efb commit 8637c7a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.15.1 (unreleased)

- Added `ELU` class

## 0.15.0 (2024-02-28)

- Updated LibTorch to 2.2.0
Expand Down
1 change: 1 addition & 0 deletions lib/torch.rb
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
require_relative "torch/nn/feature_alpha_dropout"

# nn activations
require_relative "torch/nn/elu"
require_relative "torch/nn/hardshrink"
require_relative "torch/nn/leaky_relu"
require_relative "torch/nn/log_sigmoid"
Expand Down
20 changes: 20 additions & 0 deletions lib/torch/nn/elu.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module Torch
module NN
class ELU < Module
def initialize(alpha: 1, inplace: false)
super()
@alpha = alpha
@inplace = inplace
end

def forward(input)
F.elu(input, alpha: @alpha, inplace: @inplace)
end

def extra_inspect
inplace_str = @inplace ? ", inplace: true" : ""
format("alpha: %s", @alpha) + inplace_str
end
end
end
end
8 changes: 8 additions & 0 deletions lib/torch/nn/functional.rb
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ def pad(input, pad, mode: "constant", value: 0)

# activation layers

def elu(input, alpha: 1, inplace: false)
if inplace
NN.elu!(input, alpha)
else
NN.elu(input, alpha)
end
end

def hardshrink(input, lambd = 0.5)
Torch.hardshrink(input, lambd)
end
Expand Down
6 changes: 6 additions & 0 deletions test/nn/activations_test.rb
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
require_relative "../test_helper"

class ActivationsTest < Minitest::Test
def test_elu
m = Torch::NN::ELU.new
input = Torch.randn(2)
_output = m.call(input)
end

def test_hardshrink
m = Torch::NN::Hardshrink.new
input = Torch.randn(2)
Expand Down

0 comments on commit 8637c7a

Please sign in to comment.