In [2]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.profiler import profile, record_function, ProfilerActivity

### Model

In [3]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 32 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

### Data

In [4]:
# Load CIFAR-10 dataset 
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

Files already downloaded and verified


### Training loop

In [5]:
# Function to train the model
def train(model, trainloader, criterion, optimizer, device, epochs=1):
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            if i == 200:
                break

### Execution time profiling

In [6]:
# utility function for running the profiler 
def run_profiler(trainloader, model, profile_memory=False):
    device = 'cuda'
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
    
    with profile(activities=activities, record_shapes=True, profile_memory=profile_memory) as prof:
        with record_function("training"):
            train(model, trainloader, criterion, optimizer, device, epochs=1)

    if profile_memory == False:
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    else:
         print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

In [7]:
model = SimpleCNN()
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)
run_profiler(trainloader, model)

STAGE:2024-05-14 04:27:28 10122:10122 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-05-14 04:27:29 10122:10122 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-14 04:27:29 10122:10122 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               training        23.76%     360.249ms        71.31%        1.081s        1.081s       0.000us         0.00%      68.837ms      68.837ms             1  
autograd::engine::evaluate_function: ConvolutionBack...         0.15%       2.271ms         3.63%      55.037ms     136.908us       0.000us         0.00%      34.770ms      86.493us           402  
         

In [8]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(3 * 32 * 32, 10)

    def forward(self, x):
        x = x.view(-1, 3 * 32 * 32)
        x = self.fc1(x)
        return x

In [9]:
model = SimpleNet()
run_profiler(trainloader, model)

STAGE:2024-05-14 04:27:33 10122:10122 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-05-14 04:27:34 10122:10122 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-14 04:27:34 10122:10122 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               training        23.91%     192.128ms        84.59%     679.785ms     679.785ms       0.000us         0.00%      39.361ms      39.361ms             1  
                                           aten::linear         0.10%     768.000us         1.57%      12.605ms      62.711us       0.000us         0.00%      16.955ms      84.353us           201  
         

### Memory consumption profiling

In [10]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)
model = SimpleCNN()
run_profiler(trainloader, model, profile_memory=True)

STAGE:2024-05-14 04:27:35 10122:10122 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-05-14 04:27:36 10122:10122 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-14 04:27:36 10122:10122 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        22.44%     224.849ms        22.74%     227.911ms       1.134ms       0.000us         0.00%       0.000us       0.000us      75.42 Mb      75.42 Mb           0 b           0 

In [11]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4)
model = SimpleCNN()
run_profiler(trainloader, model, profile_memory=True)

STAGE:2024-05-14 04:27:41 10122:10122 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-05-14 04:27:42 10122:10122 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-14 04:27:42 10122:10122 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        13.45%     127.135ms        13.74%     129.910ms     646.318us       0.000us         0.00%       0.000us       0.000us       9.43 Mb       9.43 Mb           0 b           0 