# SGNLite Demo: Tennis Swing Detection

This notebook demonstrates how to use SGNLite for tennis swing detection and classification.

In [None]:
# Setup
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
from sgnlite.model import SGNLite

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 1. Load the Model

In [None]:
# Create model
model = SGNLite(
    in_channels=2,      # x, y coordinates
    num_joints=17,      # COCO-17 keypoints
    num_classes=6       # 6 tennis stroke types
)

print(f'SGNLite Parameters: {model.get_num_params():,}')

# Load trained weights (update path to your checkpoint)
# checkpoint = torch.load('path/to/checkpoint.pt', map_location=device)
# model.load_state_dict(checkpoint['state_dict'])

model = model.to(device)
model.eval()

## 2. Test with Random Input

In [None]:
# Create dummy input: [batch, channels, frames, joints]
batch_size = 4
num_frames = 20
num_joints = 17
num_channels = 2  # x, y

x = torch.randn(batch_size, num_channels, num_frames, num_joints).to(device)
print(f'Input shape: {x.shape}')

# Forward pass
with torch.no_grad():
    logits = model(x)
    probs = torch.softmax(logits, dim=1)

print(f'Output shape: {logits.shape}')
print(f'Predictions: {probs.argmax(dim=1).tolist()}')

## 3. Class Labels

In [None]:
# Tennis stroke classes
class_names = [
    'feed',           # 0
    'ground_stroke',  # 1
    'negative',       # 2
    'overhead',       # 3
    'serve',          # 4
    'volley'          # 5
]

# Print predictions with class names
for i in range(batch_size):
    pred_idx = probs[i].argmax().item()
    confidence = probs[i, pred_idx].item()
    print(f'Sample {i}: {class_names[pred_idx]} ({confidence:.2%})')

## 4. Inference Speed Benchmark

In [None]:
import time

# Benchmark
num_runs = 100
x = torch.randn(64, 2, 20, 17).to(device)

# Warmup
for _ in range(10):
    with torch.no_grad():
        _ = model(x)

if device.type == 'cuda':
    torch.cuda.synchronize()

# Timed runs
start = time.perf_counter()
for _ in range(num_runs):
    with torch.no_grad():
        _ = model(x)

if device.type == 'cuda':
    torch.cuda.synchronize()

elapsed = time.perf_counter() - start
fps = (num_runs * 64) / elapsed

print(f'Inference Speed: {fps:.0f} FPS')
print(f'Latency: {1000 * elapsed / num_runs:.2f} ms per batch')

## 5. Model Architecture Summary

In [None]:
print('SGNLite Architecture:')
print('=' * 50)
print(f'Input: [N, 2, 20, 17] (batch, xy, frames, joints)')
print(f'Embedding dim: {model.embed_dim}')
print(f'Transformer blocks: {len(model.blocks)}')
print(f'Attention heads: 6')
print(f'Output: [N, 6] (6 class logits)')
print('=' * 50)
print(f'Total parameters: {model.get_num_params():,}')