In [1]:
import time

import numpy as np
import torch
import timm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = {}
args['mode']       = 'default'
args['model']      = 'resnet50'
args['steps']      = 10
args['backend']    = 'inductor'
args['batch_size'] = 128

ordinal = {1 : 'st', 2 : 'nd', 3 : 'rd'}

In [3]:
def run_model(model, inputs, steps = 20):
    
    model     = model.cuda()
    optimizer = torch.optim.Adam(model.parameters())
    
    times = []
    for step in range(1, steps + 1):
        
        start = time.time()
        optimizer.zero_grad()
        
        output = model(inputs.cuda())
        if not isinstance(output, torch.Tensor): output = output.logits
        
        output.sum().backward()
        
        optimizer.step()
        end = time.time()
        
        times.append(end - start)
        
        ordinal_num = f'{step}{ordinal[step]}' if step in ordinal.keys() else f'{step}th'
        print(f'time for {ordinal_num} forward pass is {end - start:.3f}')
        
    median = np.median(times)
    print(f'median step time is {median:.3f}')

In [4]:
model = torch.hub.load('pytorch/vision:v0.10.0', args['model'], pretrained = True)

if args['mode'] in ['default', 'reduce-overhead']:
    model = torch.compile(model, mode = args['mode'], backend = args['backend'])

Using cache found in /home/jovyan/.cache/torch/hub/pytorch_vision_v0.10.0


In [5]:
inputs = torch.randn(args['batch_size'], 3, 224, 224)
run_model(model, inputs, args['steps'])



time for 1st forward pass is 18.521
time for 2nd forward pass is 0.059
time for 3rd forward pass is 0.147
time for 4th forward pass is 0.147
time for 5th forward pass is 0.147
time for 6th forward pass is 0.147
time for 7th forward pass is 0.147
time for 8th forward pass is 0.147
time for 9th forward pass is 0.147
time for 10th forward pass is 0.147
median step time is 0.147


Process ForkProcess-31:
Process ForkProcess-4:
Process ForkProcess-12:
Process ForkProcess-25:
Process ForkProcess-18:
Process ForkProcess-22:
Process ForkProcess-2:
Process ForkProcess-8:
Process ForkProcess-26:
Process ForkProcess-19:
Process ForkProcess-14:
Process ForkProcess-23:
Process ForkProcess-1:
Process ForkProcess-28:
Process ForkProcess-21:
Process ForkProcess-20:
Process ForkProcess-24:
Process ForkProcess-29:
Process ForkProcess-6:
Process ForkProcess-27:
Process ForkProcess-13:
Process ForkProcess-32:
Process ForkProcess-7:
Process ForkProcess-30:
Process ForkProcess-3:
Process ForkProcess-17:
Traceback (most recent call last):
Process ForkProcess-5:
Process ForkProcess-16:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Process ForkProcess-15:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last

In [6]:
model = timm.create_model(args['model'], pretrained = True)
if args['mode'] in ['default', 'reduce-overhead']:
    model = torch.compile(model, mode = args['mode'], backend = args['backend'])
    
inputs = torch.randn(args['batch_size'], 3, 224, 224)
run_model(model, inputs, args['steps'])

time for 1st forward pass is 13.091
time for 2nd forward pass is 0.143
time for 3rd forward pass is 0.147
time for 4th forward pass is 0.147
time for 5th forward pass is 0.147
time for 6th forward pass is 0.147
time for 7th forward pass is 0.147
time for 8th forward pass is 0.147
time for 9th forward pass is 0.147
time for 10th forward pass is 0.147
median step time is 0.147
