In [1]:
import time
from collections import defaultdict

import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
import torchvision.models as models

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Measuring latency without Torch Dispatcher

In [3]:
inp = torch.randn(1, 3, 224, 224, device=device)
mod = models.resnet50().to(device)
optimizer = torch.optim.Adam(mod.parameters(), lr=0.001)

total_start = time.time()

print("=================== Forward =====================")
torch.cuda.synchronize()
start_forward = time.time()
optimizer.zero_grad()
outputs = mod(inp)
torch.cuda.synchronize()
end_forward = time.time()
forward_latency = end_forward - start_forward
print(f"Forward pass latency: {forward_latency} seconds")

print("=================== Backward =====================")
start_backward = time.time()
loss = outputs.sum()
loss.backward()
optimizer.step()
torch.cuda.synchronize()
end_backward = time.time()
backward_latency = end_backward - start_backward
print(f"Backward pass latency: {backward_latency} seconds")

total_end = time.time()
total_latency = total_end - total_start
print(f"Total latency (forward + backward): {total_latency} seconds")


Forward pass latency: 1.48537015914917 seconds
Backward pass latency: 0.4989480972290039 seconds
Total latency (forward + backward): 1.9885234832763672 seconds


In [12]:
batch_size = 1
num_batches = 10

data = [torch.randn(batch_size, 3, 224, 224, device=device) for _ in range(num_batches)]
forward_latency_list, backward_latency_list, total_latency_list = [],[],[]

for i, batch in enumerate(data):
  if i == 0:
    optimizer.zero_grad()
    outputs = mod(inp)
    loss = outputs.sum()
    loss.backward()
    optimizer.step()

  else:
    total_start = time.time()
    print(f"=================== Batch {i+1} =====================")
    print()
    torch.cuda.synchronize()
    start_forward = time.time()
    optimizer.zero_grad()
    outputs = mod(inp)
    torch.cuda.synchronize()
    end_forward = time.time()
    forward_latency = end_forward - start_forward

    start_backward = time.time()
    loss = outputs.sum()
    loss.backward()
    optimizer.step()
    torch.cuda.synchronize()
    end_backward = time.time()
    backward_latency = end_backward - start_backward

    total_end = time.time()
    total_latency = total_end - total_start
    print(f"Total latency (forward + backward): {total_latency} seconds")
    print()
    print(f"Forward pass latency: {forward_latency} seconds")
    print(f"Backward pass latency: {backward_latency} seconds")
    print()

    forward_latency_list.append(forward_latency)
    backward_latency_list.append(backward_latency)
    total_latency_list.append(total_latency)

avg_forward_latency = sum(forward_latency_list) / len(forward_latency_list)
avg_backward_latency = sum(backward_latency_list) / len(backward_latency_list)
avg_total_latency = sum(total_latency_list) / len(total_latency_list)

print()
print(f"Average forward pass latency: {avg_forward_latency} seconds")
print(f"Average backward pass latency: {avg_backward_latency} seconds")
print(f"Average total latency: {avg_total_latency} seconds")



Total latency (forward + backward): 0.09418249130249023 seconds

Forward pass latency: 0.05772280693054199 seconds
Backward pass latency: 0.026458740234375 seconds


Total latency (forward + backward): 0.08638954162597656 seconds

Forward pass latency: 0.05787229537963867 seconds
Backward pass latency: 0.0274200439453125 seconds


Total latency (forward + backward): 0.09773874282836914 seconds

Forward pass latency: 0.07194256782531738 seconds
Backward pass latency: 0.024062633514404297 seconds


Total latency (forward + backward): 0.2598686218261719 seconds

Forward pass latency: 0.2314894199371338 seconds
Backward pass latency: 0.028301715850830078 seconds


Total latency (forward + backward): 0.08493828773498535 seconds

Forward pass latency: 0.05778980255126953 seconds
Backward pass latency: 0.02482295036315918 seconds


Total latency (forward + backward): 0.1203770637512207 seconds

Forward pass latency: 0.0813291072845459 seconds
Backward pass latency: 0.03786134719848633 second

## Measuring latency with Torch Dispatcher

In [10]:
def normalize_tuple(x):
    if not isinstance(x, tuple):
        return (x,)
    return x

