In [3]:
import numpy as np
import numba as nb
from numba.extending import overload
from numba import types
def explicit_indexing_mult1(x:np.ndarray,a,b):
    """Multiply each row of `x` by each element/value of `a` then add and multiply each column of `x` by elements/value of `b`.
    :param x: A 2d array
    :param a: A value, a 1d array size = x.shape[0], or 2d array size =x.shape (0,1)
    :param b: A value, a 1d array size = x.shape[1]
    :return: 
    """
    #just for showing the result:
    xt=x.copy()
    x=xt
    
    x *= np.asarray(a)[:, None] if np.ndim(a) == 1 else a  # Make 'a' broadcast along columns
    x += np.asarray(b)[None, :] if np.ndim(b) == 1 else b  # Make 'b' broadcast along rows
    x *= np.asarray(b)[None, :] if np.ndim(b) == 1 else b  # Same for multiplication
    
    return x

# explicit_indexing_mult1=explicit_indexing_mult1_jit

x=np.array([[4,5,3,2],[1.,1.,1.,1.],[4,3,2,1]],dtype=np.float64)
a=np.array([[2,2,2,2],[1.,1.,1.,1.],[1,2,3,4]],dtype=np.float64)
b=np.array([2,2,2,2],dtype=np.float64)
t1=(a,b) #x second row s/b all 6
t2=(a[0,:3],b) # second row all 8
t3=(.5,3.) # second row all 10.5

(explicit_indexing_mult1(x,*t1),
explicit_indexing_mult1(x,*t2),
explicit_indexing_mult1(x,*t3))

(array([[20., 24., 16., 12.],
        [ 6.,  6.,  6.,  6.],
        [12., 16., 16., 12.]]),
 array([[20., 24., 16., 12.],
        [ 8.,  8.,  8.,  8.],
        [20., 16., 12.,  8.]]),
 array([[15. , 16.5, 13.5, 12. ],
        [10.5, 10.5, 10.5, 10.5],
        [15. , 13.5, 12. , 10.5]]))

In [4]:
@overload(explicit_indexing_mult1)
def explicit_indexing_mult1_overload(x, a, b):
    a_is_array = isinstance(a, types.Array)
    b_is_array = isinstance(b, types.Array)
    a_dim = a.ndim if a_is_array else 0
    b_dim = b.ndim if b_is_array else 0

    if (a_dim == 0 or a_dim==2) and b_dim == 0:
        def impl(x, a, b):
            x=x.copy()
            x *= a
            x += b
            x *= b
            return x
        return impl
    elif a_dim == 1 and b_dim == 0:
        def impl(x, a, b):
            x=x.copy()
            x *= a.reshape((x.shape[0], 1))  # Broadcast a across columns
            x += b
            x *= b
            return x
        return impl
    elif (a_dim == 0 or a_dim==2) and b_dim == 1:
        def impl(x, a, b):
            x=x.copy()
            x *= a
            x += b.reshape((1, x.shape[1]))  # Broadcast b across rows
            x *= b.reshape((1, x.shape[1]))
            return x
        return impl

    elif a_dim == 1 and b_dim == 1:
        def impl(x, a, b):
            x=x.copy()
            x *= a.reshape((x.shape[0], 1))  # Broadcast a across columns
            x += b.reshape((1, x.shape[1]))  # Broadcast b across rows
            x *= b.reshape((1, x.shape[1]))
            return x
        return impl
    
@nb.njit
def explicit_indexing_mult1_jit(x,a,b):
    return explicit_indexing_mult1(x,a,b)

(explicit_indexing_mult1_jit(x,*t1),
explicit_indexing_mult1_jit(x,*t2),
explicit_indexing_mult1_jit(x,*t3))

(array([[20., 24., 16., 12.],
        [ 6.,  6.,  6.,  6.],
        [12., 16., 16., 12.]]),
 array([[20., 24., 16., 12.],
        [ 8.,  8.,  8.,  8.],
        [20., 16., 12.,  8.]]),
 array([[15. , 16.5, 13.5, 12. ],
        [10.5, 10.5, 10.5, 10.5],
        [15. , 13.5, 12. , 10.5]]))

In [7]:
def l_12_d(x, i1=0, i2=0,d=0):
    if isinstance(d,nb.types.Literal):
        d=d.literal_value
    if type(x) is np.ndarray:
        if len(x.shape)>=2+d: return x[i1, i2]
        elif len(x.shape)==1+d: return x[i1]
    return x

#l stands for lower 
# The order of _{#s}_ represents the sequence of dropped indexes as a stack, last dropped first.
#  Basically, if the dimension is less than required to use all indexes, the last index is the first removed.
#  See example in next cell.
# d represents the smallest dimension size the subindex is allowed to reach d=0 is the value.
def l_12_0(x, i1=0, i2=0):
    pass

