In [1]:
%load_ext autoreload
%autoreload 2

# Necessary imports
import os
import sys

import torch
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 200)

sys.path.insert(0, "../../..")
from batchflow import Notifier
from batchflow.models.torch import *

# Setup

In [2]:
DEVICE = torch.device('cuda:0')

BASE_SHAPE = (16, 32, 256, 256)
SHAPES = [(BASE_SHAPE[0], BASE_SHAPE[1] * (2 ** i),
           BASE_SHAPE[2] // (2 ** i), BASE_SHAPE[3] // (2 ** i)) 
          for i in range(5)]
SHAPES

[(16, 32, 256, 256),
 (16, 64, 128, 128),
 (16, 128, 64, 64),
 (16, 256, 32, 32),
 (16, 512, 16, 16)]

# Conv

In [3]:
%%time
stats_list = []
for shape in SHAPES:
    inputs = torch.rand(*shape, device=DEVICE)

    modules = {
        'conv_1x1': Block(inputs=inputs, layout='c', channels='same', kernel_size=1),
        'conv_1x1_bias': Block(inputs=inputs, layout='c', channels='same', kernel_size=1, bias=True),

        'conv_3x3': Block(inputs=inputs, layout='c', channels='same', kernel_size=3),
        'conv_3x3_bias': Block(inputs=inputs, layout='c', channels='same', kernel_size=3, bias=True),
        'conv_3x3_depthwise': Block(inputs=inputs, layout='w', channels='same', kernel_size=3),
        'conv_3x3_depthwise_bias': Block(inputs=inputs, layout='w', channels='same', kernel_size=3, bias=True),
        'conv_3x3_dilation3': Block(inputs=inputs, layout='c', channels='same', kernel_size=3, dilation=3),

        'K_1357': Block(inputs=inputs, layout='K', channels='same', kernel_size=(1, 3, 5, 7), dilation=1),
        'K_3x3_dilation1357': Block(inputs=inputs, layout='K', channels='same', kernel_size=(3, 3, 3, 3), dilation=(1, 3, 5, 7)),
    }

    for name, module in modules.items():
        for amp in [False, True]:
            stats = get_module_performance(module, inputs=inputs, n_repeats=100, warmup=10, amp=amp)
            stats_list.append({'module_name': name, 'amp': amp, 'shape': shape, **stats})


dataframe = pd.DataFrame(stats_list).set_index(['shape', 'module_name', 'amp'])
dataframe

CPU times: user 33.4 s, sys: 17.7 s, total: 51 s
Wall time: 51.7 s


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,parameters,"forward time mean, ms","forward time std, ms","forward memory, MB","backward time mean, ms","backward time std, ms","backward memory, MB","time total, ms"
shape,module_name,amp,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
"(16, 32, 256, 256)",conv_1x1,False,1024,0.650506,0.009159,128.380371,2.106046,0.043341,0.003906,435.206543
"(16, 32, 256, 256)",conv_1x1,True,1024,0.717346,0.002652,128.380371,1.223467,0.038532,0.003906,232.574661
"(16, 32, 256, 256)",conv_1x1_bias,False,1056,1.198123,0.004649,128.380371,2.275003,0.060293,0.003906,395.928223
"(16, 32, 256, 256)",conv_1x1_bias,True,1056,0.996343,0.003214,128.380371,1.369967,0.027851,0.003906,276.730865
"(16, 32, 256, 256)",conv_3x3,False,9216,1.699051,0.016151,128.099121,4.602156,0.059066,0.087891,710.122803
"(16, 32, 256, 256)",conv_3x3,True,9216,1.415402,0.002752,128.099121,1.960329,0.022019,0.087891,392.005188
"(16, 32, 256, 256)",conv_3x3_bias,False,9248,2.25366,0.015858,128.099121,4.742179,0.055564,0.087891,785.6073
"(16, 32, 256, 256)",conv_3x3_bias,True,9248,1.694615,0.007743,128.099121,2.056038,0.028126,0.087891,432.138245
"(16, 32, 256, 256)",conv_3x3_depthwise,False,288,0.878508,0.004724,128.0,2.080521,0.051381,0.001465,338.550781
"(16, 32, 256, 256)",conv_3x3_depthwise,True,288,0.686915,0.004162,128.0,0.767289,0.016721,0.001465,175.661057


In [4]:
# Amp speeds up things, but does not change memory required!
dataframe.groupby('amp').mean()

Unnamed: 0_level_0,parameters,"forward time mean, ms","forward time std, ms","forward memory, MB","backward time mean, ms","backward time std, ms","backward memory, MB","time total, ms"
amp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
False,458281.955556,1.914209,0.015475,86.706489,4.272227,0.053522,87.051074,700.355594
True,458281.955556,1.502709,0.038628,86.702322,1.63562,0.080522,87.071908,368.096958


# Normalization

In [5]:
%%time
stats_list = []
for shape in SHAPES:
    inputs = torch.rand(*shape, device=DEVICE)

    modules = {
        'batchnorm': Block(inputs=inputs, layout='n', normalization_type='batch'),
        'instancenorm': Block(inputs=inputs, layout='n', normalization_type='instance'),
        'instancenorm_affine': Block(inputs=inputs, layout='n', normalization_type='instance',
                                     normalization={'affine': True}),
        'layernorm': Block(inputs=inputs, layout='n', normalization_type='layer'),
    }

    for name, module in modules.items():
        module = module.to(inputs.device)
        for amp in [False, True]:
            stats = get_module_performance(module, inputs=inputs, n_repeats=100, warmup=10, amp=amp)
            stats_list.append({'module_name': name, 'amp': amp, 'shape': shape, **stats})


dataframe = pd.DataFrame(stats_list).set_index(['shape', 'module_name', 'amp'])
dataframe

CPU times: user 4.07 s, sys: 1.54 s, total: 5.61 s
Wall time: 5.72 s


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,parameters,"forward time mean, ms","forward time std, ms","forward memory, MB","backward time mean, ms","backward time std, ms","backward memory, MB","time total, ms"
shape,module_name,amp,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
"(16, 32, 256, 256)",batchnorm,False,64,1.062239,0.006829,128.000977,1.930649,0.012426,128.000977,342.196167
"(16, 32, 256, 256)",batchnorm,True,64,1.064653,0.007264,128.000977,1.93922,0.010426,128.000977,343.314697
"(16, 32, 256, 256)",instancenorm,False,0,0.897957,0.007719,128.003906,,,,105.406364
"(16, 32, 256, 256)",instancenorm,True,0,0.908005,0.005937,128.003906,,,,106.622368
"(16, 32, 256, 256)",instancenorm_affine,False,64,0.940601,0.01156,128.007812,1.531597,0.019149,128.003906,284.662262
"(16, 32, 256, 256)",instancenorm_affine,True,64,0.936352,0.007931,128.007812,1.535387,0.011838,128.003906,284.868835
"(16, 32, 256, 256)",layernorm,False,64,3.113319,0.002467,520.0,1.422284,0.010865,128.001953,513.286804
"(16, 32, 256, 256)",layernorm,True,64,3.112373,0.002531,520.0,1.423372,0.01218,128.001953,514.541931
"(16, 64, 128, 128)",batchnorm,False,128,0.536193,0.004219,64.000977,0.89655,0.011819,64.000977,168.390564
"(16, 64, 128, 128)",batchnorm,True,128,0.539322,0.015783,64.000977,0.912432,0.125144,64.000977,170.194946


# Named blocks

In [6]:
def simplegate(x):
    return x[:, 0::2] * x[:, 1::2]

In [7]:
%%time
stats_list = []
for shape in SHAPES:
    inputs = torch.rand(*shape, device=DEVICE)

    modules = {
        'resblock': ResBlock(inputs=inputs, attention='se'),
        'resblock_nrep2': ResBlock(inputs=inputs, n_reps=2, attention='se'),
        'resblock_bottleneck4': ResBlock(inputs=inputs, bottleneck=4, attention='se'),
        'resblock_groups8': ResBlock(inputs=inputs, groups=8, attention='se'),
        'resblock_bottleneck4groups8': ResBlock(inputs=inputs, bottleneck=4, groups=8, attention='se'),

        'convnext_layernorm': ConvNeXtBlock(inputs=inputs, normalization_type='layer'),
        'convnext_batchnorm': ConvNeXtBlock(inputs=inputs, normalization_type='batch'),
        
        'afblock': Block(inputs=inputs, layout='RncwaSc! Rncac!',
                         channels=['2 * same', 'same', 'same', '2 * same', 'same'],
                         kernel_size=[1, 3, 1, 1, 1], bias=True,
                         activation=simplegate, attention='se', self_attention={'bias': True},
                         branch_end={'drop_path': 0.0, 'layer_scale': 0.001})
    }

    for name, module in modules.items():
        module = module.to(inputs.device)
        for amp in [False, True]:
            stats = get_module_performance(module, inputs=inputs, n_repeats=100, warmup=10, amp=amp)
            stats_list.append({'module_name': name, 'amp': amp, 'shape': shape, **stats})


dataframe = pd.DataFrame(stats_list).set_index(['shape', 'module_name', 'amp'])
dataframe

CPU times: user 2min 2s, sys: 1min 1s, total: 3min 3s
Wall time: 3min 4s


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,parameters,"forward time mean, ms","forward time std, ms","forward memory, MB","backward time mean, ms","backward time std, ms","backward memory, MB","time total, ms"
shape,module_name,amp,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
"(16, 32, 256, 256)",resblock,False,19072,8.372287,0.04932,896.006348,18.659613,0.096503,384.001953,3015.032959
"(16, 32, 256, 256)",resblock,True,19072,6.090035,0.010439,896.006348,9.828188,0.026411,384.001953,1800.89502
"(16, 32, 256, 256)",resblock_nrep2,False,38144,16.563763,0.023789,1536.012695,41.026332,0.143833,384.001953,6405.561523
"(16, 32, 256, 256)",resblock_nrep2,True,38144,12.104493,0.01666,1536.012695,21.855963,0.033683,384.001953,3809.650146
"(16, 32, 256, 256)",resblock_bottleneck4,False,52160,20.923508,0.032501,2176.008301,52.206494,0.915898,1312.695801,8133.459961
"(16, 32, 256, 256)",resblock_bottleneck4,True,52160,14.172451,0.028933,2176.008301,27.874911,0.33938,1312.695801,4723.103516
"(16, 32, 256, 256)",resblock_groups8,False,2944,8.613645,0.023105,896.006348,21.514019,0.093533,384.001953,3358.734375
"(16, 32, 256, 256)",resblock_groups8,True,2944,7.631383,0.023514,896.006348,26.374286,0.09965,384.001953,3783.669189
"(16, 32, 256, 256)",resblock_bottleneck4groups8,False,11840,22.34555,0.062642,2176.008301,57.624238,0.154119,384.001953,8891.181641
"(16, 32, 256, 256)",resblock_bottleneck4groups8,True,11840,22.278,0.068597,2176.008301,56.376773,0.521825,384.001953,8751.759766
