Skip to content

Add a first-class global load op/helper #548

@fsx950223

Description

@fsx950223

Summary

FlyDSL currently does not have a direct high-level op/helper for loading scalar or vector values from a global tensor with a computed offset. Kernel code that needs a simple global memory load has to spell out lower-level plumbing, which is verbose and easy to get subtly wrong.

A first-class global_load op/helper would make kernels easier to write and would keep pointer/addressing details in one place.

Current Implementations

Today we have at least two ways to express this pattern.

1. Extract raw pointer, GEP, then LLVM load

This is the direct low-level approach:

def _extract_global_ptr(tensor):
    from flydsl._mlir.dialects import fly as _fly

    raw = tensor.ir_value() if hasattr(tensor, "ir_value") and not isinstance(tensor, ir.Value) else tensor
    ptr_type = ir.Type.parse("!llvm.ptr<1>")
    return _fly.extract_aligned_pointer_as_index(ptr_type, raw)


def _global_load_i64x2(global_ptr, byte_offset_i64):
    ptr = buffer_ops.get_element_ptr(global_ptr, byte_offset=fx.Int64(byte_offset_i64), elem_type=T.i8)
    return llvm.LoadOp(T.i64x2, ptr, alignment=16).result


def _global_load_i32(global_ptr, elem_offset_i32):
    byte_offset_i64 = fx.Int64(elem_offset_i32) * fx.Int64(4)
    ptr = buffer_ops.get_element_ptr(global_ptr, byte_offset=byte_offset_i64, elem_type=T.i8)
    return llvm.LoadOp(T.i32, ptr, alignment=4).result

Pros:

  • Emits exactly the intended pointer arithmetic and llvm.load.
  • Works for raw byte offsets and arbitrary result types.

Cons:

  • Requires importing/using llvm directly from kernel code.
  • Requires manual pointer extraction and byte-offset handling.
  • Bypasses the higher-level FlyDSL tensor/view vocabulary.

2. Recast tensor iterator, make a Fly view, then memref_load_vec

This avoids direct llvm.LoadOp in kernel code:

def _recast_tensor_iter(tensor, elem_type):
    src_iter = fx.get_iter(tensor)
    src_ptr_type = fx.PointerType(src_iter.type)
    ptr_type = fx.PointerType.get(
        elem_ty=elem_type,
        address_space=src_ptr_type.address_space,
        alignment=src_ptr_type.alignment,
    )
    return fx.recast_iter(ptr_type, src_iter)


def _global_load_i64x2(tensor, byte_offset_i64):
    ptr = fx.add_offset(_recast_tensor_iter(tensor, T.i8), fx.make_int_tuple(fx.Int64(byte_offset_i64)))
    view = fx.Tensor(fx.make_view(ptr, fx.make_layout((16,), (1,))))
    raw = fx.memref_load_vec(view)
    return vector.bitcast(T.i64x2, raw)


def _global_load_i32(tensor, elem_offset_i32):
    ptr = fx.add_offset(_recast_tensor_iter(tensor, T.i32), fx.make_int_tuple(fx.Int32(elem_offset_i32)))
    view = fx.Tensor(fx.make_view(ptr, fx.make_layout((1,), (1,))))
    raw = fx.memref_load_vec(view)
    return vector.extract(raw, static_position=[0], dynamic_position=[])

Pros:

  • Stays within FlyDSL iterator/view/memref abstractions.
  • Avoids direct llvm.LoadOp in kernel code.
  • Can be lowered through existing Fly memref/vector-load machinery.

Cons:

  • Still very verbose for a simple global load.
  • Requires callers to know when to recast to byte-addressed i8 vs element-typed pointers.
  • Requires manually constructing one-off layouts and extracting/bitcasting results.
  • vector.load cannot be used directly on kernel fx.Tensor arguments because they are !fly.memref, not standard MLIR memref.

Proposal

Add a first-class FlyDSL global load op/helper, for example one or both of:

fx.global_load(result_type, tensor, byte_offset, *, alignment=None)
fx.global_load_elem(result_type, tensor, elem_offset, *, alignment=None)

or a lower-level dialect op such as:

fly.global_load %tensor[%offset] : !fly.memref<...> -> vector<...>

The helper/op should support:

  • Scalar loads, e.g. i32, f32.
  • Vector loads, e.g. vector<2xi64>, vector<16xi8>.
  • Byte offsets for packed/raw access patterns.
  • Element offsets for typed access patterns.
  • Optional alignment and cache modifier metadata if useful.
  • Lowering to the appropriate LLVM/global memory load without requiring kernel authors to manually extract pointers or build temporary Fly views.

Motivation

Paged attention kernels currently need this pattern for K/V cache reads and metadata reads. The low-level spelling makes code harder to review and increases the chance of unit mistakes between byte offsets, dword offsets, and element offsets.

A dedicated op/helper would make the intent obvious:

k2 = fx.global_load(T.i64x2, key_cache_ptr, byte_offset=ka_dw * fx.Int64(4), alignment=16)
context_len = fx.global_load(T.i32, context_lengths_ptr, elem_offset=batch_idx, alignment=4)

This would also make it easier to consistently tune lowering behavior in one place later.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions