In [9]:
import torch
import torch.nn as nn
import math, time
import sys

sys.path.append('/home/boyuan/Faith-NNVerificationCompiler/')

epsilon = 1e-12

from HandTunedKernels.kernel_test.forward_test_bound import Bounds
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size=1;length=2;dim_in=1024;dim_out=1024;dim_y_out=1024

import tvm
import tvm.testing
from tvm import te
import numpy
import timeit
dtype="float32"
target="llvm"
dev = tvm.device(target, 0)

In [10]:
def mamtul_template(batch_size, length, dim_in, dim_out, dim_y_out):
    ibound = te.placeholder((batch_size, length, dim_in*2+2, dim_out), name='ibound')
    w = te.placeholder((dim_out, dim_y_out), name='w')

    w_pos = te.compute(
        w.shape, 
        lambda i,j: te.if_then_else(w[i,j]>0, w[i,j], 0.),
        name='w_pos'
    )
    w_neg = te.compute(w.shape, lambda i,j: w[i,j]-w_pos[i,j], name='w_neg')

    w_concate = te.compute(
        [dim_out*2, dim_y_out], 
        lambda dout, dout_y: 
        te.if_then_else(
            dout < dim_out, 
            w_pos[dout, dout_y],
            w_neg[dout-dim_out, dout_y]
        ),
        name='w_concate'
    )

    ibound_concate = te.compute(
        [batch_size, length, dim_in, dim_out*2],
        lambda b, l, din, dout:
        te.if_then_else(
            dout < dim_out,
            ibound[b, l, 2+din, dout],
            ibound[b, l, 2+dim_in+din, dout]
        ),
        name='ibound_concate'
    )

    dout = te.reduce_axis((0, dim_out*2), "dout")
    obound_lw = te.compute(
        [batch_size, length, dim_in, dim_y_out],
        lambda b, l, din, dout_y:
        te.sum(ibound_concate[b, l, din, dout] * w_concate[dout, dout_y], axis=dout),
        name='obound_lw'
    )


    s = te.create_schedule([obound_lw.op])
    func = tvm.build(s, [ibound, w, obound_lw], target=target, name="certified_matmul")

    input_bound_raw = numpy.random.rand(batch_size, length, 2+2*dim_in, dim_out).astype(dtype) 
    input_bound_tvm = tvm.nd.array(input_bound_raw, dev)

    input_w_raw = numpy.random.rand(dim_out, dim_y_out).astype(dtype)
    input_w_tvm = tvm.nd.array(input_w_raw, dev)

    output_bound_lw_raw = numpy.zeros((batch_size, length, dim_in, dim_y_out)).astype(dtype)
    output_bound_lw_tvm = tvm.nd.array(output_bound_lw_raw, dev)

    func(input_bound_tvm, input_w_tvm, output_bound_lw_tvm)
    evaluator = func.time_evaluator(func.entry_name, dev, number=1)
    baseline_latency = evaluator(input_bound_tvm, input_w_tvm, output_bound_lw_tvm).mean

    bn = 32
    kfactor = 4
    s = te.create_schedule(obound_lw.op)

    # Blocking by loop tiling
    mo, no, mi, ni = s[obound_lw].tile(obound_lw.op.axis[2], obound_lw.op.axis[3], bn, bn)
    (kaxis,) = s[obound_lw].op.reduce_axis
    ko, ki = s[obound_lw].split(kaxis, factor=kfactor)

    # Host reduction domain outside the blocking loop
    s[obound_lw].reorder(mo, no, ko, ki, mi, ni)

    func = tvm.build(s, [ibound, w, obound_lw], target=target, name="certified_matmul")
    func(input_bound_tvm, input_w_tvm, output_bound_lw_tvm)
    evaluator = func.time_evaluator(func.entry_name, dev, number=1)
    loop_tiling_latency = evaluator(input_bound_tvm, input_w_tvm, output_bound_lw_tvm).mean
    
    print("%d\t%d\t%d\t%f\t%f" % (batch_size, length, dim_in, baseline_latency, loop_tiling_latency))

