In [1]:
import tvm
import topi
import torch
from tvm import autotvm
import sys
sys.path.append("../..")

from fast3d.autotuning import auto_tune, the_best_config_model
from fast3d.get_data import *
from fast3d.schedules import schedule_direct_3d_cuda, schedule_conv3d_transpose_nchw_cuda, schedule_conv3d_nchw_cuda
from fast3d.utils import _grad_input_padding,return_log_name,reshape_inp_weight_shape, conv3d_ZHW_size

### How to use it
You can define a data_dict which contain the conv layer information you want to autotune.

After autotuning, you can copy the log file to "auto_log" folder.

Remeber change the config.py 

In [2]:
data_dict = {
             0:{
                      "inshape":(1, 1, 64, 128, 128),
                      "kershape":(24, 1, 1, 5, 5),
                      "outshape":(1, 24, 64, 64, 64),
                      "stride":(1, 2, 2),
                      "padding":(0, 2, 2),
                      "dilation":1,
                      "groups":1,
                      "bias": None
                     },
    
             1:{
                      "inshape":(1, 24, 64, 128, 128),
                      "kershape":(12, 24, 1, 1, 1),
                      "outshape":(1, 12, 64, 128, 128),
                      "stride":(1, 1, 1),
                      "padding":(0, 0, 0),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },
    
            2:{
                      "inshape":(1, 32, 64, 128, 128),
                      "kershape":(12, 32,1, 1, 1),
                      "outshape":(1, 12, 64, 128, 128),
                      "stride":(1, 1, 1),
                      "padding":(0, 0, 0),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },
            3:{
                      "inshape":(1, 24, 64, 64, 64),
                      "kershape":(24, 24, 3, 3, 3),
                      "outshape":(1, 24, 64, 64, 64),
                      "stride":(1, 1, 1),
                      "padding":(1, 1, 1),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },  
            4:{
                      "inshape":(1, 64, 32, 32, 32),
                      "kershape":(64,64,3,3,3),
                      "outshape":(1, 64, 32, 32, 32),
                      "stride":(1, 1, 1),
                      "padding":(1, 1, 1),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },
            5:{
                      "inshape":(1, 192, 16, 16, 16),
                      "kershape":(192,192,3,3,3),
                      "outshape":(1, 192, 16, 16, 16),
                      "stride":(1, 1, 1),
                      "padding":(1, 1, 1),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },
            6:{
                      "inshape":(1, 192, 8, 8, 8),
                      "kershape":(192,192, 3, 3, 3),
                      "outshape":(1, 192, 8, 8, 8),
                      "stride":(1, 1, 1),
                      "padding":(1, 1, 1),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },
            7:{
                      "inshape":(1, 384, 8, 8, 8),
                      "kershape":(384, 384, 3, 3, 3),
                      "outshape":(1, 384, 8, 8, 8),
                      "stride":(1, 1, 1),
                      "padding":(1, 1, 1),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },
            8:{
                      "inshape":(1, 24, 32, 32, 32),
                      "kershape":(64, 24, 1, 1, 1),
                      "outshape":(1, 64, 32, 32, 32),
                      "stride":(1, 1, 1),
                      "padding":(0, 0, 0),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },
            9:{
                      "inshape":(1, 64, 32, 32, 32),
                      "kershape":(24, 64, 1, 1, 1),
                      "outshape":(1, 24, 32, 32, 32),
                      "stride":(1, 1, 1),
                      "padding":(0, 0, 0),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },
            10:{
                      "inshape":(1, 64, 16, 16, 16),
                      "kershape":(192, 64, 1, 1, 1),
                      "outshape":(1, 192, 16, 16, 16),
                      "stride":(1, 1, 1),
                      "padding":(0, 0, 0),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },   
            11:{
                      "inshape":(1, 192, 16, 16, 16),
                      "kershape":(64, 192, 1, 1, 1),
                      "outshape":(1, 64, 16, 16, 16),
                      "stride":(1, 1, 1),
                      "padding":(0, 0, 0),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },    
            12:{
                      "inshape":(1, 192, 8, 8, 8),
                      "kershape":(384, 192, 1, 1, 1),
                      "outshape":(1, 384, 8, 8, 8),
                      "stride":(1, 1, 1),
                      "padding":(0, 0, 0),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },        
            13:{
                      "inshape":(1, 384, 8, 8, 8),
                      "kershape":(192, 384, 1, 1, 1),
                      "outshape":(1, 192, 16, 16, 16),
                      "stride":(1, 1, 1),
                      "padding":(0, 0, 0),
                      "dilation":1,
                       "groups":1,
                       "bias": None
                      },        
            }

