In [1]:
import numpy as np
import torch

In [2]:
from tools.train import rt_detr_config

device = torch.device("cuda")
model = rt_detr_config().model.to(device).eval()

Load PResNet18 state_dict


In [26]:
from powerlines.sahi import multiscale_image_patches

# Construct random test input
# scales = [256, 512, 1024]
scales = [512, 1024]

original_input = torch.randn((1, 3, 3000, 4096), dtype=torch.float)
patches = multiscale_image_patches(
    original_input, patch_sizes=scales, step_size_fraction=0.8, predict_on_full_image=True
)
num_patch_inputs = len(patches.patches)
num_batches_per_frame = num_patch_inputs / 38

In [27]:
num_patch_inputs

101

In [22]:
num_batches_per_frame

2.6578947368421053

In [18]:
min_batch_size = 2
max_batch_size = 58

In [19]:
from tqdm import tqdm

batch_size_step = 4
# batch_sizes = np.linspace(min_batch_size, max_batch_size, num=int((max_batch_size - min_batch_size) / batch_size_step + 1)).astype(int)
batch_sizes = [38]
means, stds = [], []

for batch_size in batch_sizes:
    print(f"Evalutating batch size {batch_size}")
    
    input = torch.randn((batch_size, 3, 640, 640), dtype=torch.float).to(device)
    num_batches_per_frame = num_patch_inputs / batch_size

    # Timers
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    repetitions = int(1000)
    warm_up_rounds = 50
    timings = torch.zeros((repetitions, 1))

    # Run on GPU
    with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
        # GPU warm-up
        for _ in range(warm_up_rounds):
            _ = model(input)

        # Measure inference time
        with torch.no_grad():
            for rep in tqdm(range(repetitions), desc="Inference"):
                starter.record()
                _ = model(input)
                ender.record()

                # Synchronize GPU
                torch.cuda.synchronize()
                elapsed_time = starter.elapsed_time(ender)
                timings[rep] = elapsed_time

    # Compute inference time (per entire frame)
    mean_syn = timings.mean() * num_batches_per_frame
    std_syn = timings.std() * num_batches_per_frame
    means.append(mean_syn)
    stds.append(std_syn)

Evalutating batch size 38


Inference:  25%|██▍       | 246/1000 [00:27<01:23,  9.02it/s]

KeyboardInterrupt



In [16]:
means, stds

([tensor(290.1152)], [tensor(0.6649)])

In [9]:
# For finding optimal batch size
idx = torch.as_tensor(means).argmin()
optimal_batch_size = batch_sizes[idx.item()]
means[idx.item()], stds[idx.item()]