## Network analysis
A network to analyse various components associated with the network. This notebook
should be used to better understand the network in order to make optimisations to it
to improve performance and accuracy.

In [None]:
import time

import torch

from ssdv2.models import SSDv2

In [None]:
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

NUM_CLASSES = 80
MODEL = SSDv2(NUM_CLASSES).to(dtype=DTYPE, device=DEVICE)

### Investigate model size

In [None]:
print(f"SSDv2 params: {sum([p.numel() for p in MODEL.parameters()]):.3e}")
print(f"Backbone params: {sum([p.numel() for p in MODEL.backbone.parameters()]):.3e}")
print(f"Neck params: {sum([p.numel() for p in MODEL.neck.parameters()]):.3e}")
print(f"Head params: {sum([p.numel() for p in MODEL.head.parameters()]):.3e}")

### Investigate execution times

In [None]:
# Create a "batch of images" to infer on
images = torch.rand((2, 3, 640, 640), dtype=DTYPE, device=DEVICE)

all_start = time.time()

# Run backbone
start = time.time()
fms = MODEL.backbone.forward(images)
print(f"Backbone time: {time.time() - start:.3f}s")

# Run neck
start = time.time()
fms = MODEL.neck.forward(fms)
print(f"Neck time: {time.time() - start:.3f}s")

# Run head
start = time.time()
logits, centerness, boxes = MODEL.head.forward(fms)
print(f"Head time: {time.time() - start:.3f}s")

total_time = time.time() - all_start
print(f"Total time: {total_time:.3f}s")
print(f"Total FPS: {images.shape[0] / total_time:.3f}")