/
attn.py
119 lines (98 loc) · 4.52 KB
/
attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from math import sqrt
class FullAttention(nn.Module):
'''
The Attention operation
'''
def __init__(self, scale=None, attention_dropout=0.1):
super(FullAttention, self).__init__()
self.scale = scale
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1./sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
return V.contiguous()
class AttentionLayer(nn.Module):
'''
The Multi-head Self-Attention (MSA) Layer
'''
def __init__(self, d_model, n_heads, d_keys=None, d_values=None, mix=True, dropout = 0.1):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model//n_heads)
d_values = d_values or (d_model//n_heads)
self.inner_attention = FullAttention(scale=None, attention_dropout = dropout)
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
self.mix = mix
def forward(self, queries, keys, values):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out = self.inner_attention(
queries,
keys,
values,
)
if self.mix:
out = out.transpose(2,1).contiguous()
out = out.view(B, L, -1)
return self.out_projection(out)
class TwoStageAttentionLayer(nn.Module):
'''
The Two Stage Attention (TSA) Layer
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
'''
def __init__(self, seg_num, factor, d_model, n_heads, d_ff = None, dropout=0.1):
super(TwoStageAttentionLayer, self).__init__()
d_ff = d_ff or 4*d_model
self.time_attention = AttentionLayer(d_model, n_heads, dropout = dropout)
self.dim_sender = AttentionLayer(d_model, n_heads, dropout = dropout)
self.dim_receiver = AttentionLayer(d_model, n_heads, dropout = dropout)
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
self.dropout = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.norm4 = nn.LayerNorm(d_model)
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model))
self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model))
def forward(self, x):
#Cross Time Stage: Directly apply MSA to each dimension
batch = x.shape[0]
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
time_enc = self.time_attention(
time_in, time_in, time_in
)
dim_in = time_in + self.dropout(time_enc)
dim_in = self.norm1(dim_in)
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
dim_in = self.norm2(dim_in)
#Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b = batch)
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat = batch)
dim_buffer = self.dim_sender(batch_router, dim_send, dim_send)
dim_receive = self.dim_receiver(dim_send, dim_buffer, dim_buffer)
dim_enc = dim_send + self.dropout(dim_receive)
dim_enc = self.norm3(dim_enc)
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
dim_enc = self.norm4(dim_enc)
final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b = batch)
return final_out