In [1]:
%load_ext autoreload
%env CUDA_VISIBLE_DEVICES = 2

env: CUDA_VISIBLE_DEVICES=2


In [9]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from ptflops import get_model_complexity_info

%autoreload 2
pd.set_option("display.precision", 2)

In [10]:
device = torch.device('cuda')
device

device(type='cuda')

In [88]:
def module_time_decorator(module):
    """track module backward/forward pass time"""
    def wrapper(**kwargs):
        if kwargs['mode'] == 'forward':
            kwargs['start'].record()
            outputs = module(kwargs['inputs'])
            kwargs['end'].record()
            torch.cuda.synchronize()
            
        elif kwargs['mode'] == 'backward':
            outputs = module(kwargs['inputs'])
            kwargs['start'].record()
            outputs.backward(outputs)
            kwargs['end'].record()
            torch.cuda.synchronize()
            
        else:
            raise KeyError
            
        return kwargs['start'].elapsed_time(kwargs['end']) 
    return wrapper

In [93]:
def module_memory_decorator(module):
    """track forward/backward module memory consumption"""
    def wrapper(**kwargs):
        if kwargs['mode'] == 'forward':
            start_memory = get_memory(device=kwargs['device'])
            out = module(kwargs['inputs'])
            end_memory = get_memory(reset_memory=False, device=kwargs['device']) 
            
        elif kwargs['mode'] == 'backward':
            outputs = module(kwargs['inputs'])
            start_memory = get_memory(device=kwargs['device'])
            outputs.backward(outputs) 
            end_memory = get_memory(reset_memory=False, device=kwargs['device']) 
        else:
            raise KeyError
            
        return end_memory - start_memory
    return wrapper

In [103]:
def get_memory(reset_memory=True, device=None):
    """Take current max allocated memory, either with or without resetting"""
    if reset_memory:
        torch.cuda.reset_peak_memory_stats()
            
    max_memory = torch.cuda.max_memory_allocated(device)
        
    return max_memory
    
    
def make_initialization_inputs(inputs, device=None):
    """ Take either tensor, shape tuple or list of them, and always return tensor or list of them. """
    if isinstance(inputs, torch.Tensor):
        pass
    elif isinstance(inputs, tuple):
        inputs = torch.rand(*inputs, device=device)
    elif isinstance(inputs, list):
        inputs = [make_initialization_inputs(item, device=device) for item in inputs]
    return inputs

    
    
def tracker(module, inputs, repeats=300, warmup=40, device=None, track_backward=True, channels_last=False, amp=False) -> dict:
    """Track module #macs, #parameters, time and memory consumption on forward and backward pass for a given inputs tensor or inputs shape"""
    
    total_start = torch.cuda.Event(enable_timing=True)
    total_end = torch.cuda.Event(enable_timing=True)
    total_start.record()
    
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    result = {}
    
    torch.cuda.empty_cache()
    
    inputs = make_initialization_inputs(inputs=inputs, device=device)
    module.to(device)
    time_module = module_time_decorator(module)
    memory_module = module_memory_decorator(module)
    
    if channels_last:
        inputs.to(memory_format=torch.channels_last)
        module.to(memory_format=torch.channels_last)          

    forward_timings = []
    backward_timings = []
    
    for i in range(repeats + warmup):
        
        with torch.cuda.amp.autocast(enabled=amp):
            if i < warmup:
                outputs = module(inputs)
                del outputs
                i += 1
                continue

            forward_time = time_module(mode='forward', inputs=inputs, start=start, end=end)
            forward_timings.append(forward_time) 
        
        if track_backward:
            backward_time = time_module(mode='backward', inputs=inputs, start=start, end=end)
            backward_timings.append(backward_time)
   
    result['forward time mean(ms)'] = np.mean(forward_timings)
    result['forward time std(ms)'] = np.std(forward_timings)
    
    forward_memory = memory_module(mode='forward', inputs=inputs, device=device)
    result['forward memory(MB)'] = forward_memory / 2**20
   
    if track_backward:
        result['backward time mean(ms)'] = np.mean(backward_timings)
        result['backward time std(ms)'] = np.std(backward_timings)
        
        backward_memory = memory_module(mode='backward', inputs=inputs, device=device)
        result['backward memory(MB)'] = backward_memory / 2**20
          
    macs, params = get_model_complexity_info(module, tuple(inputs.shape[1:]), as_strings=False, print_per_layer_stat=False)
    result['macs'] = macs
    result['parameters'] = float(params)
    total_end.record()
    torch.cuda.synchronize()
    result['time total(ms)'] = total_start.elapsed_time(total_end)
    
    return result

