In [1]:
%load_ext autoreload
%autoreload 2

# Necessary imports
import sys

import torch
import pandas as pd
pd.set_option('display.max_rows', 200)

sys.path.insert(0, "../../..")
from batchflow.models.torch import *

# Setup

In [2]:
DEVICE = torch.device('cuda:0')

BASE_SHAPE = (16, 32, 256, 256)
SHAPES = [(BASE_SHAPE[0], BASE_SHAPE[1] * (2 ** i),
           BASE_SHAPE[2] // (2 ** i), BASE_SHAPE[3] // (2 ** i))
          for i in range(5)]
SHAPES

[(16, 32, 256, 256),
 (16, 64, 128, 128),
 (16, 128, 64, 64),
 (16, 256, 32, 32),
 (16, 512, 16, 16)]

# Conv

In [3]:
%%time
stats_list = []
for shape in SHAPES:
    inputs = torch.rand(*shape, device=DEVICE)

    modules = {
        'conv_1x1': Block(inputs=inputs, layout='c', channels='same', kernel_size=1),
        'conv_1x1_bias': Block(inputs=inputs, layout='c', channels='same', kernel_size=1, bias=True),

        'conv_3x3': Block(inputs=inputs, layout='c', channels='same', kernel_size=3),
        'conv_3x3_bias': Block(inputs=inputs, layout='c', channels='same', kernel_size=3, bias=True),
        'conv_3x3_depthwise': Block(inputs=inputs, layout='w', channels='same', kernel_size=3),
        'conv_3x3_depthwise_bias': Block(inputs=inputs, layout='w', channels='same', kernel_size=3, bias=True),
        'conv_3x3_dilation3': Block(inputs=inputs, layout='c', channels='same', kernel_size=3, dilation=3),

        'K_1357': Block(inputs=inputs, layout='K', channels='same', kernel_size=(1, 3, 5, 7), dilation=1),
        'K_3x3_dilation1357': Block(inputs=inputs, layout='K', channels='same', kernel_size=(3, 3, 3, 3),
                                    dilation=(1, 3, 5, 7)),
    }

    for name, module in modules.items():
        for amp in [False, True]:
            stats = get_module_performance(module, inputs=inputs, n_repeats=100, warmup=10, amp=amp)
            stats_list.append({'module_name': name, 'amp': amp, 'shape': shape, **stats})


dataframe = pd.DataFrame(stats_list).set_index(['shape', 'module_name', 'amp'])
dataframe

CPU times: user 25.9 s, sys: 16.4 s, total: 42.2 s
Wall time: 43 s


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,parameters,"forward time mean, ms","forward time std, ms","forward memory, MB","backward time mean, ms","backward time std, ms","backward memory, MB","time total, ms"
shape,module_name,amp,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
"(16, 32, 256, 256)",conv_1x1,False,1024,0.625838,0.007044,128.379395,2.046299,0.028055,0.003906,372.032135
"(16, 32, 256, 256)",conv_1x1,True,1024,0.71915,0.003102,128.379395,1.056207,0.013646,0.003906,227.913605
"(16, 32, 256, 256)",conv_1x1_bias,False,1056,1.134604,0.005593,128.379395,2.289348,0.023324,0.005859,388.965179
"(16, 32, 256, 256)",conv_1x1_bias,True,1056,1.033707,0.005163,128.379395,1.201721,0.016406,0.005859,263.728577
"(16, 32, 256, 256)",conv_3x3,False,9216,1.68572,0.010448,128.098145,7.371203,0.018811,576.175781,1023.007507
"(16, 32, 256, 256)",conv_3x3,True,9216,1.420374,0.003433,128.098145,1.966549,0.014605,576.175781,401.979675
"(16, 32, 256, 256)",conv_3x3_bias,False,9248,2.198337,0.010281,128.098145,7.582058,0.044991,576.175781,1101.029297
"(16, 32, 256, 256)",conv_3x3_bias,True,9248,1.731039,0.00229,128.098145,2.09969,0.016729,576.175781,445.302063
"(16, 32, 256, 256)",conv_3x3_depthwise,False,288,0.914312,0.005124,128.0,1.977559,0.034835,0.001465,328.591064
"(16, 32, 256, 256)",conv_3x3_depthwise,True,288,0.691681,0.004526,128.0,0.718933,0.014751,0.001465,172.286179


In [4]:
# Amp speeds up things, but does not change memory required!
dataframe.groupby('amp').mean()

Unnamed: 0_level_0,parameters,"forward time mean, ms","forward time std, ms","forward memory, MB","backward time mean, ms","backward time std, ms","backward memory, MB","time total, ms"
amp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
False,458281.955556,1.76436,0.012868,126.101997,3.525029,0.052498,109.933887,598.363288
True,458281.955556,1.158195,0.00944,126.092274,1.30579,0.047201,109.944998,296.948743


# Normalization

In [5]:
%%time
stats_list = []
for shape in SHAPES:
    inputs = torch.rand(*shape, device=DEVICE)

    modules = {
        'batchnorm': Block(inputs=inputs, layout='n', normalization_type='batch'),
        'instancenorm': Block(inputs=inputs, layout='n', normalization_type='instance'),
        'instancenorm_affine': Block(inputs=inputs, layout='n', normalization_type='instance',
                                     normalization={'affine': True}),
        'layernorm': Block(inputs=inputs, layout='n', normalization_type='layer'),
    }

    for name, module in modules.items():
        module = module.to(inputs.device)
        for amp in [False, True]:
            stats = get_module_performance(module, inputs=inputs, n_repeats=100, warmup=10, amp=amp)
            stats_list.append({'module_name': name, 'amp': amp, 'shape': shape, **stats})


dataframe = pd.DataFrame(stats_list).set_index(['shape', 'module_name', 'amp'])
dataframe

CPU times: user 4.02 s, sys: 1.7 s, total: 5.72 s
Wall time: 5.86 s


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,parameters,"forward time mean, ms","forward time std, ms","forward memory, MB","backward time mean, ms","backward time std, ms","backward memory, MB","time total, ms"
shape,module_name,amp,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
"(16, 32, 256, 256)",batchnorm,False,64,1.032504,0.034165,128.000977,1.976578,0.073454,128.000977,344.819611
"(16, 32, 256, 256)",batchnorm,True,64,1.007319,0.010618,128.000977,1.931208,0.010929,128.000977,335.219208
"(16, 32, 256, 256)",instancenorm,False,0,0.893964,0.004516,128.003906,,,,105.609756
"(16, 32, 256, 256)",instancenorm,True,0,0.891875,0.004406,128.003906,,,,104.526428
"(16, 32, 256, 256)",instancenorm_affine,False,64,0.915812,0.005453,128.007812,1.560172,0.015421,128.003906,286.190247
"(16, 32, 256, 256)",instancenorm_affine,True,64,0.916036,0.005839,128.007812,1.552491,0.019947,128.003906,283.566071
"(16, 32, 256, 256)",layernorm,False,64,3.103204,0.013519,520.0,1.472187,0.411092,128.001953,527.725464
"(16, 32, 256, 256)",layernorm,True,64,3.1007,0.005671,520.0,1.437151,0.01622,128.001953,518.169861
"(16, 64, 128, 128)",batchnorm,False,128,0.542812,0.017543,64.000977,0.905804,0.012454,64.000977,170.564224
"(16, 64, 128, 128)",batchnorm,True,128,0.540575,0.00765,64.000977,0.897659,0.012934,64.000977,169.040833


# Named blocks

In [6]:
def simplegate(x):
    return x[:, 0::2] * x[:, 1::2]

In [7]:
%%time
stats_list = []
for shape in SHAPES:
    inputs = torch.rand(*shape, device=DEVICE)

    modules = {
        'resblock': ResBlock(inputs=inputs, attention='se'),
        'resblock_nrep2': ResBlock(inputs=inputs, n_reps=2, attention='se'),
        'resblock_bottleneck4': BottleneckBlock(inputs=inputs, bottleneck=4, attention='se'),
        'resblock_groups8': ResBlock(inputs=inputs, groups=8, attention='se'),
        'resblock_bottleneck4groups8': BottleneckBlock(inputs=inputs, bottleneck=4, groups=8, attention='se'),

        'convnext_layernorm': ConvNeXtBlock(inputs=inputs, normalization_type='layer'),
        'convnext_batchnorm': ConvNeXtBlock(inputs=inputs, normalization_type='batch'),

        'afblock': Block(inputs=inputs, layout='RncwaSc! Rncac!',
                         channels=['2 * same', 'same', 'same', '2 * same', 'same'],
                         kernel_size=[1, 3, 1, 1, 1], bias=True,
                         activation=simplegate, attention='se', self_attention={'bias': True},
                         branch_end={'drop_path': 0.0, 'layer_scale': 0.001})
    }

    for name, module in modules.items():
        module = module.to(inputs.device)
        for amp in [False, True]:
            stats = get_module_performance(module, inputs=inputs, n_repeats=100, warmup=10, amp=amp)
            stats_list.append({'module_name': name, 'amp': amp, 'shape': shape, **stats})


dataframe = pd.DataFrame(stats_list).set_index(['shape', 'module_name', 'amp'])
dataframe

CPU times: user 1min 44s, sys: 57.3 s, total: 2min 41s
Wall time: 2min 43s


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,parameters,"forward time mean, ms","forward time std, ms","forward memory, MB","backward time mean, ms","backward time std, ms","backward memory, MB","time total, ms"
shape,module_name,amp,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
"(16, 32, 256, 256)",resblock,False,19072,8.155047,0.031291,896.006348,23.330101,0.036642,576.17041,3528.040039
"(16, 32, 256, 256)",resblock,True,19072,6.051027,0.01442,896.006348,10.003004,0.032485,576.17041,1846.230225
"(16, 32, 256, 256)",resblock_nrep2,False,38144,16.209383,0.042686,1536.012695,49.156775,0.068283,704.17041,7275.398926
"(16, 32, 256, 256)",resblock_nrep2,True,38144,12.037552,0.032422,1536.012695,22.280271,0.044255,704.17041,3855.489014
"(16, 32, 256, 256)",resblock_bottleneck4,False,42880,17.525905,0.064805,1920.007324,40.849086,0.197714,511.993652,6575.134766
"(16, 32, 256, 256)",resblock_bottleneck4,True,42880,12.114142,0.034823,1920.007324,21.619274,0.046482,511.993652,3832.240967
"(16, 32, 256, 256)",resblock_groups8,False,2944,8.780751,0.026551,896.006348,24.65129,0.08282,384.001953,3751.855469
"(16, 32, 256, 256)",resblock_groups8,True,2944,6.888771,0.015053,896.006348,10.395284,0.042175,384.001953,1957.067505
"(16, 32, 256, 256)",resblock_bottleneck4groups8,False,10624,19.070248,0.279037,1920.007324,44.446571,0.160912,511.993652,7108.094238
"(16, 32, 256, 256)",resblock_bottleneck4groups8,True,10624,14.121785,0.042042,1920.007324,27.148014,0.108335,511.993652,4689.23584
