
# Using PyTorch profiler for FLOPS and Model Sizing

Hi everyone, here's how you can measure FLOPS in GB and FLOPS of your model. Useful for the lab and project!




In [1]:

import torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity


###  Define a Simple CNN Model

You'll have to define your model first in order to use the torch profiler and run it through one forward pass with a dummy input to track the FLOPS and parameters. Below is a simple CNN with 3 increasing filters, a max pool, fc layers, and RELU. Since we're using pytorch operations, they get counted in the pytorch profiler.

In [2]:

#for CIFAR10
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 128, 512)
        self.fc2 = nn.Linear(512, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(-1, 4 * 4 * 128)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)


### Count Parameters & Model Size

In [3]:

#we set our model to what we defined earlier
model = SimpleCNN()

param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Parameters: {param_count:,}")
param_count_m = param_count / 1_000_000  # convert to millions (used for lab grading!!!)
print(f"Model Size: {param_count_m:.2f} M parameters")
model_size_gb = param_count * 4 / (1024 * 1024 * 1024)
print(f"Model Size: {model_size_gb} GB")


Total Parameters: 1,147,466
Model Size: 1.15 M parameters
Model Size: 0.004274643957614899 GB


### Forward Pass Check

In [4]:

dummy_input = torch.randn(1, 3, 32, 32)  #batch size of one, we use random numbers as dummy input
model.eval() #do not update the gradients, just evaluate
with torch.no_grad(): #run it through
    output = model(dummy_input)
print("Output shape:", output.shape) #10 classes for cifar-10


Output shape: torch.Size([1, 10])


###  FLOP Profiling

In [5]:
#Note THAT profileracitvity is set for CPU but can be set to .CUDA to run on cuda
#if you're setting it for CUDA, make sure to also set the cuda for the model as well

with profile(activities=[ProfilerActivity.CPU],record_shapes=True,with_flops=True) as prof:
    with torch.no_grad():
        model(dummy_input)

events = prof.events()


###  Total FLOPs

In [6]:

total_flops = sum(e.flops for e in events if e.flops is not None)
print(f"Total FLOPs: {total_flops:,}")
print(f"GFLOPs: {total_flops / 1e9:.3f}") #1e9 moves it three decimal places


Total FLOPs: 22,751,232
GFLOPs: 0.023


###  Top 10 Most Expensive Operations

In [7]:

flop_events = [(e.flops, e.key) for e in events if e.flops is not None]
flop_events.sort(reverse=True)

print("Top 10 most expensive operations:")
for i, (flops, op) in enumerate(flop_events[:10]): #top 10
    pct = (flops / total_flops) * 100
    print(f"{i+1:2d}. {op[:45]:45s} | {flops:>12,} FLOPs ({pct:5.1f}%)")


Top 10 most expensive operations:
 1. aten::conv2d                                  |    9,437,184 FLOPs ( 41.5%)
 2. aten::conv2d                                  |    9,437,184 FLOPs ( 41.5%)
 3. aten::addmm                                   |    2,097,152 FLOPs (  9.2%)
 4. aten::conv2d                                  |    1,769,472 FLOPs (  7.8%)
 5. aten::addmm                                   |       10,240 FLOPs (  0.0%)
 6. aten::view                                    |            0 FLOPs (  0.0%)
 7. aten::view                                    |            0 FLOPs (  0.0%)
 8. aten::view                                    |            0 FLOPs (  0.0%)
 9. aten::view                                    |            0 FLOPs (  0.0%)
10. aten::view                                    |            0 FLOPs (  0.0%)


### Layer-by-Layer Parameter Breakdown

In [8]:

print("Layer parameter breakdown:")
for name, module in model.named_modules():
    if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Linear)): #include more nn modules here to check layers
        layer_params = sum(p.numel() for p in module.parameters())
        print(f"{name:25s} | {layer_params:>10,} params")


Layer parameter breakdown:
conv1                     |        896 params
conv2                     |     18,496 params
conv3                     |     73,856 params
fc1                       |  1,049,088 params
fc2                       |      5,130 params



### Paste your model here

Replace `SimpleCNN` with your own model below and re-run the steps if you want to calculate the Gflops, parameters in M.


In [9]:

# class someModel(nn.Module):
#......................

