Skip to content

Commit 818b783

Browse files
authored
Merge b53e764 into 130b6ea
2 parents 130b6ea + b53e764 commit 818b783

File tree

12 files changed

+455
-162
lines changed

12 files changed

+455
-162
lines changed

applications/llama_3.2_1b/configs/llama32_1b.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"use_aie_residual": true,
2424
"use_aie_regular_mha": false,
2525
"use_aie_fused_mha": true,
26-
"use_aie_final_gemm": false,
26+
"use_aie_final_gemm": true,
2727
"rope_freq": {
2828
"factor": 32.0,
2929
"low_freq_factor": 1.0,

applications/llama_3.2_1b/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def set_prefill_time():
400400
parser.add_argument(
401401
"--prompt_len",
402402
type=int,
403-
default=64,
403+
default=2048,
404404
help="Truncate prompt to this many tokens.",
405405
)
406406
parser.add_argument(

applications/llama_3.2_1b/src/block/gqa.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,38 +163,37 @@ def forward(self, x, mask, angles, input_pos=None):
163163
# Decode phase with KV cache - use GEMV for single token
164164
# weight.T @ input, which is vector-matrix multiplication (So, is_mv=False)
165165
x_flat = x.reshape(1, -1) # Shape: (1, d_in)
166-
input_dtype = x.dtype
167166

168167
queries_flat = self.aie_query_gemv(x_flat)
169-
queries = queries_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
168+
queries = queries_flat.reshape(b, num_tokens, self.d_out)
170169

171170
keys_flat = self.aie_key_gemv(x_flat)
172171
keys = keys_flat.reshape(
173172
b, num_tokens, self.num_kv_groups * self.head_dim
174-
).to(input_dtype)
173+
)
175174

176175
values_flat = self.aie_value_gemv(x_flat)
177176
values = values_flat.reshape(
178177
b, num_tokens, self.num_kv_groups * self.head_dim
179-
).to(input_dtype)
178+
)
180179

181180
elif self.cfg["use_aie_attn_projection_gemm"]:
182181
# Prefill phase - use GEMM for multiple tokens
183182
x_flat = x.reshape(-1, d_in)
184183
input_dtype = x.dtype
185184

186185
queries_flat = self.aie_query(x_flat)
187-
queries = queries_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
186+
queries = queries_flat.reshape(b, num_tokens, self.d_out)
188187

189188
keys_flat = self.aie_key(x_flat)
190189
keys = keys_flat.reshape(
191190
b, num_tokens, self.num_kv_groups * self.head_dim
192-
).to(input_dtype)
191+
)
193192

