In [2]:
import torch
from replknet import *
from convert_opt_conv import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def is_lowest(module):
    try:
        first = next(module.children())
    except:
        return True
    return False

### dynamically decide the better operator for convolution, given parameters
def dynamic_fit_conv(kernel_size, input_width, input_height, input_channel=3, output_channel=2):
    # create simulated input data with batch size of 3
    signal = torch.randn(3, input_channel, input_width, input_height)
    kernel = torch.randn(output_channel, input_channel, kernel_size, kernel_size)
    bias = torch.randn(output_channel)
    
    # initialize candidates: Conv2d, FFTConv2d
    test_torch_conv = torch.nn.Conv2d(input_channel, output_channel, kernel_size, bias=True)
    test_torch_conv.weight = torch.nn.Parameter(kernel)
    test_torch_conv.bias = torch.nn.Parameter(bias)
    
    test_fft_conv = FFTConv2d(input_channel, output_channel, kernel_size, bias=True)
    test_fft_conv.weight = torch.nn.Parameter(kernel)
    test_fft_conv.bias = torch.nn.Parameter(bias)
    
    # run and get average time
    iters = 10
    time0 = time.time()
    for i in range(iters):
        out = test_torch_conv(signal)
        if i == 0:
            time0 = time.time()
    time1 = time.time()
    
    for i in range(iters):
        out = test_fft_conv(signal)
        if i == 0:
            time2 = time.time()
    time3 = time.time()
    
    torch_time = (time1 - time0) / iters * 1000
    fft_time = (time3 - time2) / iters * 1000
    
    # 0: torch conv is better; 1: fft conv is better
    return 0 if torch_time < fft_time else 1

def dynamic_convert_opt_conv2d(model, org_in_channels, org_in_width, org_in_height):
    
    cur_in_channels, cur_in_width, cur_in_height = org_in_channels, org_in_width, org_in_height
    
    # traverse through layers and substitute with optimal convolution candidate, w.r.t. key parameters (see function dynamic_fit_conv)
    for i, submodule in enumerate(list(model.children())):
        if not is_lowest(submodule):
            # print('it is {}th father module'.format(i), type(submodule))
            cur_in_channels, cur_in_width, cur_in_height = dynamic_convert_opt_conv2d(submodule, cur_in_channels, cur_in_width, cur_in_height)
        else:
            # print('it is a lowest module', type(submodule), is_lowest(submodule))
            if isinstance(submodule, torch.nn.Conv2d):
                ### fetch internal data sizes
                ### we must feed correct kernel_size, input_size to the 'dynamic_fit_conv' function
                # print(submodule.kernel_size[0], input_width, input_height, submodule.in_channels, submodule.out_channels)
                best_candidate = dynamic_fit_conv(kernel_size=submodule.kernel_size[0], 
                                            input_width=cur_in_width, input_height=cur_in_height, 
                                            input_channel=submodule.in_channels, output_channel=submodule.out_channels)
                if best_candidate == 1:
                    # print('Convert {}th layer to FFTConv2d'.format(i + 1))
                    # using model[i] in replacement only works with torch implemented container, like nn.Sequential, nn.ModuleList
                    model[i] = FFTConv2d(in_channels=submodule.in_channels, 
                                                out_channels=submodule.out_channels,
                                                kernel_size=submodule.kernel_size, 
                                                stride=submodule.stride,
                                                padding=submodule.padding,
                                                dilation=submodule.dilation,
                                                groups=submodule.groups,
                                                bias=True if submodule.bias is not None else False)
            tmp = torch.randn(3, cur_in_channels, cur_in_width, cur_in_height)
            cur_in_channels, cur_in_width, cur_in_height = submodule(tmp).size()[-3:]
    return cur_in_channels, cur_in_width, cur_in_height

