This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
highway.py
64 lines (55 loc) · 2.62 KB
/
highway.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
A [Highway layer](https://arxiv.org/abs/1505.00387) that does a gated combination of a linear
transformation and a non-linear transformation of its input.
"""
from typing import Callable
import torch
from overrides import overrides
class Highway(torch.nn.Module):
"""
A [Highway layer](https://arxiv.org/abs/1505.00387) does a gated combination of a linear
transformation and a non-linear transformation of its input. :math:`y = g * x + (1 - g) *
f(A(x))`, where :math:`A` is a linear transformation, :math:`f` is an element-wise
non-linearity, and :math:`g` is an element-wise gate, computed as :math:`sigmoid(B(x))`.
This module will apply a fixed number of highway layers to its input, returning the final
result.
# Parameters
input_dim : `int`, required
The dimensionality of :math:`x`. We assume the input has shape `(batch_size, ...,
input_dim)`.
num_layers : `int`, optional (default=`1`)
The number of highway layers to apply to the input.
activation : `Callable[[torch.Tensor], torch.Tensor]`, optional (default=`torch.nn.functional.relu`)
The non-linearity to use in the highway layers.
"""
def __init__(
self,
input_dim: int,
num_layers: int = 1,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
) -> None:
super().__init__()
self._input_dim = input_dim
self._layers = torch.nn.ModuleList(
[torch.nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]
)
self._activation = activation
for layer in self._layers:
# We should bias the highway layer to just carry its input forward. We do that by
# setting the bias on `B(x)` to be positive, because that means `g` will be biased to
# be high, so we will carry the input forward. The bias on `B(x)` is the second half
# of the bias vector in each Linear layer.
layer.bias[input_dim:].data.fill_(1)
@overrides
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
current_input = inputs
for layer in self._layers:
projected_input = layer(current_input)
linear_part = current_input
# NOTE: if you modify this, think about whether you should modify the initialization
# above, too.
nonlinear_part, gate = projected_input.chunk(2, dim=-1)
nonlinear_part = self._activation(nonlinear_part)
gate = torch.sigmoid(gate)
current_input = gate * linear_part + (1 - gate) * nonlinear_part
return current_input