194193
values_flat = self.aie_value(x_flat)
195194
values = values_flat.reshape(
196195
b, num_tokens, self.num_kv_groups * self.head_dim
197-
).to(input_dtype)
196+
)
198197
else:
199198
queries = self.W_query(x)
200199
keys = self.W_key(x)
@@ -348,9 +347,9 @@ def apply_rope_and_transpose(tensor, num_heads_dim, angle_slice):
348347
def my_mha(queries, keys, values):
349348
inv_scale = 1 / np.sqrt(values.shape[-1])
350349
context_vec = torch.nn.functional.scaled_dot_product_attention(
351-
queries.to(torch.bfloat16).to("cpu"),
352-
keys.to(torch.bfloat16).to("cpu"),
353-
values.to(torch.bfloat16).to("cpu"),
350+
queries,
351+
keys,
352+
values,
354353
dropout_p=0.0,
355354
is_causal=True,
356355
scale=inv_scale,
@@ -384,11 +383,11 @@ def my_mha(queries, keys, values):
384383
if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gemv"]:
385384
context_vec_flat = context_vec.reshape(1, -1)
386385
output_flat = self.aie_out_proj_gemv(context_vec_flat)
387-
context_vec = output_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
386+
context_vec = output_flat.reshape(b, num_tokens, self.d_out)
388387
elif self.cfg["use_aie_attn_projection_gemm"]:
389388
context_vec_flat = context_vec.reshape(-1, self.d_out)
390389
output_flat = self.aie_out_proj(context_vec_flat)
391-
context_vec = output_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
390+
context_vec = output_flat.reshape(b, num_tokens, self.d_out)
392391
else:
393392
context_vec = self.out_proj(context_vec)
394393

applications/llama_3.2_1b/src/model_with_json.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from pathlib import Path
1313
from src.block.transformer import TransformerBlock
1414
from operators.rope.rope_utils import compute_rope_params
15-
from operators import AIERMSNorm
15+
from operators import (
16+
AIERMSNorm,
17+
AIEGEMM,
18+
)
1619
from rich.console import Console
1720
from rich.text import Text
1821

@@ -169,13 +172,37 @@ def __init__(
169172
self.cfg["emb_dim"], eps=1e-5, dtype=self.cfg["dtype"]
170173
)
171174

172-
# Depedns on use_aie_final_gemm
173-
self.out_head = nn.Linear(
174-
self.cfg["emb_dim"],
175-
self.cfg["vocab_size"],
176-
bias=False,
177-
dtype=self.cfg["dtype"],
178-
)
175+
# Offload final linear layer if enabled
176+
if self.cfg.get("use_aie_final_gemm", False):
177+
# Since this GEMM has such a large N dimension, partition the N dimension by 4,
178+
# and GEMM will execute for a workload of that smaller N dimension across different buffers of B and C
179+
aie_config_prefill = {
180+
"num_aie_columns": 8,
181+
"tile_m": 64,
182+
"tile_k": 64,
183+
"tile_n": 64,
184+
"b_col_maj": True,
185+
"use_static_weight": True,
186+
"separate_c_tiles": True,
187+
"partition_N": 4,
188+
}
189+
if self.cfg["use_kv_cache"]:
190+
M_for_gemm = self.prompt_length
191+
else:
192+
M_for_gemm = self.prompt_length + self.num_tokens
193+
self.out_head = AIEGEMM(
194+
M=M_for_gemm,
195+
K=self.cfg["emb_dim"],
196+
N=self.cfg["vocab_size"],
197+
**aie_config_prefill,
198+
)
199+
else:
200+
self.out_head = nn.Linear(
201+
self.cfg["emb_dim"],
202+
self.cfg["vocab_size"],
203+
bias=False,
204+
dtype=self.cfg["dtype"],
205+
)
179206

180207
# Reusable utilities
181208
cos, sin = compute_rope_params(
@@ -194,6 +221,22 @@ def forward(self, in_idx, input_pos=None, use_kv_cache=False):
194221
tok_embeds = self.tok_emb(in_idx)
195222
x = tok_embeds
196223

224+
# Check if input is a vector (decode phase) or matrix (prefill phase)
225+
# Handle 1D: (emb_dim,), 2D: (1, emb_dim), or 3D: (1, 1, emb_dim)
226+
is_vector = (
227+
len(x.shape) == 1
228+
or (len(x.shape) == 2 and x.shape[0] == 1)
229+
or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1)
230+
)
231+
232+
# (batch, sequence, embedding) where sequence=1 indicates decode
233+
if len(x.shape) == 3:
234+
is_decode_with_kv = (x.shape[1] == 1) and self.cfg["use_kv_cache"]
235+
elif len(x.shape) == 2:
236+
is_decode_with_kv = (x.shape[0] == 1) and self.cfg["use_kv_cache"]
237+
else:
238+
is_decode_with_kv = False
239+
197240
num_tokens = x.shape[1]
198241

199242
# During generation phase with KV cache, don't create a mask
@@ -219,19 +262,39 @@ def forward(self, in_idx, input_pos=None, use_kv_cache=False):
219262
else:
220263
x = self.final_norm(x)
221264

222-
logits = self.out_head(x.to(self.cfg["dtype"]))
265+
if is_decode_with_kv and self.cfg["use_aie_gemv"]:
266+
# TODO: Offload to NPU
267+
# logits = self.aie_out_head_gemv(x)
268+
logits = self.out_head(x)
269+
else:
270+
logits = self.out_head(x)
223271

224272
return logits
225273

226-
def assign_weights(self, final_norm):
274+
def assign_weights(self, final_norm, out_head, out_head_name):
227275
if self.cfg.get("use_aie_final_norm", False):
228276
self.aie_final_norm_prefill.weight = final_norm
229277
if self.cfg["use_kv_cache"]:
230278
self.aie_final_norm_decode.weight = final_norm
231-
return
279+
else:
280+
self.final_norm.weight = assign(
281+
self.final_norm.weight,
282+
final_norm,
283+
f"model.norm.weight",
284+
)
232285

233-
self.final_norm.weight = assign(
234-
self.final_norm.weight,
235-
final_norm,
236-
f"model.norm.weight",
237-
)
286+
# TODO: Offload GEMV to NPU
287+
# if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
288+
# self.aie_out_head_gemv.weight = out_head
289+
if self.cfg["use_aie_final_gemm"]:
290+
# Want column-major for B
291+
self.out_head.weight = out_head.T
292+
# TODO: Create separate linear layers for prefill and decode (with gemm/gemv)
293+
# if self.cfg["use_kv_cache"]:
294+
# self.out_head.weight = out_head.T
295+
else:
296+
self.out_head.weight = assign(
297+
self.out_head.weight,
298+
out_head,
299+
out_head_name,
300+
)

applications/llama_3.2_1b/src/utils.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,18 +126,10 @@ def load_weights_into_llama(model, param_config, params):
126126
)
127127