In [101]:
shape = (1, 64, 128, 128)

module_collection = {'conv_64_512_1x1': nn.Conv2d(kernel_size=1, in_channels=64, out_channels=512), 
                     'conv_64_512_3x3': nn.Conv2d(kernel_size=3, in_channels=64, out_channels=512), 
                     'bottleneck_64_512_3x3': nn.Sequential(*[nn.Conv2d(kernel_size=1, in_channels=64, out_channels=32), 
                                                             nn.Conv2d(kernel_size=3, in_channels=32, out_channels=512),
                                                              ]),
                      'conv_64_512_3x3_g2': nn.Conv2d(kernel_size=7, in_channels=64, out_channels=512, groups=2),
                      'conv_64_512_3x3_g8': nn.Conv2d(kernel_size=7, in_channels=64, out_channels=512, groups=8),
                      'conv_padding': nn.Conv2d(kernel_size=3, in_channels=64, out_channels=512, padding=1),
                      'conv_nn_Padding': nn.Sequential(nn.ZeroPad2d(1),
                                                       nn.Conv2d(kernel_size=3, in_channels=64, out_channels=512))
                    }

In [106]:
# contiguous
module_collection_stats = pd.DataFrame(index=module_collection.keys(), 
                                               columns=['forward time mean(ms)', 'forward time std(ms)',
                                                        'backward time mean(ms)', 'backward time std(ms)',
                                                        'forward memory(MB)','backward memory(MB)',
                                                        'macs', 'parameters', 'time total(ms)'])
        
for module_name, module_value in module_collection.items():
    module_collection_stats.loc[module_name] = tracker(module_value, inputs=shape, device=device, track_backward=True)

module_collection_stats

Unnamed: 0,forward time mean(ms),forward time std(ms),backward time mean(ms),backward time std(ms),forward memory(MB),backward memory(MB),macs,parameters,time total(ms)
conv_64_512_1x1,0.38,0.01,0.3,0.05,32.22,0.12,545259520.0,33280.0,1156.15
conv_64_512_3x3,0.81,0.01,1.04,0.01,35.13,1.12,4690151424.0,295424.0,2149.93
bottleneck_64_512_3x3,0.79,0.04,1.28,0.19,35.57,2.56,2383218688.0,150048.0,906.16
conv_64_512_3x3_g2,2.01,0.02,2.48,0.03,33.23,3.06,11956733952.0,803328.0,2982.55
conv_64_512_3x3_g8,0.69,0.02,1.25,0.02,31.46,0.77,2994898944.0,201216.0,1253.92
conv_padding,0.81,0.1,3.28,43.92,35.13,87.12,4840226816.0,295424.0,2819.29
conv_nn_Padding,0.82,0.02,1.07,0.11,39.25,1.12,4840226816.0,295424.0,1683.52


In [107]:
# channels last
module_collection_stats = pd.DataFrame(index=module_collection.keys(), 
                                               columns=['forward time mean(ms)', 'forward time std(ms)',
                                                        'backward time mean(ms)', 'backward time std(ms)',
                                                        'forward memory(MB)','backward memory(MB)',
                                                        'macs', 'parameters','time total(ms)'])
        
for module_name, module_value in module_collection.items():
    module_collection_stats.loc[module_name] = tracker(module_value, inputs=shape, device=device, track_backward=True, channels_last=True)

module_collection_stats