In [3]:
def print_args(log_name, func, para, env='cuda'):
    with autotvm.apply_history_best(log_name):
        with tvm.target.create(env):
            s, arg_bufs = func(*para)
            print(arg_bufs)

In [4]:
def test_correct(fw_log, bwd_inp_g_log, bwd_wei_g_log, inshape, kershape, outshape, stride, padding, dilation, groups, bias, schedule_direct_3d_cuda, schedule_conv3d_transpose_nchw_cuda, schedule_conv3d_nchw_cuda):
    inp = get_rand(inshape)
    wei = get_rand(kershape)
    out = get_zero(outshape)
    inp_grad = get_zero(inshape)
    output_padding = _grad_input_padding(outshape, inshape, stride, padding, kershape[2:])   

    #fw
    fw_func = the_best_config_model(fw_log, schedule_direct_3d_cuda,(inshape, kershape, stride, padding, dilation))     
    fw_func(inp, wei, out)


    #bwd inp grad
    bw_inp_func = the_best_config_model(bwd_inp_g_log, schedule_conv3d_transpose_nchw_cuda, (outshape, kershape, stride, padding, output_padding, 'float32') )        
    bw_inp_func(out, wei, inp_grad)

    #bwd wei grad
    inp_rshape = reshape_inp_weight_shape(inshape)
    grad_rshape = reshape_inp_weight_shape(outshape)

    bw_wei_func = the_best_config_model(bwd_wei_g_log, schedule_conv3d_nchw_cuda,(inp_rshape, grad_rshape, dilation, padding, stride, groups) )        

    inp_r = get_rand(inp_rshape)
    grad_r = get_rand(grad_rshape)

    kz = conv3d_ZHW_size(inp_rshape[2], dilation, padding[0], grad_rshape[2],stride[0])
    kh = conv3d_ZHW_size(inp_rshape[3], dilation, padding[1], grad_rshape[3],stride[1])
    kw = conv3d_ZHW_size(inp_rshape[4], dilation, padding[2], grad_rshape[4],stride[2])
    wei_r = get_rand((kershape[1],kershape[0],kz,kh,kw))
    bw_wei_func(inp_r,grad_r,wei_r)

In [5]:
def get_index(infile):
    dispatch_context = autotvm.apply_history_best(infile)
    targetkey = dispatch_context.best_by_targetkey
    for a,i in targetkey.items():
        break
    index = i[0].config.index
    return index

def write2best_log(infile):
    index = get_index(infile)
    fo = open(infile, "rb")
    foo = open("best.log","a")
    for line in fo.readlines():
        l = line.decode()
        if l.find(str(index)) != -1:
            foo.writelines(str(line.decode()))
    fo.close()
    foo.close()

In [16]:

def autotuning_conv3d(data_dict, measure_option, n_trial, autotune_range = "all"):
    for k,val in data_dict.items():
        if (autotune_range != 'all' and (k not in autotune_range)):
            continue
        
        print("Autotuning the ",k," layer")
        
        #return keys
        inshape, kershape, outshape, stride, padding, dilation, groups, bias = val.values()
        
        #get three log for one layer
        fw_log, bwd_inp_g_log, bwd_wei_g_log = return_log_name(val.values())

        #forward
        auto_tune(schedule_direct_3d_cuda, (inshape, kershape, stride, padding, dilation), fw_log, n_trial, measure_option)

        #input grad
        output_padding = _grad_input_padding(outshape, inshape, stride, padding, kershape[2:])
        auto_tune(schedule_conv3d_transpose_nchw_cuda, (outshape, kershape, stride, padding, output_padding, 'float32'), bwd_inp_g_log,  n_trial, measure_option)

        #weight grad
        inp_rshape = reshape_inp_weight_shape(inshape)
        grad_rshape = reshape_inp_weight_shape(outshape)
        auto_tune(schedule_conv3d_nchw_cuda, (inp_rshape, grad_rshape, dilation, padding, stride, groups), bwd_wei_g_log,  n_trial, measure_option)
        
        #testing
        test_correct(fw_log, bwd_inp_g_log, bwd_wei_g_log, \
                     inshape, kershape, outshape, stride, padding, dilation, groups, bias, \
                     schedule_direct_3d_cuda, schedule_conv3d_transpose_nchw_cuda, schedule_conv3d_nchw_cuda)
        
        #write best config to best.log
        write2best_log(fw_log)
        write2best_log(bwd_inp_g_log)
        write2best_log(bwd_wei_g_log)

In [17]:
n_trial = 800
measure_option = autotvm.measure_option(
        builder = autotvm.LocalBuilder(),
        runner = autotvm.LocalRunner(repeat = 2, timeout = 4))

In [None]:
autotuning_conv3d(data_dict, measure_option, n_trial, autotune_range = 'all')

In [None]:
Autotuning the  0  layer
 Current/Best: 2415.93/2582.34 GFLOPS | Progress: (800/800) | 2953.19 s Done.
 Current/Best:  155.52/ 218.30 GFLOPS | Progress: (800/800) | 2433.96 s Done.
 Current/Best:   86.82/ 178.83 GFLOPS | Progress: (800/800) | 3117.65 s Done.
Autotuning the  1  layer
 Current/Best: 1421.88/1509.51 GFLOPS | Progress: (800/800) | 2550.92 s Done.
 Current/Best: 2952.57/2994.98 GFLOPS | Progress: (800/800) | 3023.75 s Done.
 Current/Best:   36.76/  44.35 GFLOPS | Progress: (800/800) | 2674.58 s Done.
Autotuning the  2  layer
 Current/Best: 1602.93/1655.66 GFLOPS | Progress: (800/800) | 3044.40 s Done.
 Current/Best: 4261.29/4361.13 GFLOPS | Progress: (800/800) | 2765.42 s Done.
 Current/Best:   39.25/  60.49 GFLOPS | Progress: (800/800) | 2656.35 s Done.
Autotuning the  3  layer
 Current/Best: 5707.64/5811.98 GFLOPS | Progress: (800/800) | 3313.88 s Done.
 Current/Best: 6711.03/7495.22 GFLOPS | Progress: (800/800) | 3412.51 s Done.
 Current/Best:  951.74/1479.39 GFLOPS | Progress: (800/800) | 3568.30 s Done.
Autotuning the  4  layer
 Current/Best: 2679.93/7148.30 GFLOPS | Progress: (800/800) | 3189.32 s Done.
 Current/Best: 6958.82/7565.79 GFLOPS | Progress: (800/800) | 3287.95 s Done.
 Current/Best: 1232.42/3567.63 GFLOPS | Progress: (800/800) | 3048.67 s Done.
Autotuning the  5  layer
 Current/Best: 5739.71/7246.40 GFLOPS | Progress: (800/800) | 3337.61 s Done.
 Current/Best: 3460.18/5630.15 GFLOPS | Progress: (800/800) | 3105.91 s Done.
 Current/Best: 5797.22/7418.22 GFLOPS | Progress: (756/800) | 3169.93 s
Autotuning the  6  layer
 Current/Best: 2054.31/4404.14 GFLOPS | Progress: (800/800) | 3194.88 s Done.
 Current/Best: 3057.28/3420.56 GFLOPS | Progress: (800/800) | 3459.32 s Done.
 Current/Best: 3021.82/5116.12 GFLOPS | Progress: (756/800) | 3131.69 s