@overload(l_12_0)
def _l_12_0(x, i1=0, i2=0):
    #Same thing but d is manual, no literal cast so compilation can be a little quicker.
    def _impl(x,i1=0, i2=0):
        return x
    if isinstance(x, types.Array):
        if x.ndim>=2:
            def _impl(x,i1=0, i2=0):return x[i1, i2]
        elif x.ndim == 1: 
            def _impl(x,i1=0, i2=0): return  x[i1]
    return _impl


_verbs=False 

@overload(l_12_d)
def _l_12_d(x, i1=0, i2=0,d=0):
    def _impl(x,i1=0, i2=0,d=0):return x
    if isinstance(x, types.Array):
        if isinstance(d,(nb.types.Literal,int)):
            dv=d if type(d) is int else d.literal_value
            if _verbs:
                print('d is ',dv)
            if x.ndim >= 2+dv: 
                def _impl(x,i1=0, i2=0,d=0): return  x[i1,i2]
            elif x.ndim == 1+dv: 
                def _impl(x,i1=0, i2=0,d=0): return  x[i1]
            return _impl
        if _verbs:
            print('Requesting literal value for d')
        return lambda x,i1=0, i2=0,d=0: nb.literally(d)
    return _impl

@nb.njit
def implicit_indexing_mult_jit(x,a,b):
    x=x.copy()
    for i in range(x.shape[0]):
        at=l_12_d(a, i, d=1)
        for j in range(x.shape[1]):
            #av = l_12_0(a, i, j) #or this but comment out, `at` above
            av = l_12_d(at, j)
            bv = l_12_0(b, j) #or l_12_d
            x[i, j] = (x[i, j] * av + bv) * bv
    return x


(implicit_indexing_mult_jit(x,*t1),
implicit_indexing_mult_jit(x,*t2),
implicit_indexing_mult_jit(x,*t3))

(array([[20., 24., 16., 12.],
        [ 6.,  6.,  6.,  6.],
        [12., 16., 16., 12.]]),
 array([[20., 24., 16., 12.],
        [ 8.,  8.,  8.,  8.],
        [20., 16., 12.,  8.]]),
 array([[15. , 16.5, 13.5, 12. ],
        [10.5, 10.5, 10.5, 10.5],
        [15. , 13.5, 12. , 10.5]]))

In [6]:
def l_213_d(x, i1=0, i2=0,d=0):
    pass

def l_321_d(x, i1=0, i2=0,d=0):
    pass

@overload(l_213_d)
def _l_213_d(x, i1=0, i2=0,i3=0,d=0):
    def _impl(x,i1=0, i2=0,i3=0,d=0):return x
    if isinstance(x, types.Array):
        if isinstance(d,(nb.types.Literal,int)):
            dv=d if type(d) is int else d.literal_value
            if _verbs:
                print('d is ',dv)
            if x.ndim>=3+dv:
                def _impl(x,i1=0, i2=0,i3=0,d=0):return x[i1, i2, i3]
            elif x.ndim == 2+dv: 
                def _impl(x,i1=0, i2=0,i3=0,d=0): return  x[i1,i2]
            elif x.ndim == 1+dv: 
                def _impl(x,i1=0, i2=0,i3=0,d=0): return  x[i2]
            return _impl
        if _verbs:
            print('Requesting literal value for d')
        return lambda x,i1=0, i2=0,i3=0,d=0: nb.literally(d)
    return _impl

@nb.njit
def implicit_indexing_mult2_jit(x,a,b):
    """Multiply each row of `x` by each element/value of `a` then add and multiply each column of `x` by elements/value of `b`.
    :param x: A 2d array
    :param a: A value, a 1d array size = x.shape[0], or 2d array size =x.shape (0,1)
    :param b: A value, a 1d array size = x.shape[1], or 2d array size =x.shape (0,1)
    :return: 
    """
    x=x.copy()
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            av = l_12_d(a, i, j)
            bv = l_213_d(b,i, j)
            x[i, j] = (x[i, j] * av + bv) * bv
    return x

t4=(a,a) #x second row s/b all 2's.
print(implicit_indexing_mult2_jit(x,*t1),
implicit_indexing_mult2_jit(x,*t2),
implicit_indexing_mult2_jit(x,*t3),
implicit_indexing_mult2_jit(x,*t4),sep='\n')

#Index bounds are not checked like usual so be weary.
t5=(a,a[1,:3]) #second row will still all be 2s.
print('\n',implicit_indexing_mult2_jit(x,*t5))

[[20. 24. 16. 12.]
 [ 6.  6.  6.  6.]
 [12. 16. 16. 12.]]
[[20. 24. 16. 12.]
 [ 8.  8.  8.  8.]
 [20. 16. 12.  8.]]
[[15.  16.5 13.5 12. ]
 [10.5 10.5 10.5 10.5]
 [15.  13.5 12.  10.5]]
[[20. 24. 16. 12.]
 [ 2.  2.  2.  2.]
 [ 5. 16. 27. 32.]]

 [[ 9. 11.  7.  5.]
 [ 2.  2.  2.  2.]
 [ 5.  7.  7.  5.]]
