# Conditional Expression: `if-then-else`
:label:`ch_if_then_else`

The `if-then-else` statement is supported through `te.if_then_else`. In this section, 
we will introduce this expression using computing the lower triangle of an matrix as the example.


In [1]:
import tvm
from tvm import te
import numpy as np
import d2ltvm

In NumPy, we can easily use `np.tril` to obtain the lower triangle.


In [2]:
a = np.arange(12, dtype='float32').reshape((3, 4))
np.tril(a)

array([[ 0.,  0.,  0.,  0.],
       [ 4.,  5.,  0.,  0.],
       [ 8.,  9., 10.,  0.]], dtype=float32)

Now let's implement it in TVM with `if_then_else`. It accepts three arguments, the first one is the condition, if true returning the second argument, otherwise returning the third one.


In [3]:
n, m = te.var('n'), te.var('m')
A = te.placeholder((n, m))
B = te.compute(A.shape, lambda i, j: te.if_then_else(i >= j, A[i,j], 0.0))


Verify the results.


In [4]:
b = tvm.nd.array(np.empty_like(a))
s = te.create_schedule(B.op)
print(tvm.lower(s, [A, B], simple_mode=True))
mod = tvm.build(s, [A, B])
mod(tvm.nd.array(a), b)
b

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(placeholder: T.handle, compute: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        n, m = T.int32(), T.int32()
        placeholder_1 = T.match_buffer(placeholder, (n, m), strides=("stride", "stride"), buffer_type="auto")
        compute_1 = T.match_buffer(compute, (n, m), strides=("stride", "stride"), buffer_type="auto")
        for i, j in T.grid(n, m):
            compute_2 = T.Buffer((compute_1.strides[0] * n,), data=compute_1.data, buffer_type="auto")
            placeholder_2 = T.Buffer((placeholder_1.strides[0] * n,), data=placeholder_1.data, buffer_type="auto")
            compute_2[i * compute_1.strides[0] + j * compute_1.strides[1]] = T.if_then_else(j <= i, placeholder_2[i * placeholder_1.strides[0] + j * placeholder_1.strides[1]], T.float32(0.0))


<tvm.nd.NDArray shape=(3, 4), cpu(0)>
array([[ 0.,  0.,  0.,  0.],
       [ 4.,  5.,  0.,  0.],
       [ 8.,  9., 10.,  0.]], dtype=float32)

## Summary

- We can use `tvm.if_then_else` for the if-then-else statement.
