In [1]:
import tvm
from tvm import te
import numpy as np

# Broadcast

In [11]:
# 2 dimension
def get_broadcast_abc(shape1, shape2):
    # Validation
    assert len(shape1) == 2 and len(shape2) == 2 , "Must to be 2-D"
    for i in range(len(shape1)):
        assert shape1[i] == shape2[i] or shape1[i] == 1 or shape2[i] ==1 , "Broadcast shape error"
    A = te.placeholder(shape1, dtype='float32',name='a')
    B = te.placeholder(shape2, dtype='float32',name='b')
    m = shape1[0] if shape2[0] == 1 else shape2[0]
    n = shape1[1] if shape2[1] == 1 else shape2[1]
    f = lambda i,j : A[ 0 if shape1[0] == 1 else i][0 if shape1[1] == 1 else j] + \
    B[0 if shape2[0] ==1 else i][0 if shape2[1] ==1 else j]
    C = te.compute((m,n), f, name='c')
    return A, B, C

In [13]:
m, n = [te.var(name) for name in ('m', 'n')]
shape1 = (m, 1)
shape2 = (m, n)
A, B, C = get_broadcast_abc(shape1, shape2)

In [15]:
s = te.create_schedule(C.op)
tvm.lower(s, [A,B,C], simple_mode=True)

IRModuleNode( {GlobalVar(main): PrimFunc([a, b, c]) attrs={"global_symbol": "main", "tir.noalias": (bool)1} {
  for (i, 0, m) {
    for (j, 0, n) {
      c[((i*stride) + (j*stride))] = (a[(i*stride)] + b[((i*stride) + (j*stride))])
    }
  }
}
})

In [37]:
a = tvm.nd.array(np.arange(3, dtype='float32').reshape(3,1))
b = tvm.nd.array(np.arange(4, dtype='float32').reshape((1,4)))
c = tvm.nd.array(np.empty((3,4), dtype='float32'))
mod = tvm.build(s, [A, B, C])
mod(a, b, c)
a, b, c

(<tvm.nd.NDArray shape=(3, 1), cpu(0)>
 array([[0.],
        [1.],
        [2.]], dtype=float32),
 <tvm.nd.NDArray shape=(1, 4), cpu(0)>
 array([[0., 1., 2., 3.]], dtype=float32),
 <tvm.nd.NDArray shape=(3, 4), cpu(0)>
 array([[0., 1., 2., 3.],
        [1., 2., 3., 4.],
        [2., 3., 4., 5.]], dtype=float32))