-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
executable file
·113 lines (100 loc) · 3.31 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch import nn
from beartype import beartype
from beartype.typing import Optional, Union, Tuple, Dict, Any
import sys
sys.path.append("../../../third_party/x-transformers")
from x_transformers import Encoder, Decoder, TransformerWrapper, AutoregressiveWrapper
class llama(nn.Module):
@beartype
def __init__(
self,
dim=4096,
num_text_tokens=1024,
text_max_seq_len=4096,
decoder_depth=32,
attn_dim_head=128, # 128 = dim/attn_heads = 4096/32
attn_heads=32,
kv_heads=8,
attn_layers_kwargs: dict = dict(),
flash_attn=True,
text_forgetful_causal_mask_prob=0.1,
autoregressive_wrapper_kwargs: dict = dict(
pad_value = 0,
ignore_index = -100
),
scaling=1
):
super().__init__()
# assert decoder_depth*scaling%1==0, "llama layer num should be integer"
self.decoder = TransformerWrapper(
num_tokens = num_text_tokens + 1,
max_seq_len = text_max_seq_len,
attn_layers = Decoder(
dim = dim,
depth = int(decoder_depth*scaling),
dim_head = attn_dim_head,
heads = attn_heads,
kv_heads = kv_heads,
num_mem_kv = 0,
cross_attend = False,
rotary_pos_emb = True,
flash_attn = flash_attn,
**attn_layers_kwargs
).to(torch.bfloat16)
)
self.wrapped_decoder = AutoregressiveWrapper(
self.decoder,
mask_prob = text_forgetful_causal_mask_prob,
**autoregressive_wrapper_kwargs
)
@beartype
def forward(self, x, seq_len):
x = self.wrapped_decoder.generate(x, seq_len=seq_len)
return x
class llama_sliced(nn.Module):
@beartype
def __init__(
self,
dim=4096,
num_text_tokens=1024,
text_max_seq_len=4096,
decoder_depth=32,
attn_dim_head=128, # 128 = dim/attn_heads = 4096/32
attn_heads=32,
attn_layers_kwargs: dict = dict(),
flash_attn=True,
text_forgetful_causal_mask_prob=0.1,
autoregressive_wrapper_kwargs: dict = dict(
pad_value = 0,
ignore_index = -100
),
sliced_id: int = 0,
sliced_num: int = 1,
):
assert decoder_depth % sliced_num == 0, f"The decoder is not divisible."
super().__init__()
self.decoder = TransformerWrapper(
num_tokens = num_text_tokens + 1,
max_seq_len = text_max_seq_len,
attn_layers = Decoder(
dim = dim,
depth = int(decoder_depth/sliced_num),
dim_head = attn_dim_head,
heads = attn_heads,
num_mem_kv = 1,
cross_attend = False,
rotary_pos_emb = True,
flash_attn = flash_attn,
**attn_layers_kwargs
)
)
self.wrapped_decoder = AutoregressiveWrapper(
self.decoder,
mask_prob = text_forgetful_causal_mask_prob,
**autoregressive_wrapper_kwargs
)
@beartype
def forward(self, x, seq_len):
x = self.wrapped_decoder.generate(x, seq_len=seq_len)
return x