In [1]:
import numpy as np
import ctypes

from llvmlite import ir, binding

In [2]:
# Initialize LLVM
binding.initialize()
binding.initialize_native_target()
binding.initialize_native_asmprinter()

target = binding.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = binding.parse_assembly("")  # empty backing module
engine = binding.create_mcjit_compiler(backing_mod, target_machine)

In [3]:
# Build IR module
module = ir.Module(name="jit_mod")
module.triple = binding.get_default_triple()
module.data_layout = target_machine.target_data

i32 = ir.IntType(32)
i32p = i32.as_pointer()
fn_ty = ir.FunctionType(i32, [i32p, i32p, i32])
fn = ir.Function(module, fn_ty, name="dot_i32")

a_ptr, b_ptr, n = fn.args
a_ptr.name = "a"
b_ptr.name = "b"
n.name = "n"

entry = fn.append_basic_block("entry")
loop = fn.append_basic_block("loop")
body = fn.append_basic_block("body")
exitb = fn.append_basic_block("exit")

builder = ir.IRBuilder(entry)
zero = ir.Constant(i32, 0)

# i = 0, s = 0
i_alloc = builder.alloca(i32, name="i")
s_alloc = builder.alloca(i32, name="s")
builder.store(zero, i_alloc)
builder.store(zero, s_alloc)
builder.branch(loop)

# loop: if i < n goto body else exit
builder.position_at_end(loop)
i_val = builder.load(i_alloc, name="i_val")
cond = builder.icmp_signed("<", i_val, n, name="cond")
builder.cbranch(cond, body, exitb)

# body: s += a[i]*b[i]; i++
builder.position_at_end(body)
ai_ptr = builder.gep(a_ptr, [i_val], inbounds=True, name="ai_ptr")
bi_ptr = builder.gep(b_ptr, [i_val], inbounds=True, name="bi_ptr")
ai = builder.load(ai_ptr, name="ai")
bi = builder.load(bi_ptr, name="bi")
prod = builder.mul(ai, bi, name="prod")

s_val = builder.load(s_alloc, name="s_val")
s_new = builder.add(s_val, prod, name="s_new")
builder.store(s_new, s_alloc)

i_new = builder.add(i_val, ir.Constant(i32, 1), name="i_new")
builder.store(i_new, i_alloc)
builder.branch(loop)

# exit: return s
builder.position_at_end(exitb)
s_final = builder.load(s_alloc, name="s_final")
builder.ret(s_final)

<ir.Ret '.12' of type 'void', opname 'ret', operands [<ir.LoadInstr 's_final' of type 'i32', opname 'load', operands [<ir.AllocaInstr 's' of type 'i32*', opname 'alloca', operands ()>]>]>

In [4]:
# Print the IR (before optimization) 
llvm_ir = str(module)
print("=== LLVM IR (unoptimized) ===")
print(llvm_ir)

=== LLVM IR (unoptimized) ===
; ModuleID = "jit_mod"
target triple = "x86_64-unknown-linux-gnu"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"

define i32 @"dot_i32"(i32* %"a", i32* %"b", i32 %"n")
{
entry:
  %"i" = alloca i32
  %"s" = alloca i32
  store i32 0, i32* %"i"
  store i32 0, i32* %"s"
  br label %"loop"
loop:
  %"i_val" = load i32, i32* %"i"
  %"cond" = icmp slt i32 %"i_val", %"n"
  br i1 %"cond", label %"body", label %"exit"
body:
  %"ai_ptr" = getelementptr inbounds i32, i32* %"a", i32 %"i_val"
  %"bi_ptr" = getelementptr inbounds i32, i32* %"b", i32 %"i_val"
  %"ai" = load i32, i32* %"ai_ptr"
  %"bi" = load i32, i32* %"bi_ptr"
  %"prod" = mul i32 %"ai", %"bi"
  %"s_val" = load i32, i32* %"s"
  %"s_new" = add i32 %"s_val", %"prod"
  store i32 %"s_new", i32* %"s"
  %"i_new" = add i32 %"i_val", 1
  store i32 %"i_new", i32* %"i"
  br label %"loop"
exit:
  %"s_final" = load i32, i32* %"s"
  ret i32 %"s_final"
}



In [5]:
# Parse + verify
mod = binding.parse_assembly(llvm_ir)
mod.verify()

In [6]:
# Optimize with standard O3-ish passes
pmb = binding.PassManagerBuilder()
pmb.opt_level = 3
pmb.loop_vectorize = True
pmb.slp_vectorize = True

mpm = binding.ModulePassManager()
pmb.populate(mpm)
mpm.run(mod)

print("\n=== LLVM IR (optimized) ===")
print(str(mod))


=== LLVM IR (optimized) ===
; ModuleID = '<string>'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

; Function Attrs: nofree norecurse nosync nounwind readonly
define i32 @dot_i32(i32* nocapture readonly %a, i32* nocapture readonly %b, i32 %n) local_unnamed_addr #0 {
entry:
  %cond1 = icmp sgt i32 %n, 0
  br i1 %cond1, label %body.preheader, label %exit

body.preheader:                                   ; preds = %entry
  %wide.trip.count = zext i32 %n to i64
  br label %body

body:                                             ; preds = %body.preheader, %body
  %indvars.iv = phi i64 [ 0, %body.preheader ], [ %indvars.iv.next, %body ]
  %s.02 = phi i32 [ 0, %body.preheader ], [ %s_new, %body ]
  %ai_ptr = getelementptr inbounds i32, i32* %a, i64 %indvars.iv
  %bi_ptr = getelementptr inbounds i32, i32* %b, i64 %indvars.iv
  %ai = load i32, i32* %ai_ptr, align 4
  %bi = lo

In [7]:
# Add module to engine and finalize
engine.add_module(mod)
engine.finalize_object()
engine.run_static_constructors()

In [8]:
# Get function pointer
func_ptr = engine.get_function_address("dot_i32")

In [9]:
# Wrap with ctypes
cfunc = ctypes.CFUNCTYPE(ctypes.c_int32,
                         ctypes.POINTER(ctypes.c_int32),
                         ctypes.POINTER(ctypes.c_int32),
                         ctypes.c_int32)(func_ptr)

In [10]:
# Test on real data
a = (np.arange(100000, dtype=np.int32) % 1000) - 500
b = (np.arange(100000, dtype=np.int32) % 777) - 300

res_jit = cfunc(a.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
               b.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
               a.size)

res_np = int((a.astype(np.int64) * b.astype(np.int64)).sum())  # avoid overflow in numpy check

In [11]:
# Our function uses int32 math, so match that:
res_np_i32 = np.int32((a * b).sum()).item()

print("\nJIT result:", res_jit)
print("NumPy int32 result:", res_np_i32)


JIT result: -25558656
NumPy int32 result: -25558656