128128
# Load output layer weights
129-
model.assign_weights(params["model.norm.weight"])
130-
131129
if "lm_head.weight" in params.keys():
132-
model.out_head.weight = assign(
133-
model.out_head.weight, params["lm_head.weight"], "lm_head.weight"
134-
)
130+
model.assign_weights(params["model.norm.weight"], params["lm_head.weight"], "lm_head.weight")
135131
else:
136-
model.out_head.weight = assign(
137-
model.out_head.weight,
138-
params["model.embed_tokens.weight"],
139-
"model.embed_tokens.weight",
140-
)
132+
model.assign_weights(params["model.norm.weight"], params["model.embed_tokens.weight"], "model.embed_tokens.weight")
141133

142134

143135
def text_to_token_ids(text, tokenizer):

operators/common/aie_base.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,48 @@ def prepare_runtime(cls):
9898
)
9999

100100
# If multiple buffers (of the same binned size) are used in the
101-
# same kernel invocation, they require separate allocations.
101+
# same kernel invocation OR across different invocations with shared
102+
# buffers, they require separate allocations.
102103
conflicting_buffers = {} # map buffer -> {set of conflicting buffers}
103-
for kernel, *args in op.runlist:
104+
buffer_to_runlist_entries = {} # map buffer -> set of runlist entry indices
105+
106+
# First pass: track which buffers appear in which runlist entries
107+
for idx, (kernel, *args) in enumerate(op.runlist):
108+
for arg in args:
109+
buffer_to_runlist_entries.setdefault(arg, set()).add(idx)
110+
111+
# Second pass: determine conflicts
112+
for idx, (kernel, *args) in enumerate(op.runlist):
104113
for arg in args:
105114
if arg in op.buffer_static_data:
106115
# Static buffers never conflict
107116
continue
108-
# Conflict only exists if buffers are in the same size pool
109117
pool_sz = get_pool_sz(op.buffers[arg])
118+
119+
# Buffers conflict if they're in the same runlist entry
110120
conflicting_args = {
111121
a for a in args if get_pool_sz(op.buffers[a]) == pool_sz
112122
} - {arg}
123+
124+
# Also conflict with buffers in other runlist entries that share
125+
# a buffer with this entry
126+
for other_arg in args:
127+
if other_arg == arg:
128+
continue
129+
for other_idx in buffer_to_runlist_entries.get(
130+
other_arg, set()
131+
):
132+
if other_idx != idx:
133+
_, *other_args = op.runlist[other_idx]
134+
conflicting_args.update(
135+
{
136+
a
137+
for a in other_args
138+
if get_pool_sz(op.buffers[a]) == pool_sz
139+
and a != arg
140+
}
141+
)
142+
113143
conflicting_buffers[arg] = conflicting_buffers.get(
114144
arg, set()
115145
).union(conflicting_args)
@@ -244,12 +274,19 @@ def add_to_runlist(self, kernel_name, *args):
244274
def get_bo(self, buffer_name):
245275
return self.buffer_bos[buffer_name]
246276

