Skip to content

Commit 0df96f5

Browse files
authored
Merge 4bba04b into a7f0f07
2 parents a7f0f07 + 4bba04b commit 0df96f5

File tree

11 files changed

+580
-80
lines changed

11 files changed

+580
-80
lines changed

applications/llama_3.2_1b/configs/llama32_1b.json

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
"rope_base": 500000.0,
1313
"dtype": "bfloat16",
1414
"use_aie_final_norm": true,
15-
"use_aie_ffn_gemm": true,
16-
"use_aie_ffn_silu": true,
17-
"use_aie_ffn_mul": true,
15+
"use_aie_ffn_gemm": false,
16+
"use_aie_ffn_silu": false,
17+
"use_aie_ffn_mul": false,
18+
"use_aie_ffn_swiglu": true,
1819
"use_aie_attn_projection_gemm": true,
1920
"use_aie_rope": true,
2021
"use_aie_norm1": true,

applications/llama_3.2_1b/src/block/feed_forward.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from src.operator.aie_gemm import AIEGEMM
1414
from src.operator.aie_gemv import AIEGEMV
1515
from src.operator.aie_silu import AIESiLU
16+
from src.operator.aie_swiglu_prefill import AIESwiGLUPrefill
17+
from src.operator.aie_swiglu_decode import AIESwiGLUDecode
18+
from ml_dtypes import bfloat16
1619

1720

1821
class FeedForward(nn.Module):
@@ -25,6 +28,16 @@ def __init__(
2528
super().__init__()
2629
self.cfg = cfg.copy()
2730

31+
assert (
32+
cfg["use_aie_ffn_swiglu"]
33+
and not (
34+
cfg["use_aie_ffn_silu"]
35+
or cfg["use_aie_ffn_gemm"]
36+
or cfg["use_aie_ffn_mul"]
37+
)
38+
or not cfg["use_aie_ffn_swiglu"]
39+
), "Cannot mix fused SwiGLU with individual AIE operators."
40+
2841
self.emb_dim = cfg["emb_dim"]
2942
self.hidden_dim = cfg["hidden_dim"]
3043

@@ -36,10 +49,17 @@ def __init__(
3649
else:
3750
self.silu = nn.SiLU()
3851

39-
self.emb_dim = cfg["emb_dim"]
40-
self.hidden_dim = cfg["hidden_dim"]
52+
if self.cfg["use_aie_ffn_swiglu"]:
53+
self.aie_swiglu_prefill = AIESwiGLUPrefill(
54+
seq_len=prompt_length,
55+
embedding_dim=self.emb_dim,
56+
hidden_dim=self.hidden_dim,
57+
)
58+
if self.cfg["use_kv_cache"]:
59+
self.aie_swiglu_decode = AIESwiGLUDecode(
60+
embedding_dim=self.emb_dim, hidden_dim=self.hidden_dim
61+
)
4162

42-
# Initialize FFN up and down projections
4363
if self.cfg["use_aie_ffn_gemm"]:
4464
if self.cfg["use_kv_cache"]:
4565
M_prefill = prompt_length
@@ -108,8 +128,15 @@ def forward(self, x):
108128
or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1)
109129
)
110130

131+
is_prefill = not is_vector or not self.cfg["use_kv_cache"]
111132
is_decode_with_kv = is_vector and self.cfg["use_kv_cache"]
112133

134+
if self.cfg["use_aie_ffn_swiglu"]:
135+
if is_prefill:
136+
return self.aie_swiglu_prefill(x)
137+
else:
138+
return self.aie_swiglu_decode(x)
139+
113140
if is_decode_with_kv and self.cfg["use_aie_gemv"]:
114141
x_fc1 = self.aie_fc1_gemv(x)
115142
x_fc2 = self.aie_fc2_gemv(x)
@@ -131,6 +158,21 @@ def forward(self, x):
131158
return self.fc3(x).view(original_shape)
132159

133160
def assign_weights(self, l, fc1, fc2, fc3):
161+
if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
162+
self.aie_fc1_gemv.weight = fc1
163+
self.aie_fc2_gemv.weight = fc2
164+
self.aie_fc3_gemv.weight = fc3
165+
166+
if self.cfg["use_aie_ffn_swiglu"]:
167+
self.aie_swiglu_prefill.weights_1 = fc1
168+
self.aie_swiglu_prefill.weights_2 = fc2
169+
self.aie_swiglu_prefill.weights_3 = fc3
170+
if self.cfg["use_kv_cache"]:
171+
self.aie_swiglu_decode.weights_1 = fc1
172+
self.aie_swiglu_decode.weights_2 = fc2
173+
self.aie_swiglu_decode.weights_3 = fc3
174+
return
175+
134176
self.fc1.weight = assign(
135177
self.fc1.weight,
136178
fc1,
@@ -146,8 +188,3 @@ def assign_weights(self, l, fc1, fc2, fc3):
146188
fc3,
147189
f"model.layers.{l}.mlp.down_proj.weight",
148190
)
149-
150-
if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
151-
self.aie_fc1_gemv.weight = fc1
152-
self.aie_fc2_gemv.weight = fc2
153-
self.aie_fc3_gemv.weight = fc3

applications/llama_3.2_1b/src/compilation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,13 @@ class SourceArtifact(CompilationArtifact):
113113

114114

115115
class XclbinArtifact(CompilationArtifact):
116-
def __init__(self, path, depends, kernel_name="MLIR_AIE", extra_flags=None):
116+
def __init__(
117+
self, path, depends, kernel_name="MLIR_AIE", extra_flags=None, xclbin_input=None
118+
):
117119
super().__init__(path, depends)
118120
self.kernel_name = kernel_name
119121
self.extra_flags = extra_flags if extra_flags is not None else []
122+
self.xclbin_input = xclbin_input
120123