Unnamed: 0,forward time mean(ms),forward time std(ms),backward time mean(ms),backward time std(ms),forward memory(MB),backward memory(MB),macs,parameters,time total(ms)
conv_64_512_1x1,0.45,0.04,0.35,0.2,36.22,34.0,545259520.0,33280.0,415.22
conv_64_512_3x3,1.0,0.02,1.27,2.72,37.22,34.0,4690151424.0,295424.0,1125.64
bottleneck_64_512_3x3,0.84,0.05,4.72,44.39,38.66,34.0,2383218688.0,150048.0,3068.36
conv_64_512_3x3_g2,1.9,0.03,13.18,0.05,37.23,39.2,11956733952.0,803328.0,5726.83
conv_64_512_3x3_g8,1.06,0.02,11.95,3.78,35.46,34.77,2994898944.0,201216.0,4536.95
conv_padding,1.0,0.02,1.77,0.06,37.47,34.0,4840226816.0,295424.0,1703.9
conv_nn_Padding,1.07,0.07,3.74,44.69,42.38,34.0,4840226816.0,295424.0,3297.78


In [109]:
# amp
module_collection_stats = pd.DataFrame(index=module_collection.keys(), 
                                               columns=['forward time mean(ms)', 'forward time std(ms)',
                                                        'backward time mean(ms)', 'backward time std(ms)',
                                                        'forward memory(MB)','backward memory(MB)',
                                                        'macs', 'parameters','time total(ms)'])
        
for module_name, module_value in module_collection.items():
    module_collection_stats.loc[module_name] = tracker(module_value, inputs=shape, device=device, track_backward=True, amp=True)

module_collection_stats

Unnamed: 0,forward time mean(ms),forward time std(ms),backward time mean(ms),backward time std(ms),forward memory(MB),backward memory(MB),macs,parameters,time total(ms)
conv_64_512_1x1,0.37,0.01,0.27,0.06,36.22,34.0,545259520.0,33280.0,367.1
conv_64_512_3x3,0.49,0.01,1.2,1.5,37.88,34.0,4690151424.0,295424.0,1012.36
bottleneck_64_512_3x3,0.66,0.06,5.01,48.93,38.66,34.0,2383218688.0,150048.0,4060.72
conv_64_512_3x3_g2,10.58,0.03,15.95,48.3,37.23,39.26,11956733952.0,803328.0,11890.7
conv_64_512_3x3_g8,2.5,0.03,11.73,0.04,35.88,34.77,2994898944.0,201216.0,6158.7
conv_padding,0.48,0.02,1.74,0.03,37.88,34.0,4840226816.0,295424.0,1332.97
conv_nn_Padding,0.53,0.03,3.92,47.68,42.38,34.0,4840226816.0,295424.0,3333.8


In [110]:
# amp + channels last
module_collection_stats = pd.DataFrame(index=module_collection.keys(), 
                                               columns=['forward time mean(ms)', 'forward time std(ms)',
                                                        'backward time mean(ms)', 'backward time std(ms)',
                                                        'forward memory(MB)','backward memory(MB)',
                                                        'macs', 'parameters','time total(ms)'])
        
for module_name, module_value in module_collection.items():
    module_collection_stats.loc[module_name] = tracker(module_value, inputs=shape, device=device, track_backward=True, channels_last=True, amp=True)

module_collection_stats

Unnamed: 0,forward time mean(ms),forward time std(ms),backward time mean(ms),backward time std(ms),forward memory(MB),backward memory(MB),macs,parameters,time total(ms)
conv_64_512_1x1,0.42,0.06,3.3,50.79,36.22,34.0,545259520.0,33280.0,2849.2
conv_64_512_3x3,0.49,0.04,5.26,72.27,37.88,34.0,4690151424.0,295424.0,2241.06
bottleneck_64_512_3x3,0.66,0.13,2.07,0.23,38.66,34.0,2383218688.0,150048.0,2849.95
conv_64_512_3x3_g2,10.57,0.03,15.96,48.94,37.23,39.26,11956733952.0,803328.0,11806.86
conv_64_512_3x3_g8,2.52,0.02,11.73,0.03,35.88,34.77,2994898944.0,201216.0,5684.43
conv_padding,0.49,0.02,1.73,0.05,37.88,34.0,4840226816.0,295424.0,1455.89
conv_nn_Padding,0.54,0.05,3.84,45.69,42.38,34.0,4840226816.0,295424.0,3366.31