class LatencyMeasurementMode(TorchDispatchMode):
    def __init__(self, module=None):
        self.latency_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        self.parents = ['Global']
        if module is not None:
            for name, module in dict(module.named_children()).items():
                module.register_forward_pre_hook(self.enter_module(name))
                module.register_forward_hook(self.exit_module(name))

    def enter_module(self, name):
        def f(module, inputs):
            self.parents.append(name)
            inputs = normalize_tuple(inputs)
            out = self.create_backwards_pop(name)(*inputs)
            return out

        return f

    def exit_module(self, name):
        def f(module, inputs, outputs):
            assert self.parents[-1] == name
            self.parents.pop()
            outputs = normalize_tuple(outputs)
            return self.create_backwards_push(name)(*outputs)
        return f

    def create_backwards_push(self, name):
        class PushState(torch.autograd.Function):
            @staticmethod
            def forward(ctx, *args):
                args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
                if len(args) == 1:
                    return args[0]
                return args

            @staticmethod
            def backward(ctx, *grad_outs):
                self.parents.append(name)
                return grad_outs

        return PushState.apply

    def create_backwards_pop(self, name):
        class PopState(torch.autograd.Function):
            @staticmethod
            def forward(ctx, *args):
                args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
                if len(args) == 1:
                    return args[0]
                return args

            @staticmethod
            def backward(ctx, *grad_outs):
                assert self.parents[-1] == name
                self.parents.pop()
                return grad_outs

        return PopState.apply

    def __enter__(self):
        self.latency_counts.clear()
        super().__enter__()

    def __exit__(self, *args):

        ## Added for popping last layer during backward.
        # print("Pop:", self.parents.pop())
        print()

        ## Added for logging forward & backward latency seperately.
        self.total_forward_latency = sum(self.latency_counts['Global']['forward'].values())
        self.total_backward_latency = sum(self.latency_counts['Global']['backward'].values())

        self.total_latency = sum(self.latency_counts['Global']["total"].values())

        print(f"Total latency (forward + backward): {self.total_latency} seconds")
        print()
        # print("=================== Latency per training steps ===================")
        print(f"Forward pass latency: {self.total_forward_latency} seconds")
        print(f"Backward pass latency: {self.total_backward_latency} seconds")
        print()

        # print("=================== Latency per model modules  ===================")

        # for mod in self.latency_counts.keys():
        #     print(f"Module: ", mod)
        #     for phase in ['forward', 'backward']:
        #         for k, v in self.latency_counts[mod][phase].items():
        #             print(f"{phase} {k} latency: {v} seconds")
        #     print()

        super().__exit__(*args)

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs if kwargs else {}
        torch.cuda.synchronize()
        start_time = time.time()
        out = func(*args, **kwargs)
        torch.cuda.synchronize()
        end_time = time.time()
        latency = end_time - start_time

        func_packet = func._overloadpacket

        ## Added for checking the module if it is forward or backward.(GPT-4)
        ## 1)
        current_phase = 'backward'
        if all(isinstance(arg, torch.Tensor) and arg.grad_fn is None for arg in args):
            current_phase = 'forward'

        ## 2)
        # current_phase = 'backward'
        # if torch.is_grad_enabled():
        #     current_phase = 'forward'

        for par in self.parents:
            self.latency_counts[par]["total"][func_packet] += latency
            self.latency_counts[par][current_phase][func_packet] += latency

        return out

In [None]:
inp = torch.randn(1, 3, 224, 224, device=device)
mod = models.resnet50().to(device)
optimizer = torch.optim.Adam(mod.parameters(), lr=0.001)


latency_counter = LatencyMeasurementMode(mod)

with latency_counter:
    optimizer.zero_grad()
    outputs = mod(inp)
    loss = outputs.sum()
    loss.backward()
    optimizer.step()

In [11]:
latency_counter = LatencyMeasurementMode(mod)

forward_latency_list, backward_latency_list, total_latency_list = [],[],[]

for i, batch in enumerate(data):
  if i == 0:
    optimizer.zero_grad()
    outputs = mod(inp)
    loss = outputs.sum()
    loss.backward()
    optimizer.step()

  else:
    print(f"=================== Batch {i+ 1} =====================")
    latency_counter = LatencyMeasurementMode(mod)
    with latency_counter:
      optimizer.zero_grad()
      outputs = mod(inp)
      loss = outputs.sum()
      loss.backward()
      optimizer.step()

    forward_latency_list.append(latency_counter.total_forward_latency)
    backward_latency_list.append(latency_counter.total_backward_latency)
    total_latency_list.append(latency_counter.total_latency)

avg_forward_latency = sum(forward_latency_list) / len(forward_latency_list)
avg_backward_latency = sum(backward_latency_list) / len(backward_latency_list)
avg_total_latency = sum(total_latency_list) / len(total_latency_list)

print(f"Average forward pass latency: {avg_forward_latency} seconds")
print(f"Average backward pass latency: {avg_backward_latency} seconds")
print(f"Average total latency: {avg_total_latency} seconds")


Total latency (forward + backward): 0.09601140022277832 seconds

Forward pass latency: 0.015318632125854492 seconds
Backward pass latency: 0.08069276809692383 seconds


Total latency (forward + backward): 0.09161567687988281 seconds

Forward pass latency: 0.010662317276000977 seconds
Backward pass latency: 0.08095335960388184 seconds


Total latency (forward + backward): 0.17192530632019043 seconds

Forward pass latency: 0.01895594596862793 seconds
Backward pass latency: 0.1529693603515625 seconds


Total latency (forward + backward): 0.11286115646362305 seconds

Forward pass latency: 0.013003826141357422 seconds
Backward pass latency: 0.09985733032226562 seconds


Total latency (forward + backward): 0.09781455993652344 seconds

Forward pass latency: 0.011237859725952148 seconds
Backward pass latency: 0.08657670021057129 seconds


Total latency (forward + backward): 0.09904956817626953 seconds

Forward pass latency: 0.011846065521240234 seconds
Backward pass latency: 0.087203502655029