/
activations.py
63 lines (53 loc) · 2.29 KB
/
activations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
from torch.distributions.exponential import Exponential
class Snake(nn.Module):
'''
Implementation of the serpentine-like sine-based periodic activation function:
.. math::
Snake_a := x + \frac{1}{a} sin^2(ax) = x - \frac{1}{2a}cos{2ax} + \frac{1}{2a}
This activation function is able to better extrapolate to previously unseen data,
especially in the case of learning periodic functions
Shape:
- Input: (N, *) where * means, any number of additional
dimensions
- Output: (N, *), same shape as the input
Parameters:
- a - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, a=None, trainable=True):
'''
Initialization.
Args:
in_features: shape of the input
a: trainable parameter
trainable: sets `a` as a trainable parameter
`a` is initialized to 1 by default, higher values = higher-frequency,
5-50 is a good starting point if you already think your data is periodic,
consider starting lower e.g. 0.5 if you think not, but don't worry,
`a` will be trained along with the rest of your model
'''
super(Snake,self).__init__()
self.in_features = in_features if isinstance(in_features, list) else [in_features]
# Initialize `a`
if a is not None:
self.a = Parameter(torch.ones(self.in_features) * a) # create a tensor out of alpha
else:
m = Exponential(torch.tensor([0.1]))
self.a = Parameter((m.rsample(self.in_features)).squeeze()) # random init = mix of frequencies
self.a.requiresGrad = trainable # set the training of `a` to true
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake ∶= x + 1/a* sin^2 (xa)
'''
return x + (1.0/self.a) * pow(sin(x * self.a), 2)