In [4]:
# define customed model 1
module1 = torch.nn.Sequential(
    torch.nn.Conv2d(3, 3, 3, padding=1),
    torch.nn.Conv2d(3, 2, 30),
    torch.nn.BatchNorm2d(2),
    torch.nn.ReLU()
)
module2 = torch.nn.Sequential(
    torch.nn.Conv2d(2, 5, 10),
    torch.nn.Conv2d(5, 1, 30),
    torch.nn.BatchNorm2d(1),
    torch.nn.ReLU()
)
module3 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 20, 3),
    torch.nn.Conv2d(20, 3, 40),
    torch.nn.BatchNorm2d(3),
    torch.nn.ReLU()
)
module4 = torch.nn.Sequential(
    torch.nn.Conv2d(3, 3, 3),
    torch.nn.Conv2d(3, 2, 60),
    torch.nn.BatchNorm2d(2),
    torch.nn.ReLU()
)
myNet = torch.nn.Sequential(
    module1, 
    module2,
    module3,
    module4
)
myNet

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(3, 2, kernel_size=(30, 30), stride=(1, 1))
    (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (1): Sequential(
    (0): Conv2d(2, 5, kernel_size=(10, 10), stride=(1, 1))
    (1): Conv2d(5, 1, kernel_size=(30, 30), stride=(1, 1))
    (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (2): Sequential(
    (0): Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1))
    (1): Conv2d(20, 3, kernel_size=(40, 40), stride=(1, 1))
    (2): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (3): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): Conv2d(3, 2, kernel_size=(60, 60), stride=(1, 1))
    (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()


In [5]:
image = torch.randn(10, 3, 512, 512)
time0 = time.time()
iters = 3
for i in range(iters):
    out = myNet(image)
    if i == 0:
        time0 = time.time()
time1 = time.time()
org_inf_time = (time1 - time0) / (iters - 1) * 1000
print('original inference time: {}ms'.format(org_inf_time))
time1 = time.time()
org_in_channels, org_in_width, org_in_height = image.size()[-3:]
dynamic_convert_opt_conv2d(myNet, org_in_channels, org_in_width, org_in_height)
time2 = time.time()
convert_time = time2 - time1
print('Convert to opt used time: {}s'.format(convert_time), myNet)

time2 = time.time()
for i in range(iters):
    out = myNet(image)
    if i == 0:
        time2 = time.time()
time3 = time.time()
opt_inf_time = (time3 - time2) / (iters - 1) * 1000
print('optimized inference time: {}ms'.format(opt_inf_time))

original inference time: 9635.271310806274ms
Convert to opt used time: 44.394498348236084s Sequential(
  (0): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): _FFTConv()
    (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (1): Sequential(
    (0): Conv2d(2, 5, kernel_size=(10, 10), stride=(1, 1))
    (1): _FFTConv()
    (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (2): Sequential(
    (0): Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1))
    (1): _FFTConv()
    (2): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (3): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): _FFTConv()
    (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
)
optimized inference time: 2257.57098197937ms


In [17]:
# define customed model 2
module1 = torch.nn.Sequential(
    torch.nn.Conv2d(3, 3, 3, padding=1),
    torch.nn.Conv2d(3, 2, 30, padding=2, stride=2),
    torch.nn.BatchNorm2d(2),
    torch.nn.ReLU()
)
module2 = torch.nn.Sequential(
    torch.nn.Conv2d(2, 20, 3),
    torch.nn.Conv2d(20, 30, 30, stride=2),
    torch.nn.BatchNorm2d(30),
    torch.nn.ReLU()
)
module3 = torch.nn.Sequential(
    torch.nn.Conv2d(30, 3, 3),
    torch.nn.Conv2d(3, 2, 60, padding=2, stride=2),
    torch.nn.BatchNorm2d(2),
    torch.nn.AdaptiveAvgPool2d(10)
)
myNet2 = torch.nn.Sequential(
    module1, 
    module2,
    module3
)
myNet2

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(3, 2, kernel_size=(30, 30), stride=(2, 2), padding=(2, 2))
    (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (1): Sequential(
    (0): Conv2d(2, 20, kernel_size=(3, 3), stride=(1, 1))
    (1): Conv2d(20, 30, kernel_size=(30, 30), stride=(2, 2))
    (2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (2): Sequential(
    (0): Conv2d(30, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): Conv2d(3, 2, kernel_size=(60, 60), stride=(2, 2), padding=(2, 2))
    (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): AdaptiveAvgPool2d(output_size=10)
  )
)

In [18]:
image = torch.randn(10, 3, 512, 512)
time0 = time.time()
iters = 3
for i in range(iters):
    out = myNet2(image)
    if i == 0:
        time0 = time.time()
time1 = time.time()
org_inf_time = (time1 - time0) / (iters - 1) * 1000
print('original inference time: {}ms'.format(org_inf_time))
time1 = time.time()
org_in_channels, org_in_width, org_in_height = image.size()[-3:]
dynamic_convert_opt_conv2d(myNet2, org_in_channels, org_in_width, org_in_height)
time2 = time.time()
convert_time = time2 - time1
print('Convert to opt used time: {}s'.format(convert_time), myNet2)

time2 = time.time()
for i in range(iters):
    out = myNet2(image)
    if i == 0:
        time2 = time.time()
time3 = time.time()
opt_inf_time = (time3 - time2) / (iters - 1) * 1000
print('optimized inference time: {}ms'.format(opt_inf_time))

original inference time: 3422.3471879959106ms
Convert to opt used time: 13.753021955490112s Sequential(
  (0): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): _FFTConv()
    (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (1): Sequential(
    (0): Conv2d(2, 20, kernel_size=(3, 3), stride=(1, 1))
    (1): Conv2d(20, 30, kernel_size=(30, 30), stride=(2, 2))
    (2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (2): Sequential(
    (0): Conv2d(30, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): _FFTConv()
    (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): AdaptiveAvgPool2d(output_size=10)
  )
)
optimized inference time: 664.1570329666138ms
