-
Notifications
You must be signed in to change notification settings - Fork 21
/
blockdiag_butterfly_multiply.py
111 lines (100 loc) · 5.03 KB
/
blockdiag_butterfly_multiply.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import math
import numpy as np
import torch
from torch.nn import functional as F
from einops import rearrange
def blockdiag_butterfly_multiply_reference(x, w1_bfly, w2_bfly, version=2):
"""
This implementation is slow but more likely to be correct.
There are 3 implementations, which should all yield the same answer
Arguments:
x: (batch, n)
w1_bfly: (k, q, p), where k = n / p
w2_bfly: (l, s, r), where l = k * q / r = n * q / (p * r)
Outputs:
out: (batch, m), where m = l * s = n * s * q / (p * r)
"""
if version not in [1, 2, 3]:
raise NotImplementedError('version must be either 1, 2, or 3')
batch, n = x.shape
k, q, p = w1_bfly.shape
l, s, r = w2_bfly.shape
assert k * p == n
assert l * r == k * q
x_reshaped = rearrange(x, 'b (k p) -> b k p', k=k)
if version == 1: # Implementation 1 (only works for when k = q = p = l = s = r = sqrt(n))
assert k == q == p == l == s == r == int(math.sqrt(n))
return torch.einsum('bkp,kqp,qlk->blq', x_reshaped, w1_bfly, w2_bfly).reshape(batch, n)
elif version == 2: # Implementation 2
out1 = torch.einsum('kqp,bkp->bkq', w1_bfly, x_reshaped)
out1 = rearrange(rearrange(out1, 'b k q -> b (k q)'), 'b (r l) -> b l r', l=l)
return torch.einsum('lsr,blr->bsl', w2_bfly, out1).reshape(batch, s * l)
# Implementation 3: most likely to be correct, but it's the slowest
elif version == 3:
w1_dense = torch.block_diag(*torch.unbind(w1_bfly, dim=0))
out1 = F.linear(x, w1_dense)
out1 = rearrange(out1, 'b (r l) -> b (l r)', l=l)
w2_dense = torch.block_diag(*torch.unbind(w2_bfly, dim=0))
out2 = F.linear(out1, w2_dense)
out2 = rearrange(out2, 'b (l s) -> b (s l)', l=l)
return out2
class BlockdiagButterflyMultiply(torch.autograd.Function):
"""This is a faster implementation, with careful memory copies for the fastest
bmm performance.
The backward pass is also written manually with careful memory copies.
Arguments:
x: (batch, n)
w1_bfly: (k, q, p), where k = n / p
w2_bfly: (l, s, r), where l = k * q / r = n * q / (p * r)
Outputs:
out: (batch, m), where m = l * s = n * s * q / (p * r)
"""
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, w1_bfly, w2_bfly):
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = np.prod(batch_shape)
k, q, p = w1_bfly.shape
l, s, r = w2_bfly.shape
assert k * p == n
assert l * r == k * q
x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
out1 = torch.empty(batch_dim, k, q, device=x.device, dtype=x.dtype).transpose(0, 1)
out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1)
out1 = out1.transpose(0, 1).reshape(batch_dim, r, l).transpose(-1, -2).contiguous().transpose(0, 1)
out2 = torch.empty(batch_dim, l, s, device=x.device, dtype=x.dtype).transpose(0, 1)
out2 = torch.bmm(out1, w2_bfly.transpose(-1, -2), out=out2)
out2 = out2.permute(1, 2, 0).reshape(*batch_shape, s * l)
ctx.save_for_backward(x, w1_bfly, w2_bfly, out1)
return out2
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout):
x, w1_bfly, w2_bfly, out1 = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = np.prod(batch_shape)
k, q, p = w1_bfly.shape
l, s, r = w2_bfly.shape
# assert k * p == n
# assert l * r == k * q
dx, dw1_bfly, dw2_bfly = None, None, None
# dout_reshaped = dout.reshape(batch_dim, sqrtn, sqrtn).permute(2, 1, 0).contiguous()
dout_reshaped = dout.reshape(batch_dim, s, l).transpose(-1, -2).contiguous()
dout_reshaped = dout_reshaped.transpose(0, 1)
if ctx.needs_input_grad[2]:
# dw2_bfly = torch.empty(l, s, r, device=w2_bfly.device, dtype=w2_bfly.dtype)
# dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1, out=dw2_bfly)
dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1.conj())
if ctx.needs_input_grad[1] or ctx.needs_input_grad[0]:
dout1 = torch.empty(batch_dim, l, r, device=x.device, dtype=x.dtype).transpose(0, 1)
dout1 = torch.bmm(dout_reshaped, w2_bfly.conj(), out=dout1)
dout1 = dout1.transpose(0, 1).transpose(-1, -2).contiguous().reshape(batch_dim, k, q).transpose(0, 1)
# dout1 = dout1.permute(1, 2, 0).contiguous().transpose(0, 1)
if ctx.needs_input_grad[0]:
dx = torch.empty(batch_dim, k, p, device=x.device, dtype=x.dtype)
dx = torch.bmm(dout1, w1_bfly.conj(), out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n)
if ctx.needs_input_grad[1]:
x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
dw1_bfly = torch.bmm(dout1.transpose(-1, -2), x_reshaped.conj())
return dx, dw1_bfly, dw2_bfly
blockdiag_butterfly_multiply = BlockdiagButterflyMultiply.apply