In [1]:
import tvm 
import numpy as np

## Describe Batchwise Computation

- For operators which have the same shape, we can put them together as the inputs of `tvm.compute`, if we wish they can be scheduled together in the next schedule procedure

In [2]:
n = tvm.var(name="n")
m = tvm.var(name="m")
A0 = tvm.placeholder(shape=(m, n), name="A0")
A1 = tvm.placeholder(shape=(m, n), name="A1")

In [3]:
B0, B1 = tvm.compute(shape=(m, n), fcompute=lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), 
                     name="B")

In [6]:
# The generated IR code would be:
s = tvm.create_schedule(B0.op)
print(tvm.lower(sch=s, args=[A0, A1, B0, B1], simple_mode=True))

produce B {
  for (i, 0, m) {
    for (j, 0, n) {
      B.v0[((i*n) + j)] = (A0[((i*n) + j)] + 2.000000f)
      B.v1[((i*n) + j)] = (A1[((i*n) + j)]*3.000000f)
    }
  }
}



## Describe Reduction with Collaborative Inputs

- Sometimes, we require multiple inputs to express some reduction operators, and the inputs will collaborate together, e.g. `argmax`
- In the reduction procedure, `argmax` need to compare the value of operands, also need to keep the index of operand. This can be expressed with `comm_reducer` 

In [7]:
# x and y are the operands of reduction, both of them are a tuple of index and value
def fcombine(x, y):
    lhs = tvm.select(cond=(x[1] >= y[1]), t=x[0], f=y[0])
    rhs = tvm.select(cond=(x[1] >= y[1]), t=x[1], f=y[1])
    return lhs, rhs

In [8]:
# our identity element also need to be a tuple, so `fidentity` accepts 
# two types as inputs
def fidentity(t0, t1):
    return tvm.const(value=-1, dtype=t0), tvm.min_value(dtype=t1)

In [9]:
argmax = tvm.comm_reducer(fcombine, fidentity, name="argmax")

In [10]:
# describe the reduction computation
m = tvm.var("m")
n = tvm.var("n")
idx = tvm.placeholder(shape=(m, n), name="idx", dtype="int32")
val = tvm.placeholder(shape=(m, n), name="val", dtype="int32")
k = tvm.reduce_axis(dom=(0, n), name="k")
T0, T1 = tvm.compute(shape=(m, ), fcompute=lambda i: argmax((idx[i, k], val[i, k]), axis=k),
                     name="T")


In [11]:
# The generated IR code would be:
s = tvm.create_schedule(T0.op)
print(tvm.lower(sch=s, args=[idx, val, T0, T1], simple_mode=True))

produce T {
  for (i, 0, m) {
    T.v0[i] = -1
    T.v1[i] = -2147483648
    for (k, 0, n) {
      T.v0[i] = tvm_if_then_else((T.v1[i] < val[((i*n) + k)]), idx[((i*n) + k)], T.v0[i])
      T.v1[i] = tvm_if_then_else((T.v1[i] < val[((i*n) + k)]), val[((i*n) + k)], T.v1[i])
    }
  }
}



## Schedule Operation with Tuple Inputs


- Although you will get multiple outputs with one batch operation, but they can only be scheduled together in terms of operation.

In [13]:
n = tvm.var("n")
m = tvm.var("m")
A0 = tvm.placeholder((m, n), name='A0')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name='B')
A1 = tvm.placeholder((m, n), name='A1')
C = tvm.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name='C')

In [14]:
s = tvm.create_schedule(C.op)

In [16]:
s[B0].compute_at(s[C], C.op.axis[0])

In [17]:
print(tvm.lower(sch=s, args=[A0, A1, C], simple_mode=True))

// attr [B.v0] storage_scope = "global"
allocate B.v0[float32 * 1 * n]
// attr [B.v1] storage_scope = "global"
allocate B.v1[float32 * 1 * n]
produce C {
  for (i, 0, m) {
    produce B {
      for (j, 0, n) {
        B.v0[j] = (A0[((i*n) + j)] + 2.000000f)
        B.v1[j] = (A0[((i*n) + j)]*3.000000f)
      }
    }
    for (j, 0, n) {
      C[((i*n) + j)] = (A1[((i*n) + j)] + B.v0[j])
    }
  }
}