121124

122125
class InstsBinArtifact(CompilationArtifact):
@@ -295,6 +298,10 @@ def compile(self, artifacts):
295298
"--xclbin-name=" + str(first_xclbin.path),
296299
"--xclbin-kernel-name=" + first_xclbin.kernel_name,
297300
]
301+
if first_xclbin.xclbin_input is not None:
302+
compile_cmd += [
303+
"--xclbin-input=" + str(first_xclbin.xclbin_input.path)
304+
]
298305
if do_compile_insts_bin:
299306
first_insts_bin = mlir_sources_to_insts_bins[mlir_source][
300307
0
@@ -414,7 +421,7 @@ def _rename_symbols(self, artifact):
414421
result = subprocess.run(cmd, capture_output=True, text=True)
415422

416423
if result.returncode == 0:
417-
logging.info(f"Successfully renamed symbols in: {artifact.path.name}")
424+
logging.debug(f"Successfully renamed symbols in: {artifact.path.name}")
418425
else:
419426
raise RuntimeError(f"Symbol renaming failed: {result.stderr}")
420427

applications/llama_3.2_1b/src/model_with_json.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def dtype_from_string(inp):
3737
"use_aie_ffn_gemm": (bool, False, "[FFN] GEMM"),
3838
"use_aie_ffn_mul": (bool, False, "[FFN] Elementwise Mul"),
3939
"use_aie_ffn_silu": (bool, False, "[FFN] SiLU"),
40+
"use_aie_ffn_swiglu": (bool, False, "[FFN] Runlist-based SwiGLU"),
4041
"use_aie_residual": (bool, False, "[Transformer] Residual Addition"),
4142
"use_aie_norm1": (bool, False, "[Transformer] Pre Norm"),
4243
"use_aie_norm2": (bool, False, "[Transformer] Post Norm"),
@@ -81,6 +82,14 @@ def format_option(name, value):
8182
dont_print |= {"use_aie_regular_mha"}
8283
else:
8384
dont_print |= {"use_aie_fused_mha"}
85+
if cfg["use_aie_ffn_swiglu"]:
86+
dont_print |= {
87+
"use_aie_ffn_gemm",
88+
"use_aie_ffn_mul",
89+
"use_aie_ffn_silu",
90+
}
91+
else:
92+
dont_print |= {"use_aie_ffn_swiglu"}
8493

8594
console.print(
8695
"AIE Configuration ([green]✔[/green] = AIE NPU / [red]✘[/red] = CPU):",

applications/llama_3.2_1b/src/operator/aie_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def prepare_runtime(cls):
6060
cls.static_data_pool[buffer_data] = bo
6161

6262
for op in cls.registered_operators:
63+
if len(op.kernels) == 0:
64+
# Operator likely is used as a sub-operator in another operator and does need any setup.
65+
continue
6366
logging.info(f"Preparing runtime for AIE operator: {op.__class__.__name__}")
6467

6568
# Set up for each kernel

applications/llama_3.2_1b/src/operator/aie_elementwise_mul.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,21 @@
1515
PythonGeneratedMLIRArtifact,
1616
)
1717
from ..utils import torch_to_numpy, numpy_to_torch
18+
from pathlib import Path
1819

1920

2021
class AIEElementwiseMul(AIEOperatorBase):
2122
"""AIE-accelerated element-wise multiplication"""
2223

23-
def __init__(self, size, num_columns=None, num_channels=None, tile_size=None):
24+
def __init__(
25+
self,
26+
size,
27+
num_columns=None,
28+
num_channels=None,
29+
tile_size=None,
30+
trace_size=0,
31+
do_set_up=True,
32+
):
2433
self.size = size
2534

2635
# Enforce ShimDMA limits for elementwise_mul (uses 2 inputs per core)
@@ -37,12 +46,13 @@ def __init__(self, size, num_columns=None, num_channels=None, tile_size=None):
3746
self.num_columns = num_columns
3847
self.num_channels = num_channels
3948
self.tile_size = tile_size
49+
self.trace_size = trace_size
50+
self.do_set_up = do_set_up
4051

4152
AIEOperatorBase.__init__(self)
4253

43-
def set_up(self):
44-
# Compilation artifacts
45-
file_name_base = f"mul_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t"
54+
def get_artifacts(self, prefix="eltwise_mul_"):
55+
file_name_base = f"{prefix}{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t"
4656

4757
mlir_artifact = PythonGeneratedMLIRArtifact.new(
4858
f"{file_name_base}.mlir",
@@ -57,7 +67,7 @@ def set_up(self):
5767
self.num_columns,
5868
self.num_channels,
5969
self.tile_size,
60-
0,
70+
self.trace_size,
6171
],
6272
)
6373

@@ -75,6 +85,20 @@ def set_up(self):
7585
f"{file_name_base}.bin", depends=[mlir_artifact]
7686
)
7787

88+
return xclbin_artifact, insts_artifact
89+
90+
def set_up(self):
91+
# If this operator is only used as a sub-operator in another operator that sets it up, we should skip the setup here as those artifacts and buffers may not be needed.
92+
if not self.do_set_up:
93+
return
94+
95+
# Compilation artifacts
96+
xclbin_artifact, insts_artifact = self.get_artifacts()
97+
98+
# Override device_type in the mlir_artifact's callback_args if needed
99+
mlir_artifact = xclbin_artifact.depends[0]
100+
mlir_artifact.callback_args[0] = self.device_manager.device_type
101+
78102
artifacts = [xclbin_artifact, insts_artifact]
79103
self.add_artifacts(artifacts)
80104

0 commit comments

Comments
 (0)