# Electrical Symbol Detection - Inference
Run predictions on floor plans using the trained FasterRCNN model

## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path

# Setup paths
if os.path.exists('/content/symbol-detection/python'):
    os.chdir('/content/symbol-detection/python')
    sys.path.insert(0, '/content/symbol-detection/python/src')
    checkpoint_dir = '/content/drive/MyDrive/symbol-detection/checkpoints'
    dataset_dir = '/content/drive/MyDrive/symbol-detection/dataset'
    IN_COLAB = True
else:
    IN_COLAB = False
    checkpoint_dir = Path.cwd().parent.parent / 'checkpoints'
    dataset_dir = Path.cwd().parent.parent / 'dataset'

# Imports
import cv2
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from symbol_detection.inference import SymbolDetectionPredictor

print(f"Setup complete. Checkpoint dir: {checkpoint_dir}")

## 2. Load Trained Model

In [None]:
# Load the best trained model
checkpoint_path = Path(checkpoint_dir) / 'model_epoch_final.pth'
categories_file = Path(dataset_dir) / 'annotations.json'

# Load categories from dataset
with open(categories_file, 'r') as f:
    coco_data = json.load(f)

predictor = SymbolDetectionPredictor(
    checkpoint_path=checkpoint_path,
    num_classes=7,
    categories_file=categories_file,
    confidence_threshold=0.50,
)

print(f"✓ Model loaded: {checkpoint_path.name}")
print(f"✓ Categories: {list(predictor.categories.values())}")

## 3. Inference on Test Images

In [None]:
# Get test images from dataset
images_dir = Path(dataset_dir) / 'images'
test_images = sorted(list(images_dir.glob('*.png')))[:5]  # Use first 5 images

print(f"Found {len(test_images)} test images")
print(f"Processing {len(test_images)} samples...\n")

# Run inference
results = []
for img_path in test_images:
    image = cv2.imread(str(img_path))
    if image is None:
        continue
    
    # Run prediction
    detections = predictor.predict(image, conf_threshold=0.50)
    
    result = {
        'image': img_path.name,
        'num_detections': len(detections),
        'detections': detections,
        'image_shape': image.shape,
    }
    results.append(result)
    
    print(f"{img_path.name}: {len(detections)} symbols detected")

print(f"\n✓ Inference complete")

## 4. Visualize Detections

In [None]:
# Visualize detections with matplotlib
fig, axes = plt.subplots(len(test_images), 1, figsize=(12, 4 * len(test_images)))
if len(test_images) == 1:
    axes = [axes]

for idx, (img_path, result) in enumerate(zip(test_images, results)):
    image = cv2.imread(str(img_path))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Create annotated image
    annotated = predictor.visualize(image, result['detections'])
    annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
    
    # Display
    axes[idx].imshow(annotated_rgb)
    axes[idx].set_title(f"{img_path.name} - {result['num_detections']} symbols detected")
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

print("✓ Visualization complete")

## 5. Detailed Detection Results

In [None]:
# Show detailed results for first image
print("Detection Details for First Image:\n")
print(f"Image: {results[0]['image']}")
print(f"Total Detections: {results[0]['num_detections']}\n")

for i, det in enumerate(results[0]['detections'], 1):
    print(f"{i}. {det['class_name']}")
    print(f"   Confidence: {det['confidence']:.3f}")
    print(f"   Box: ({det['bbox'][0]:.0f}, {det['bbox'][1]:.0f}) -> ({det['bbox'][2]:.0f}, {det['bbox'][3]:.0f})")
    print(f"   Size: {det['width']:.0f} x {det['height']:.0f}")
    print()

# Statistics
print("\nInference Statistics:")
total_detections = sum(r['num_detections'] for r in results)
avg_detections = total_detections / len(results) if results else 0
print(f"  Total images: {len(results)}")
print(f"  Total detections: {total_detections}")
print(f"  Average detections per image: {avg_detections:.1f}")

# Detection breakdown by class
class_counts = {}
for result in results:
    for det in result['detections']:
        class_name = det['class_name']
        class_counts[class_name] = class_counts.get(class_name, 0) + 1

print("\n  Detections by Class:")
for class_name, count in sorted(class_counts.items()):
    print(f"    {class_name}: {count}")

## 6. Inference Performance Benchmarking

In [None]:
import time
import torch

# Benchmark inference speed
print("Benchmarking inference performance...\n")

test_image = cv2.imread(str(test_images[0]))
num_runs = 10

times = []
for i in range(num_runs):
    start = time.time()
    _ = predictor.predict(test_image)
    elapsed = time.time() - start
    times.append(elapsed * 1000)  # Convert to ms

avg_time = np.mean(times)
std_time = np.std(times)
min_time = np.min(times)
max_time = np.max(times)

print(f"Inference Speed ({num_runs} runs):")
print(f"  Average: {avg_time:.2f} ms")
print(f"  Std Dev: {std_time:.2f} ms")
print(f"  Min: {min_time:.2f} ms")
print(f"  Max: {max_time:.2f} ms")
print(f"  Throughput: {1000/avg_time:.1f} images/sec")

# GPU Memory
if torch.cuda.is_available():
    print(f"\nGPU Memory:")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
    print(f"  Reserved: {torch.cuda.memory_reserved() / 1024**2:.1f} MB")
    
    device_name = torch.cuda.get_device_name(0)
    print(f"  Device: {device_name}")

# Visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Histogram of inference times
ax1.hist(times, bins=20, edgecolor='black', alpha=0.7)
ax1.set_xlabel('Inference Time (ms)')
ax1.set_ylabel('Frequency')
ax1.set_title(f'Inference Time Distribution ({num_runs} runs)')
ax1.axvline(avg_time, color='red', linestyle='--', linewidth=2, label=f'Mean: {avg_time:.2f}ms')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Performance summary
summary_text = f"""
Inference Performance Summary

Average: {avg_time:.2f} ms
Latency: ±{std_time:.2f} ms
Throughput: {1000/avg_time:.1f} img/s

Model: FasterRCNN ResNet50+FPN
Backbone: Frozen
Device: {predictor.device.upper()}
Confidence: {predictor.confidence_threshold}
"""

ax2.text(0.1, 0.5, summary_text, fontsize=12, family='monospace',
         transform=ax2.transAxes, verticalalignment='center')
ax2.axis('off')

plt.tight_layout()
plt.show()

print("\n✓ Benchmark complete")