<a href="https://colab.research.google.com/github/AndreSlavescu/EasyAI/blob/main/MLSystemsGroup_Lecture1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.linear2(x)
        return x

x = torch.randn(1, 10)
model = SimpleNet()
output = model(x)
print(output)

tensor([[ 0.5874, -0.0648]], grad_fn=<AddmmBackward0>)


In [2]:
"""
Comparing the FX Graphs with and without torch.no_grad
"""

from torch.fx import symbolic_trace
import time

time_average_no_grad = 0
time_average_with_grad = 0
iters = 100

for _ in range(iters):
  with torch.no_grad():
      x = torch.randn(1, 10)
      start_no_grad = time.time()
      output_no_grad = model(x)
      end_no_grad = time.time()
      time_average_no_grad += end_no_grad - start_no_grad

  x = torch.randn(1, 10)
  start_with_grad = time.time()
  output_with_grad = model(x)
  end_with_grad = time.time()
  time_average_with_grad += end_with_grad - start_with_grad

print(f'Time with no_grad: {round(time_average_no_grad / iters, 6)} seconds')
print(f'Time with grad: {round(time_average_with_grad / iters, 6)} seconds')

with torch.no_grad():
    traced_model_no_grad = symbolic_trace(model)

traced_model_with_grad = symbolic_trace(model)

print("\nGraph with torch.no_grad:")
print(traced_model_no_grad.graph)

print("\nGraph without torch.no_grad:")
print(traced_model_with_grad.graph)

Time with no_grad: 5.3e-05 seconds
Time with grad: 0.000113 seconds

Graph with torch.no_grad:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %linear1 : [num_users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%linear1,), kwargs = {})
    %linear2 : [num_users=1] = call_module[target=linear2](args = (%relu,), kwargs = {})
    return linear2

Graph without torch.no_grad:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %linear1 : [num_users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%linear1,), kwargs = {})
    %linear2 : [num_users=1] = call_module[target=linear2](args = (%relu,), kwargs = {})
    return linear2


In [3]:
"""
Looking at dispatched operators with trace.json
"""
import torch.profiler

with torch.no_grad():
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU],
        record_shapes=True,
        profile_memory=True
    ) as prof_no_grad:
        output_no_grad = model(x)

prof_no_grad.export_chrome_trace("trace_no_grad.json")

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU],
        record_shapes=True,
        profile_memory=True
) as prof_with_grad:
    output_with_grad = model(x)

prof_with_grad.export_chrome_trace("trace_with_grad.json")

# View Trace

Visit:

[chrome://tracing/](chrome://tracing)

to view the trace.