In [3]:
import numba as nb
import numpy as np

class MyClass:
    def __init__(self, a, b):
        self.a = a
        self.b = b

    @staticmethod
    @nb.njit
    def my_method(x, a, b):
        # Perform some computations involving a and b
        return x * a + b
    @staticmethod
    @nb.njit
    def matvec(x, a):
        n, m = x.shape
        assert(m == len(a))
        b = np.zeros_like(a)
        # implement matrix-vector multiplication
        for i in range(n):
            for j in range(m):
                b[i] += x[i, j] * a[j]
        return b
        

# Usage
obj = MyClass(1, 2)
result = MyClass.my_method(3, obj.a, obj.b)
print(result)
A = np.arange(9).reshape(3, 3)
a = np.arange(3)
result = MyClass.matvec(A, a)
print(result)


5
[ 5 14 23]


- passing class object to function

In [3]:
from numba import jit, njit       
from numba import float64    
from numba.experimental import jitclass

spec = [("a", float64),("b",float64)]
@jitclass(spec)
class params():
    def __init__(self, a=1.1, b=2.3):
        self.a = a
        self.b = b

    def sum(self):
        return self.a + self.b


@njit(float64(float64, params.class_type.instance_type)) 
def get_sum_3(c, someobj):
    d = 0
    for i in range(1000):
        for j in range(1000):
            d += c + someobj.sum()
    return d   

get_sum_3(1.1, params(1.0, 2.0))

4100000.00006511

In [15]:
from numba import njit, literal_unroll

@njit
def foo():
    heterogeneous_tuple = (1, 2j, 3.0, "a")
    for i in literal_unroll(heterogeneous_tuple):
        print(i)

foo()

1
2j
3.0
a


In [10]:
from numba.experimental import jitclass
from numba import float64
import numpy as np

# Define the jitclass
@jitclass([('float_array', float64[:])])
class MyClass:
    def __init__(self, size):
        # Create a float64 array with the specified size
        self.float_array = np.zeros(size, dtype=np.float64)

    def update_array(self, new_values):
        # Update the float_array with new values
        self.float_array[:] = new_values

    def print_array(self):
        # Print the float_array
        print(self.float_array)
    
    def get_data(self):
        return [("float_array", self.float_array)]

# Create an instance of the jitclass
my_instance = MyClass(5)

# Update and print the array
new_values = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float64)
my_instance.update_array(new_values)
my_instance.print_array()
my_instance.float_array
# my_instance.get_data()


[1. 2. 3. 4. 5.]


array([1., 2., 3., 4., 5.])

In [18]:
from numba.experimental import jitclass
from numba import float64, int64
import numpy as np

# Define the jitclass
# @jitclass([('float_array', float64[:])])
class MyClass:
    def __init__(self, size):
        # Create a float64 array with the specified size
        self.float_array = np.zeros(size, dtype=np.float64)
        self.size = size

    def get_params(self):
        return self.__dict__

jitclass_spec = [
    ('float_array', float64[:]),
    ('size', int64)
]
obj = MyClass(5)
print(obj.get_params()) 

obj = jitclass(jitclass_spec)(MyClass)(5)
obj.float_array, obj.size


{'float_array': array([0., 0., 0., 0., 0.]), 'size': 5}


(array([0., 0., 0., 0., 0.]), 5)

In [31]:
# using jitclass with multiprocessing

import os
import numba
from time import sleep
from numba.experimental import jitclass
from multiprocessing import Pool

jit_class = jitclass({'a': numba.float64})
@jit_class
class JitClass:
    def __init__(self, a):
        self.a = a

    def square_value(self):
        return self.a ** 2

jit_class_instance = JitClass(3.0)
jit_class_instance.square_value()

def wrapper(value):
    sleep(1)
    print("id: {} \n".format(os.getpid()))
    jit_class_instance = JitClass(value)
    return jit_class_instance.square_value()

with Pool(2) as p:
    data = p.map(wrapper, range(5))

print(data)

id: 1506509 
id: 1506510 


id: 1506510 
id: 1506509 


id: 1506510 

[0.0, 1.0, 4.0, 9.0, 16.0]


In [37]:
# using jitclass with multiprocessing, pass jitclass instance as argument

import os
import numba
from time import sleep
from numba.experimental import jitclass
from multiprocessing import Pool

jit_class = jitclass({'a': numba.float64,
                      'b': numba.float64})
@jit_class
class JitClass:
    def __init__(self, a, b):
        self.a = a
        self.b = b

class A:
    def __init__(self, a, b):
        self.a = a
        self.b = b
        self.P = JitClass(a, b)

    @staticmethod
    @numba.njit
    def multiply(P):
        r = P.a * P.b
        return r

def wrapper(a, b):
    sleep(1)
    print("id: {} \n".format(os.getpid()))
    obj = A(a, b)
    return A.multiply(obj.P)

par_list = np.random.rand(5,2).tolist()

with Pool(2) as p:
    data = p.starmap(wrapper, par_list)

id: 1508707 
id: 1508708 


id: 1508708 
id: 1508707 


id: 1508708 

