# Prepare the data

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torch.utils.tensorboard import SummaryWriter
from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler
from datetime import datetime
import os

torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Transform for CIFAR-10 (RGB, 32x32 → ResNet size 224x224)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 3 channels
])

# Load CIFAR-10
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=5, pin_memory=False)



Using device: cuda
Files already downloaded and verified
TensorBoard log dir: runs/20250519-124055
Epoch 0 [0/98] Loss: 2.4066
Epoch 0 [10/98] Loss: 1.8022
Epoch 0 [20/98] Loss: 1.6009
Epoch 0 [30/98] Loss: 1.5219
Epoch 0 [40/98] Loss: 1.3999
Epoch 0 [50/98] Loss: 1.3942
Epoch 0 [60/98] Loss: 1.3605
Epoch 0 [70/98] Loss: 1.2746
Epoch 0 [80/98] Loss: 1.1651
Epoch 0 [90/98] Loss: 1.1534
✅ Training and profiling complete. View in TensorBoard.


# Training and profiling

In [3]:
# torch.cuda.synchronize() is used during training to ensure accurate profiling, 
# as the profiler may otherwise report shorter or misleading durations due to CUDA's asynchronous execution.

In [None]:
# Define ResNet18 for 10 classes
class ResNetCIFAR(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet18(weights=None)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 10)

    def forward(self, x):
        return self.resnet(x)

model = ResNetCIFAR().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

# TensorBoard Logging
log_dir = os.path.join("runs", datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard log dir: {log_dir}")

EPOCHS = 1

# Profiler setup
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(wait=2, warmup=2, active=6, repeat=1),
    on_trace_ready=tensorboard_trace_handler(log_dir),
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:

    step = 0
    model.train()

    for epoch in range(EPOCHS):
        for batch_idx, (images, labels) in enumerate(train_loader):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            torch.cuda.synchronize()
            optimizer.zero_grad()

            with record_function("model_training_step"):
                outputs = model(images).float()

                with record_function("loss_computation"):
                    loss = F.cross_entropy(outputs, labels)

                loss.backward()
                optimizer.step()

            torch.cuda.synchronize()

            writer.add_scalar("Loss/train", loss.item(), step)
            prof.step()
            step += 1

            if batch_idx % 10 == 0:
                print(f"Epoch {epoch} [{batch_idx}/{len(train_loader)}] Loss: {loss.item():.4f}")

writer.close()
print("Training and profiling complete.")