247-
def read_buffer(self, buffer_name, shape, dtype=bfloat16):
277+
def read_buffer(self, buffer_name, shape, copy=False, dtype=bfloat16):
248278
"""Read buffer and return values as a numpy array"""
249-
size = np.prod(shape) * np.dtype(dtype).itemsize
250-
output_bytes = self.get_bo(buffer_name).read(size, 0)
251-
output_data_flat = np.frombuffer(output_bytes, dtype=dtype)
252-
return output_data_flat.reshape(*shape)
279+
# Total bytes
280+
size = int(np.prod(shape)) * np.dtype(dtype).itemsize
281+
282+
# Map once; map() should return a Python buffer interface over the BO
283+
mv = self.get_bo(buffer_name).map()
284+
285+
# Create a NumPy view over mapped memory (zero-copy)
286+
arr = np.frombuffer(mv, dtype=dtype, count=np.prod(shape))
287+
if copy:
288+
return arr.copy()
289+
return arr.reshape(shape)
253290

254291
def read_buffer_as_torch(self, buffer_name, shape, dtype=bfloat16):
255292
return numpy_to_torch(self.read_buffer(buffer_name, shape, dtype))

operators/common/discover_tests.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,10 @@ def generate_test_list(operators_dir, output_dir=None, extensive=False):
9393
test_script = test_parts[0]
9494
test_args = " ".join(test_parts[1:]) if len(test_parts) > 1 else ""
9595

96-
# Wrap command to run from /tmp to avoid sys.path issues
9796
if test_args:
98-
wrapped_command = f"cd /tmp && python3 {test_script} {test_args}"
97+
wrapped_command = f"cd {output_dir} && python3 {test_script} {test_args}"
9998
else:
100-
wrapped_command = f"cd /tmp && python3 {test_script}"
99+
wrapped_command = f"cd {output_dir} && python3 {test_script}"
101100

102101
# Generate test file content
103102
content = f"""run = '{wrapped_command}'

operators/common/utils.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,33 @@
2323

2424

2525
def torch_to_numpy(tensor: torch.Tensor) -> np.ndarray:
26-
if tensor.dtype == torch.bfloat16:
27-
float_arr = tensor.float().detach().cpu().numpy()
28-
return float_arr.astype(bfloat16)
29-
return tensor.detach().cpu().numpy()
26+
# Detach (to drop grad) and ensure on CPU
27+
t = tensor.detach()
28+
if t.device.type != 'cpu':
29+
t = t.cpu()
30+
# Ensure contiguous for safe view operations
31+
if not t.is_contiguous():
32+
t = t.contiguous()
33+
34+
if t.dtype == torch.bfloat16:
35+
# Zero-copy reinterpret: view the same memory as uint16, then as NumPy bfloat16
36+
# This avoids numeric conversion and extra passes over memory.
37+
u16_np = t.view(torch.uint16).numpy() # shares memory, zero-copy
38+
return u16_np.view(np.dtype('bfloat16')) # reinterpret, zero-copy
39+
40+
# For supported dtypes, this is already zero-copy
41+
return t.numpy()
3042

3143

3244
def numpy_to_torch(array: np.ndarray) -> torch.Tensor:
33-
device = torch.device("cpu")
34-
if array.dtype == bfloat16:
35-
return torch.from_numpy(array.astype(np.float32)).to(torch.bfloat16).to(device)
36-
return torch.from_numpy(array).to(device)
45+
# Ensure contiguous to let from_numpy create a view
46+
if not array.flags['C_CONTIGUOUS']:
47+
array = np.ascontiguousarray(array)
48+
49+
if array.dtype == np.dtype('bfloat16'):
50+
# reinterpret the same memory as uint16, then view as torch.bfloat16
51+
t_u16 = torch.from_numpy(array.view(np.uint16)) # zero-copy
52+
return t_u16.view(torch.bfloat16) # view, zero-copy
53+
54+
# For supported dtypes, from_numpy is already zero-copy
55+
return torch.from_numpy(array)

0 commit comments

Comments
 (0)