In [11]:
batch_size = 1
print("batch\tlength\tdim_in\ttemplate(raw)\ttemplate(tuned)")
for length in [2,4,8,16,32,64,128]:
    for dim_in in [64,128,256,512,1024]: # matmul
        if length >=8 and dim_in >=512:
            pass
        elif length >=64 and dim_in >=128:
            pass
        else:
            mamtul_template(batch_size, length, dim_in, dim_out, dim_y_out)

batch	length	dim_in	template(raw)	template(tuned)
1	2	64	1.307692	0.073274
1	2	128	2.568111	0.161789
1	2	256	5.304649	0.320858
1	2	512	10.823001	0.692673
1	2	1024	20.684462	1.112095
1	4	64	2.520366	0.161275
1	4	128	5.043037	0.319600
1	4	256	10.066399	0.635296
1	4	512	20.691795	1.275573
1	4	1024	56.506438	2.609836
1	8	64	6.823598	0.564768
1	8	128	12.274440	1.162039
1	8	256	26.274846	1.431888
1	16	64	13.702192	0.669615
1	16	128	28.436459	1.431849
1	16	256	41.498004	2.229562
1	32	64	20.670114	1.100166
1	32	128	41.519295	2.223166


In [None]:
# ================ Below are expired ===========

In [11]:
lb = torch.rand(1,length,dim_out).to(device)
ub = lb + torch.rand(1,length,dim_out).to(device)
lw = torch.rand(1,length,dim_in,dim_out).to(device) - 0.5
uw = torch.rand(1,length,dim_in,dim_out).to(device) - 0.5
bound = Bounds(p=2,eps=0.5,lw=lw,lb=lb,uw=uw,ub=ub)
W = torch.rand(dim_y_out, dim_out).to(device) - 0.5
l, u = bound.concretize()
bound_output = bound.matmul(W)

In [96]:
input_bound_lb_raw = numpy.random.rand(batch_size, length, dim_out).astype(dtype)
input_bound_lb_tvm = tvm.nd.array(input_bound_lb_raw, dev)

input_bound_ub_raw = input_bound_lb_raw + numpy.random.rand(batch_size, length, dim_out).astype(dtype),
input_bound_ub_tvm = tvm.nd.array(input_bound_ub_raw, dev)

input_bound_lw_raw = numpy.random.rand(batch_size, length, dim_in, dim_out).astype(dtype) - 0.5
input_bound_lw_tvm = tvm.nd.array(input_bound_lw_raw, dev)

input_bound_uw_raw = numpy.random.rand(batch_size, length, dim_in, dim_out).astype(dtype) - 0.5
input_bound_uw_tvm = tvm.nd.array(input_bound_uw_raw, dev)

input_w_raw = numpy.random.rand(dim_y_out, dim_out).astype(dtype)
input_w_tvm = tvm.nd.array(input_w_raw, dev)

pos_mask_raw = numpy.zeros((dim_y_out, dim_out)).astype(dtype)
pos_mask_tvm = tvm.nd.array(pos_mask_raw, dev)

w_pos_raw = numpy.zeros((dim_y_out, dim_out)).astype(dtype)
w_pos_tvm =  tvm.nd.array(w_pos_raw, dev)

output_bound_lw_raw = numpy.random.rand(batch_size, length, dim_in, dim_y_out).astype(dtype) - 0.5
output_bound_lw_tvm = tvm.nd.array(output_bound_lw_raw, dev)

In [16]:
ibound_lb = te.placeholder((batch_size, length, dim_out))
ibound_ub = te.placeholder((batch_size, length, dim_out))
ibound_lw = te.placeholder((batch_size, length, dim_in, dim_out))
ibound_uw = te.placeholder((batch_size, length, dim_in, dim_out))
w = te.placeholder((dim_out, dim_y_out))
obound_lb = te.placeholder((batch_size, length, dim_y_out))
obound_ub = te.placeholder((batch_size, length, dim_y_out))
obound_lw = te.placeholder((batch_size, length, dim_in, dim_y_out))
obound_uw = te.placeholder((batch_size, length, dim_in, dim_y_out))

pos_mask = te.compute(w.shape, 
                      lambda i,j: te.if_then_else(w[i,j]>0, 1., 0.)
                     )

