Skip to content

Commit 1d87fe8

Browse files
authored
Merge 675c212 into 912e6bc
2 parents 912e6bc + 675c212 commit 1d87fe8

File tree

9 files changed

+876
-12
lines changed

9 files changed

+876
-12
lines changed

aie_kernels/aie2p/softmax.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <aie_api/aie.hpp>
55
#include <stdint.h>
6+
#include <math.h>
67

78
#define SM_VEC_LEN 64 // 32
89
#define log2e 1.4453125 // 1.44269504089
@@ -30,7 +31,7 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out
3031
aie::vector<bfloat16, SM_VEC_LEN> in_elems, exp_val, input_bf16, log2e_vec, max_val_vec;
3132
aie::accum<accfloat, SM_VEC_LEN> out_vals, exp_val_accum, scaled_accum, exp_in_accum;
3233

33-
float max_val = 0;
34+
float max_val = -INFINITY;
3435
float accum_exp_val = 0;
3536
float running_max = 0;
3637
bfloat16 col_sum_inv;

iron/common/compilation/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def compile(self, graph):
496496
str(self.aiecc_path),
497497
"-v",
498498
"-j1",
499+
"--dynamic-objFifos",
499500
"--no-compile-host",
500501
"--no-xchesscc",
501502
"--no-xbridge",

iron/common/fusion.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ml_dtypes
66
import pyxrt
77
import ctypes
8+
import time
89
from . import compilation as comp
910
from .base import AIEOperatorBase, MLIROperator
1011
from .utils import XRTSubBuffer
@@ -42,8 +43,7 @@ def get_kernel_artifacts(self):
4243
"""Collect all kernel artifacts from child operators.
4344
4445
Returns:
45-
List of KernelObjectArtifact instances from all unique child operators,
46-
with filenames and symbol prefixes disambiguated per operator index.
46+
List of KernelObjectArtifact instances from all unique child operators.
4747
"""
4848
kernel_artifacts = []
4949
seen: dict[int, object] = {}
@@ -52,9 +52,6 @@ def get_kernel_artifacts(self):
5252
]
5353
for idx, op in enumerate(unique_operators):
5454
objs = op.get_kernel_artifacts()
55-
for obj in objs:
56-
obj.filename = f"op{idx}_{obj.filename}"
57-
obj.prefix_symbols = f"op{idx}_"
5855
kernel_artifacts.extend(objs)
5956
return kernel_artifacts
6057

@@ -82,8 +79,6 @@ def get_mlir_artifact(self):
8279
]
8380
for idx, op in enumerate(unique_operators):
8481
mlir_artifact = op.get_mlir_artifact()
85-
if len(op.get_kernel_artifacts()) > 0:
86-
mlir_artifact.generator.kwargs["func_prefix"] = f"op{idx}_"
8782
op_name = f"op{idx}_{op.__class__.__name__}"
8883
op_names[id(op)] = op_name
8984
operator_mlir_map[op_name] = mlir_artifact
@@ -290,8 +285,10 @@ def __call__(self, *args):
290285
for i, arg in enumerate(args):
291286
assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo"
292287
run.set_arg(i, arg)
288+
t0 = time.perf_counter()
293289
run.start()
294290
ret_code = run.wait()
291+
self.last_elapsed = time.perf_counter() - t0
295292
if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED:
296293
raise RuntimeError(f"Kernel execution failed with return code {ret_code}")
297294

@@ -371,10 +368,10 @@ def get_buffer(self, buffer_name):
371368
return sub_buffer
372369

373370
def __call__(self):
374-
self.input_buffer.to("npu")
371+
self.input_buffer._sync_to_device()
375372
super().__call__(
376373
self.input_buffer.buffer_object(),
377374
self.output_buffer.buffer_object(),
378375
self.scratch_buffer.buffer_object(),
379376
)
380-
self.output_buffer.to("cpu")
377+
self.output_buffer._sync_from_device()

iron/operators/gemm/design.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def my_matmul(
299299
gemm_object,
300300
[C_l1_ty_internal],
301301
)
302-
matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_f32"
302+
matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_f32"
303303
matmul_kernel = Kernel(
304304
matmul_func_name,
305305
gemm_object,
@@ -314,7 +314,9 @@ def my_matmul(
314314
gemm_object,
315315
[C_l1_ty],
316316
)
317-
matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}"
317+
matmul_func_name = (
318+
f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}"
319+
)
318320
matmul_kernel = Kernel(
319321
matmul_func_name,
320322
gemm_object,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0

0 commit comments

Comments
 (0)