In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time

import torch.profiler

In [17]:
# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Using device: {device}")


Using device: cuda


In [18]:
# Define a simple CNN for MNIST classification
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        # Fully connected layers
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        # Pooling and activation
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Input x: [batch_size, 1, 28, 28]
        x = self.relu(self.conv1(x))
        x = self.pool(x)  # [batch_size, 16, 14, 14]
        x = self.relu(self.conv2(x))
        x = self.pool(x)  # [batch_size, 32, 7, 7]
        x = x.view(x.size(0), -1)  # Flatten: [batch_size, 32*7*7]
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
# Data preprocessing and loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize using MNIST stats
])

# Download MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset  = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

batch_size = 64
train_loader = DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True,
    num_workers=2, pin_memory=True  # Using pin_memory can speed up host to GPU copies
)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=False,
    num_workers=2, pin_memory=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:04<00:00, 2.01MB/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 322kB/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 2.74MB/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [19]:

# Initialize the model, loss function, and optimizer
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [21]:
profiler_config = torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),
    record_shapes=True,
    profile_memory=True,
    with_stack=True
)

In [22]:
def train(num_epochs):
    for epoch in range(num_epochs):
        epoch_start = time.time()
        running_loss = 0.0

        with profiler_config as prof:
            for i, (images, labels) in enumerate(train_loader):
                # Move data to the GPU
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                prof.step()

                running_loss += loss.item()

                # Print status every 100 steps
                if (i + 1) % 100 == 0:
                    print(f"Epoch [{epoch + 1}/{num_epochs}], "
                        f"Step [{i + 1}/{len(train_loader)}], "
                        f"Loss: {running_loss / 100:.4f}")
                    running_loss = 0.0

        # Synchronize GPU to get accurate timing and then print epoch time
        torch.cuda.synchronize()  # Ensure all GPU work is done
        epoch_end = time.time()
        print(f"Epoch {epoch + 1} finished in {epoch_end - epoch_start:.2f} seconds")

In [23]:

model.train()

SimpleCNN(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=1568, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
)

In [24]:
%%time

train(3)

Epoch [1/3], Step [100/938], Loss: 0.7643
Epoch [1/3], Step [200/938], Loss: 0.2405
Epoch [1/3], Step [300/938], Loss: 0.1450
Epoch [1/3], Step [400/938], Loss: 0.1320
Epoch [1/3], Step [500/938], Loss: 0.0988
Epoch [1/3], Step [600/938], Loss: 0.0826
Epoch [1/3], Step [700/938], Loss: 0.0872
Epoch [1/3], Step [800/938], Loss: 0.0728
Epoch [1/3], Step [900/938], Loss: 0.0662
Epoch 1 finished in 25.96 seconds
Epoch [2/3], Step [100/938], Loss: 0.0634
Epoch [2/3], Step [200/938], Loss: 0.0524
Epoch [2/3], Step [300/938], Loss: 0.0599
Epoch [2/3], Step [400/938], Loss: 0.0398
Epoch [2/3], Step [500/938], Loss: 0.0488
Epoch [2/3], Step [600/938], Loss: 0.0476
Epoch [2/3], Step [700/938], Loss: 0.0497
Epoch [2/3], Step [800/938], Loss: 0.0482
Epoch [2/3], Step [900/938], Loss: 0.0492
Epoch 2 finished in 21.58 seconds
Epoch [3/3], Step [100/938], Loss: 0.0362
Epoch [3/3], Step [200/938], Loss: 0.0420
Epoch [3/3], Step [300/938], Loss: 0.0324
Epoch [3/3], Step [400/938], Loss: 0.0394
Epoch [3

In [26]:
print(profiler_config.key_averages().table(sort_by="cuda_time_total", row_limit=20))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        37.25%      56.564ms        69.74%     105.905ms      35.302ms     104.813ms        46.92%     143.240ms      47.747ms           0 b           0 b     196.50 Kb     -36.12 M