This is a little memory and runtime profiler for checkpoints. First, let's define some constants.

In [None]:
from time import time

CHECKPOINT_PATH = "./checkpoints/checkpoint.pt"
DEVICE = "cpu"  # Can be "cpu" or "cuda."
TRACE_PATH = f"./exports/traces/trace-{time()}.json"

Then we'll load the checkpoint and instantiate the UltraZoom upscaler.

In [None]:
import torch

from src.ultrazoom.model import UltraZoom

checkpoint = torch.load(
    CHECKPOINT_PATH,
    map_location=DEVICE,
    weights_only=True,
)

model = UltraZoom(**checkpoint["model_args"])

model.add_weight_norms()

state_dict = checkpoint["model"]

# Compensate for compiled state dict.
for key in list(state_dict.keys()):
    state_dict[key.replace("_orig_mod.", "")] = state_dict.pop(key)

model.load_state_dict(state_dict)

model.remove_parameterizations()

model = model.to(DEVICE)

print("Model loaded successfully")

Now let's make some fake image data.

In [None]:
x = torch.randn(1, 3, 512, 512).to(DEVICE)

This next block we'll run a forward pass on the fake data within the context of the profiler.

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

match DEVICE:
    case "cpu":
        activities = [ProfilerActivity.CPU]
    case "cuda":
        activities = [ProfilerActivity.CUDA]
    case _:
        raise ValueError(f"Unsupported device: {DEVICE}")

with profile(activities=activities, profile_memory=True, record_shapes=True) as profiler:
    with record_function("model_inference"):
        y_pred = model(x)

Now let's print out the data that the profiler collected for us.

In [None]:
print(profiler.key_averages().table())

Finally, we'll export a Chrome trace so we can view it in a Chromium-compatible web browser.

In [None]:
profiler.export_chrome_trace(TRACE_PATH)