In [5]:
import torch
import torch.nn
import torch.optim
import torch.profiler
import torch.utils.data
import torchvision.datasets
import torchvision.models
import torchvision.transforms as T

In [6]:
transform = T.Compose(
    [T.Resize(224),
     T.ToTensor(),
     T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)

Files already downloaded and verified


In [7]:
device = "cuda" if torch.cuda.is_available() else "mps"

In [8]:
print(f"Using device: {device}")

Using device: mps


In [9]:
import time
model = torchvision.models.resnet18(weights='IMAGENET1K_V1').to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("start training")
start_time = time.time()
model.train()
end_time = time.time()
print(f"end training, training time: {end_time - start_time}s")

start training
end training, training time: 0.00011801719665527344s


In [6]:
def train(data):
    inputs, labels = data[0].to(device=device), data[1].to(device=device)
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [7]:
with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/resnet18'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
) as prof:
    for step, batch_data in enumerate(train_loader):
        prof.step()  # Need to call this at each step to notify profiler of steps' boundary.
        if step >= 1 + 1 + 3:
            break
        train(batch_data)

STAGE:2024-05-18 12:20:39 17995:3568965 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event
STAGE:2024-05-18 12:20:40 17995:3568965 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-18 12:20:40 17995:3568965 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
