-
Notifications
You must be signed in to change notification settings - Fork 21
/
monarch_linear.py
52 lines (41 loc) · 2.12 KB
/
monarch_linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import math
import torch
import torch.nn as nn
from torch.nn import init
from einops import rearrange
from src.models.layers.structured_linear import StructuredLinear
from src.models.layers.blockdiag_butterfly_multiply import blockdiag_butterfly_multiply
from src.utils.utils import get_logger
logger = get_logger()
class MonarchLinear(StructuredLinear):
def __init__(self, *args, nblocks=4, **kwargs):
super().__init__(*args, **kwargs)
in_blksz = int(math.ceil(self.in_features / nblocks))
out_blksz = int(math.ceil(self.out_features / nblocks))
self.in_features_extended = in_blksz * nblocks
self.out_features_extended = out_blksz * nblocks
if self.in_features_extended < self.out_features_extended:
self.blkdiag1 = nn.Parameter(torch.empty(nblocks, in_blksz, in_blksz))
self.blkdiag2 = nn.Parameter(torch.empty(nblocks, out_blksz, in_blksz))
else:
self.blkdiag1 = nn.Parameter(torch.empty(nblocks, out_blksz, in_blksz))
self.blkdiag2 = nn.Parameter(torch.empty(nblocks, out_blksz, out_blksz))
self.reset_parameters()
logger.info(f'Linear class {self.__class__}: saving={self.saving}')
def reset_parameters(self) -> None:
# Mimic init.kaiming_uniform: https://github.com/pytorch/pytorch/blob/24087d07ca7ffa244575d259711dd7c99245a67a/torch/nn/init.py#L360
for blkdiag in [self.blkdiag1, self.blkdiag2]:
fan_in = blkdiag.shape[-1]
gain = init.calculate_gain(nonlinearity='leaky_relu', param=math.sqrt(5))
std = gain / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
blkdiag.uniform_(-bound, bound)
self.reset_parameters_bias()
@property
def saving(self):
return ((self.blkdiag1.numel() + self.blkdiag2.numel())
/ (self.in_features * self.out_features))
def forward_matmul(self, x):
output = blockdiag_butterfly_multiply(self.preprocess(x), self.blkdiag1, self.blkdiag2)
return self.postprocess(output)