forked from alexisbellot/Graphical-modelling-continuous-time
-
Notifications
You must be signed in to change notification settings - Fork 0
/
locally_connected.py
91 lines (71 loc) · 2.74 KB
/
locally_connected.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import math
import torch
import torch.nn as nn
class LocallyConnected(nn.Module):
"""Local linear layer, i.e. Conv1dLocal() with filter size 1.
Args:
num_linear: num of local linear layers, i.e.
in_features: m1
out_features: m2
bias: whether to include bias or not
Shape:
- Input: [n, d, m1]
- Output: [n, d, m2]
Attributes:
weight: [d, m1, m2]
bias: [d, m2]
"""
def __init__(self, num_linear, input_features, output_features, bias=True):
super(LocallyConnected, self).__init__()
self.num_linear = num_linear
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.Tensor(num_linear, input_features, output_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(num_linear, output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter("bias", None)
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
k = 1.0 / self.input_features
bound = math.sqrt(k)
nn.init.uniform_(self.weight, -bound, bound)
if self.bias is not None:
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input_: torch.Tensor):
# [n, d, 1, m2] = [n, d, 1, m1] @ [1, d, m1, m2]
out = torch.matmul(input_.unsqueeze(dim=2), self.weight.unsqueeze(dim=0))
out = out.squeeze(dim=2)
if self.bias is not None:
# [n, d, m2] += [d, m2]
out += self.bias
return out
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return "num_linear={}, in_features={}, out_features={}, bias={}".format(
self.num_linear, self.in_features, self.out_features, self.bias is not None
)
def main():
n, d, m1, m2 = 2, 3, 5, 7
# numpy
import numpy as np
input_numpy = np.random.randn(n, d, m1)
weight = np.random.randn(d, m1, m2)
output_numpy = np.zeros([n, d, m2])
for j in range(d):
# [n, m2] = [n, m1] @ [m1, m2]
output_numpy[:, j, :] = input_numpy[:, j, :] @ weight[j, :, :]
# torch
torch.set_default_dtype(torch.double)
input_torch = torch.from_numpy(input_numpy)
locally_connected = LocallyConnected(d, m1, m2, bias=False)
locally_connected.weight.data[:] = torch.from_numpy(weight)
output_torch = locally_connected(input_torch)
# compare
print(torch.allclose(output_torch, torch.from_numpy(output_numpy)))
if __name__ == "__main__":
main()