# s = te.create_schedule(pos_mask.op)
# func = tvm.build(s, [w,pos_mask], target=target, name="foo")
# func(input_w_tvm, pos_mask_tvm)
# evaluator = func.time_evaluator(func.entry_name, dev, number=1)
# print("Baseline: %f" % evaluator(input_w_tvm, pos_mask_tvm).mean)
# print(tvm.lower(s, [w, pos_mask], simple_mode=True))

w_pos = te.compute(w.shape, lambda i,j: w[i,j]*pos_mask[i,j])
# w_pos = te.compute(w.shape, lambda i,j: function_lambda(w, pos_mask, i, j))
w_neg = te.compute(w.shape, lambda i,j: w[i,j]-w_pos[i,j])

# s = te.create_schedule(w_pos.op)
# func = tvm.build(s, [w, w_pos], target=target, name="foo")
# func(input_w_tvm, w_pos_tvm)
# evaluator = func.time_evaluator(func.entry_name, dev, number=1)
# print("Baseline: %f" % evaluator(input_w_tvm, w_pos_tvm).mean)
# print(tvm.lower(s, [w, w_pos], simple_mode=True))


dout = te.reduce_axis((0, dim_out), "dout")
res_a = te.compute(obound_lw.shape,
                 lambda b,l,din,dout_y: 
                 te.sum(ibound_lw[b,l,din,dout]*w_pos[dout,dout_y], axis=dout)#+te.sum(ibound_uw[b,l,din,dout]*w_neg[dout_y,dout], axis=dout)
                )

print(res_a.shape)

# s = te.create_schedule(res_a.op)
# func = tvm.build(s, [w, ibound_lw, res_a], target=target, name="foo")
# func(input_w_tvm, input_bound_lw_tvm, output_bound_lw_tvm)
# evaluator = func.time_evaluator(func.entry_name, dev, number=1)
# print("Baseline: %f" % evaluator(input_w_tvm, input_bound_lw_tvm, output_bound_lw_tvm).mean)
# print(tvm.lower(s, [w, ibound_lw, res_a], simple_mode=True))


res_b = te.compute(obound_lw.shape,
                   lambda b, l, din, dout_y:
                   te.sum(ibound_uw[b,l,din,dout]*w_neg[dout,dout_y], axis=dout)
                  )
# print(res_b.shape)
# s = te.create_schedule(res_b.op)
# func = tvm.build(s, [w, ibound_uw, res_b], target=target, name="foo")
# func(input_w_tvm, input_bound_lw_tvm, output_bound_lw_tvm)
# evaluator = func.time_evaluator(func.entry_name, dev, number=1)
# print("Baseline: %f" % evaluator(input_w_tvm, input_bound_lw_tvm, output_bound_lw_tvm).mean)
# print(tvm.lower(s, [w, ibound_uw, res_b], simple_mode=True))


res = te.compute(obound_lw.shape, lambda b,l,din,dout_y:res_a[b,l,din,dout_y]+res_b[b,l,din,dout_y])
print(ibound_uw.shape, ibound_lw.shape, w.shape, res.shape)
s = te.create_schedule(res.op)
# print()
# print(tvm.lower(s, [ibound_lw,ibound_uw, res, w], simple_mode=True))

func = tvm.build(s, [ibound_lw, ibound_uw, w, res], target=target, name="certified_matmul")


[1, 2, 64, 64]
[1, 2, 64, 64] [1, 2, 64, 64] [64, 64] [1, 2, 64, 64]


TVMError: Traceback (most recent call last):
  6: TVMFuncCall
  5: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)>::AssignTypedLambda<tvm::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#5}>(tvm::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  4: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
  3: tvm::ScheduleToModule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&)
  2: tvm::te::InferBound(tvm::te::Schedule const&)
  1: tvm::te::InferRootBound(tvm::te::Stage const&, tvm::te::GraphContext const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > >*)
  0: tvm::te::BaseComputeOpNode::GatherBound(tvm::te::Operation const&, std::unordered_map<tvm::te::Tensor, tvm::te::TensorDom, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::te::TensorDom> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > >*) const
  File "/home/tianqi_tang/tvm/src/te/operation/compute_op.cc", line 256
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (!out_dom_map->count(this->reduce_axis[i])) is false: 