[Reference](https://levelup.gitconnected.com/pytorch-a-comprehensive-performance-tuning-guide-a917d18bc6c2)

# 1. Always use Mixed Precision

In [1]:
import torch

# Assume model, optimizer, data_loader, loss_fn are defined
model = torch.nn.Linear(1024, 1024).cuda()
optimizer = torch.optim.Adam(model.parameters())
# Gradient scaler is crucial for stability
scaler = torch.cuda.amp.GradScaler()

# Sample Data
input_data = torch.randn(64, 1024).cuda()
target = torch.randn(64, 1024).cuda()

# Training step with autocast
optimizer.zero_grad()
with torch.autocast(device_type="cuda", dtype=torch.float16):
    output = model(input_data)
    loss = torch.mean((output - target)**2) # Example loss

# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()

# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If gradients aren't inf/NaN, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)

# Updates the scale for next iteration.
scaler.update()

print(f"Loss: {loss.item()}")

  scaler = torch.cuda.amp.GradScaler()


Loss: 1.3273839950561523


# 2. Use PyTorch 2.0 (or later) If Possible

In [2]:
import torch

# Define a regular Python function using PyTorch operations
def my_complex_function(a, b):
    x = torch.sin(a) + torch.cos(b)
    y = torch.tanh(x * a)
    return y / (torch.abs(b) + 1e-6)

# Compile the function
compiled_function = torch.compile(my_complex_function)

# Use the compiled function - first run might be slower due to compilation
input_a = torch.randn(1000, 1000).cuda() # Best results often on GPU
input_b = torch.randn(1000, 1000).cuda()

# Warm-up run (optional, but good practice for timing)
_ = compiled_function(input_a, input_b)

# Timed run
import time
start = time.time()
output = compiled_function(input_a, input_b)
end = time.time()
print(f"Compiled function execution time: {end - start:.4f} seconds")

Compiled function execution time: 0.0003 seconds


# 3. Never Forget Inference Mode

In [3]:
import torch

model = torch.nn.Linear(10, 2) # Example model
input_tensor = torch.randn(1, 10)

# Using torch.no_grad()
with torch.no_grad():
    output_no_grad = model(input_tensor)
print(f"Output (no_grad) requires_grad: {output_no_grad.requires_grad}") # Output: False

# Using torch.inference_mode() - Recommended
with torch.inference_mode():
    output_inference_mode = model(input_tensor)
print(f"Output (inference_mode) requires_grad: {output_inference_mode.requires_grad}") # Output: False

Output (no_grad) requires_grad: False
Output (inference_mode) requires_grad: False


# 4. Use Channels-Last Memory Format for CNNs

In [4]:
import torch
import torch.nn as nn

N, C, H, W = 32, 3, 224, 224 # Example dimensions
model = nn.Conv2d(C, 64, kernel_size=3, stride=1, padding=1).cuda()
input_tensor = torch.randn(N, C, H, W).cuda()

# Convert model and input to channels-last
model = model.to(memory_format=torch.channels_last)
input_tensor = input_tensor.to(memory_format=torch.channels_last)

print(f"Model parameter memory format: {model.weight.stride()}") # Stride indicates memory layout
print(f"Input tensor memory format: {input_tensor.stride()}")

# Perform operations - PyTorch handles the format internally
output = model(input_tensor)
print(f"Output tensor memory format: {output.stride()}")

Model parameter memory format: (27, 1, 9, 3)
Input tensor memory format: (150528, 1, 672, 3)
Output tensor memory format: (3211264, 1, 14336, 64)


# 5. Perform Graph Surgery where Required

In [5]:
import torch
import torch.fx as fx

class SimpleNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)

    def forward(self, x):
        x = self.linear(x)
        x = torch.relu(x)
        return x

module = SimpleNet()
symbolic_traced : fx.GraphModule = fx.symbolic_trace(module)

# Print the traced graph representation
print("--- FX Graph ---")
print(symbolic_traced.graph)

# Print the generated Python code from the graph
print("\n--- FX Code ---")
print(symbolic_traced.code)

--- FX Graph ---
graph():
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%linear,), kwargs = {})
    return relu

--- FX Code ---



def forward(self, x):
    linear = self.linear(x);  x = None
    relu = torch.relu(linear);  linear = None
    return relu
    


# 6. Use Activation Checkpointing

In [7]:
from torch.utils import checkpoint as chkpt

# regular invocation with default activation caching
result = module(*args, **kwargs)  # module is part of a model
# checkpointed invocation
result = chkpt.checkpoint(module, *args, **kwargs)

# 7. Diligent Optimizer Choices

The bitsandbytes library developed by Tim Dettmers offers 8-bit versions of many algorithms found in torch.optim, often enabling efficient state tensor management between host and GPU memory as needed.

# 8. Autotune Convolutions Using cuDNN Benchmarking

In [8]:
import torch
import torch.nn as nn

# Enable benchmark mode (usually at the start of your script)
torch.backends.cudnn.benchmark = True

# Define and run your CNN model as usual
model = nn.Sequential(
    nn.Conv2d(3, 64, 7, stride=2, padding=3),
    nn.ReLU(),
    nn.MaxPool2d(3, 2, 1),
    # ... more layers
).cuda()

# Fixed input size helps benchmark mode
input_tensor = torch.randn(64, 3, 224, 224).cuda()

# The first forward pass might be slightly slower as benchmarking occurs
print("Running first forward pass (benchmarking)...")
output = model(input_tensor)
print("First pass complete.")

# Subsequent passes should use the optimized algorithms
print("Running second forward pass...")
output = model(input_tensor)
print("Second pass complete.")

Running first forward pass (benchmarking)...
First pass complete.
Running second forward pass...
Second pass complete.


# 9. Enable Asynchronous Data Loading

In [9]:
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, size=1000):
        self.data = torch.randn(size, 128)
        self.labels = torch.randint(0, 10, (size,))
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        # Simulate some data loading/processing
        # time.sleep(0.001) # Uncomment to simulate I/O delay
        return self.data[idx], self.labels[idx]

dataset = MyDataset()

# Optimized DataLoader configuration
# Rule of thumb for num_workers: Start with 4 * num_gpus, benchmark, adjust.
# Requires sufficient CPU cores and memory.
optimized_loader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,  # Use multiple processes for loading
    pin_memory=True # Speeds up CPU-to-GPU transfer (if using CUDA)
)



# 10. Optimize Memory Usage
To save memory by resetting gradient tensors to None instead of updating them to dense tensors filled with zeros, simply call optimizer.zero_grad(set_to_none=True) or model.zero_grad(set_to_none=True). This not only conserves memory but also enhances performance during training.