## Objective
Test model speed

In [1]:
import time

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch_tensorrt

from src.model import ModelSpikeSorter
from src.utils import random_seed

In [2]:
dtype = torch.float16

full_model = ModelSpikeSorter.load("/data/MEAprojects/DLSpikeSorter/models/v0_4_4/5118/230101_135307_305876")
SAVE_PATH = "/data/MEAprojects/DLSpikeSorter/models/v0_4_4/model_speed_rerun.npy"

# full_model = ModelSpikeSorter.load("/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/dl_models/240318/c/240318_165245_967091")
# SAVE_PATH = "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/dl_models/240318/model_speed.npy"

##
model = full_model.model.conv.to(dtype=dtype, device="cuda")

## Computation time vs batch size (num sample elecs)

In [None]:
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 1020, 2048]
dtype = torch.float16

random_seed(501)

# Start testing input sizes

def get_inputs(size):
    return torch.rand(size, full_model.num_channels_in, full_model.sample_size, dtype=dtype, device="cuda")

speed_avgs = []
for size in batch_sizes:
    # Compile model
    model_rt = torch.jit.trace(model, [get_inputs(size)])
    model_rt = torch_tensorrt.compile(model_rt, inputs=[torch_tensorrt.Input((size, full_model.num_channels_in, full_model.sample_size), dtype=dtype)], enabled_precisions={dtype})
    
    # Warm up gpu
    with torch.no_grad():
        for _ in range(50):
            model_rt(get_inputs(size))
            torch.cuda.synchronize()
            
    # Run speed tests
    speeds = []
    with torch.no_grad():
        for _ in range(100):
            inputs = get_inputs(size)
            start_time = time.perf_counter()
            model_rt(inputs)
            torch.cuda.synchronize()
            end_time = time.perf_counter()
            speed = (end_time - start_time) * 1000
            speeds.append(speed)
    speed_avgs.append(np.mean(speeds))
    print(f"{size} samples: {speeds[-1]:.2f} ms")

In [4]:
# np.save(SAVE_PATH, [batch_sizes, speed_avgs])

## Computation time vs input window size

In [None]:
# MEA model
input_sizes = [81, 200]

# Neuropixels model

dtype = torch.float16

random_seed(501)

# Start testing input sizes

def get_inputs(size):
    return torch.rand(size, full_model.num_channels_in, full_model.sample_size, dtype=dtype, device="cuda")

speed_avgs = []
for size in batch_sizes:
    # Compile model
    model_rt = torch.jit.trace(model, [get_inputs(size)])
    model_rt = torch_tensorrt.compile(model_rt, inputs=[torch_tensorrt.Input((size, full_model.num_channels_in, full_model.sample_size), dtype=dtype)], enabled_precisions={dtype})
    
    # Warm up gpu
    with torch.no_grad():
        for _ in range(50):
            model_rt(get_inputs(size))
            torch.cuda.synchronize()
            
    # Run speed tests
    speeds = []
    with torch.no_grad():
        for _ in range(100):
            inputs = get_inputs(size)
            start_time = time.perf_counter()
            model_rt(inputs)
            torch.cuda.synchronize()
            end_time = time.perf_counter()
            speed = (end_time - start_time) * 1000
            speeds.append(speed)
    speed_avgs.append(np.mean(speeds))
    print(f"{size} samples: {speeds[-1]:.2f} ms")