-
Notifications
You must be signed in to change notification settings - Fork 967
/
gpt_neox.py
150 lines (111 loc) · 4.48 KB
/
gpt_neox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
from torch import nn, einsum
from functools import partial
from einops import rearrange
# helpers
def exists(val):
return val is not None
def cast_tuple(val, depth):
if isinstance(val, tuple):
return val
return (val,) * depth
# classes
class PreNorm(nn.Module):
def __init__(self, dim, norm_class, fn):
super().__init__()
self.fn = fn
self.norm = norm_class(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x, **kwargs):
return self.net(x)
# attention
def dense_attn(q, k, v, attn_mask = None, dropout_fn = None):
scale = q.shape[-1] ** -0.5
sim = einsum('b h i d, b h j d -> b h i j', q, k) * scale
if exists(attn_mask):
sim = sim + attn_mask[None, None, :, :]
attn = sim.softmax(dim=-1)
if exists(dropout_fn):
attn = dropout_fn(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
return out
class Attention(nn.Module):
def __init__(self, dim, heads, seq_len, causal=True, dim_head=64, dropout=0., sparse_attn=False):
super().__init__()
inner_dim = heads * dim_head
self.causal = causal
self.heads = heads
self.scale = dim_head ** -0.5
self.dropout = nn.Dropout(dropout)
if sparse_attn:
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
sparsity_config = VariableSparsityConfig(
num_heads=heads,
attention=("unidirectional" if causal else "bidirectional")
)
self.attn_fn = SparseSelfAttention(
sparsity_config=sparsity_config,
max_seq_length=seq_len,
attn_mask_mode='add'
)
else:
self.attn_fn = partial(dense_attn, dropout_fn = self.dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x, **kwargs):
b, h, device = x.shape[0], self.heads, x.device
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
mask = None
if self.causal:
i, j = q.shape[-2], k.shape[-2]
bool_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
mask = torch.zeros(i, j, device=device).to(q)
mask_value = -torch.finfo(q.dtype).max
mask.masked_fill_(bool_mask, mask_value)
out = self.attn_fn(q, k, v, attn_mask=mask)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class GPTNeoX(nn.Module):
def __init__(self, *, num_tokens, dim, seq_len, depth, heads=8, dim_head=64, attn_dropout=0., ff_dropout=0., sparse_attn=False, use_fused_layernorm=False):
super().__init__()
if not use_fused_layernorm:
norm_class = nn.LayerNorm
else:
from apex.normalization import FusedLayerNorm
norm_class = FusedLayerNorm
self.seq_len = seq_len
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(seq_len, dim)
self.token_emb.weight.data.normal_(0, 0.02)
self.pos_emb.weight.data.normal_(0, 0.02)
self.layers = nn.ModuleList([])
layers_sparse_attn = cast_tuple(sparse_attn, depth)
for _, layer_sparse_attn in zip(range(depth), layers_sparse_attn):
self.layers.append(nn.ModuleList([
PreNorm(dim, norm_class, Attention(dim=dim, heads=heads, seq_len=seq_len, dim_head=dim_head, dropout=attn_dropout, sparse_attn=layer_sparse_attn)),
PreNorm(dim, norm_class, FeedForward(dim=dim, dropout=ff_dropout)),
]))
self.norm = norm_class(dim)
self.to_logits = lambda t: t @ self.token_emb.weight.t()
def forward(self, x, mask=None):
n, device = x.shape[1], x.device
x = self.token_emb(x)
x = self.pos_emb(torch.arange(n, device=device)) + x
for (attn, ff) in self.layers:
x = attn(x) + x
x = ff(x) + x
x = self.norm(x)
return self.to_logits(x)