# Electrical Symbol Detection - Training on Google Colab
Train ResNet50+FPN model with CIoU Loss for multi-class object detection

## 1. Setup Environment

In [18]:
import sys
import torch

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("GPU: Not available (will use CPU)")

Python version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
PyTorch version: 2.9.0+cu128
CUDA available: True
GPU: NVIDIA A100-SXM4-40GB


## 2. Mount Google Drive (for saving checkpoints)

In [19]:
import sys
import os

# Check if running on Google Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False)
        print("‚úì Google Drive mounted at /content/drive")
    except Exception as e:
        print(f"‚ö† Could not mount Google Drive: {e}")
        print("Proceeding without Drive - checkpoints will save locally in /content/")
else:
    print("‚ö† Running locally (not on Google Colab)")
    print("Dataset will be saved locally")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úì Google Drive mounted at /content/drive


## 3. Clone Repository

In [20]:
import os
import subprocess

repo_path = '/content/symbol-detection'

if not os.path.exists(repo_path):
    subprocess.run(['git', 'clone', 'https://github.com/BhanukaDev/symbol-detection.git', repo_path], check=True)
    print(f"Repository cloned to {repo_path}")
else:
    print(f"Repository already exists at {repo_path}")
    os.chdir(repo_path)
    subprocess.run(['git', 'pull'], check=True)
    print("Repository updated")

Repository already exists at /content/symbol-detection
Repository updated


## 4. Install Dependencies

In [21]:
import os

# Change to python directory if in Colab
if os.path.exists('/content/symbol-detection/python'):
    os.chdir('/content/symbol-detection/python')
    
    # Install local workspace packages first
    print("Installing local workspace packages...")
    !pip install -e ./floor-grid
    !pip install -e ./effects
    print("‚úì Workspace packages installed")

# Install external dependencies
!pip install torch torchvision torchmetrics pycocotools timm

# Install main package
!pip install -e .

Installing local workspace packages...
Obtaining file:///content/symbol-detection/python/floor-grid
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: floor-grid
  Building editable for floor-grid (pyproject.toml) ... [?25l[?25hdone
  Created wheel for floor-grid: filename=floor_grid-0.1.0-py3-none-any.whl size=1196 sha256=78699117b6f3e43fc7361415f221b0fc1e7cda53ed8d7db66595655680a66a57
  Stored in directory: /tmp/pip-ephem-wheel-cache-aah615rx/wheels/9e/6b/f6/93c9e88f3c6f9856769f5b99711035582c95144db05c26c467
Successfully built floor-grid
Installing collected packages: floor-grid
  Attempting uninstall: floor-grid
    Found existing installation: floor-grid 0.1.0
    Uninstalling floor-grid-0.1.0:
      Successfully uninstalled floor-grid-0.1

## 5. Verify Installation

In [22]:
import sys
import os

# Ensure we're in the right directory
if os.path.exists('/content/symbol-detection/python'):
    os.chdir('/content/symbol-detection/python')
    sys.path.insert(0, '/content/symbol-detection/python/src')

try:
    from symbol_detection.training import Trainer, CIoULoss
    from symbol_detection.dataset.generator import COCODatasetGenerator
    
    print("‚úì symbol-detection package imported successfully")
    print("‚úì Trainer available")
    print("‚úì CIoU Loss available")
    print("‚úì COCODatasetGenerator available")
except ImportError as e:
    print(f"‚úó Import failed: {e}")
    print("\nReinstalling package...")
    os.chdir('/content/symbol-detection/python')
    !pip install -e .
    print("Please re-run this cell after installation completes.")

‚úó Import failed: No module named 'floor_grid'

Reinstalling package...
Obtaining file:///content/symbol-detection/python
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: symbol-detection
  Building editable for symbol-detection (pyproject.toml) ... [?25l[?25hdone
  Created wheel for symbol-detection: filename=symbol_detection-0.1.0-0.editable-py3-none-any.whl size=1780 sha256=4c1b5d4959102b7412f09d66e0059ffa5bbd61083af8d00d466d31c1dfca669a
  Stored in directory: /tmp/pip-ephem-wheel-cache-10qhihfd/wheels/94/36/75/5aa7df0c2e953f991eb4aa945f3eb59d6faa6eeb9272dcd759
Successfully built symbol-detection
Installing collected packages: symbol-detection
  Attempting uninstall: symbol-detection
    Found existing installation: symbol-detection 0.1.

Please re-run this cell after installation completes.


## 6. Mount Dataset Location

In [23]:
import os
import sys
from pathlib import Path

# Determine paths based on environment
IN_COLAB = 'google.colab' in sys.modules
DRIVE_MOUNTED = os.path.exists('/content/drive/MyDrive') if IN_COLAB else False

if IN_COLAB and DRIVE_MOUNTED:
    # Save to Google Drive
    dataset_dir = '/content/drive/MyDrive/symbol-detection/dataset'
    checkpoints_dir = '/content/drive/MyDrive/symbol-detection/checkpoints'
    print("‚úì Using Google Drive for storage")
elif IN_COLAB:
    # Fallback to temporary Colab storage
    dataset_dir = '/content/symbol-detection/dataset'
    checkpoints_dir = '/content/symbol-detection/checkpoints'
    print("‚ö† Google Drive not mounted - using temporary Colab storage")
else:
    # Local development
    dataset_dir = str(Path.cwd().parent.parent / 'python' / 'dataset')
    checkpoints_dir = str(Path.cwd().parent.parent / 'python' / 'checkpoints')
    print("üìÅ Using local storage")

os.makedirs(dataset_dir, exist_ok=True)
os.makedirs(checkpoints_dir, exist_ok=True)

print(f"Dataset directory: {dataset_dir}")
print(f"Checkpoints directory: {checkpoints_dir}")
print(f"Dataset exists: {os.path.exists(os.path.join(dataset_dir, 'annotations.json'))}")

‚úì Using Google Drive for storage
Dataset directory: /content/drive/MyDrive/symbol-detection/dataset
Checkpoints directory: /content/drive/MyDrive/symbol-detection/checkpoints
Dataset exists: False


## 7. Generate Dataset (Optional - if not using pre-generated)

In [34]:
import os
import sys

os.chdir('/content/symbol-detection/python')
sys.path.insert(0, '/content/symbol-detection/python/src')

from symbol_detection.dataset.generator import COCODatasetGenerator

print("Generating dataset (200 images)...")

generator = COCODatasetGenerator(
    output_dir=dataset_dir,
    symbols_dir='data/electrical-symbols',
)

# Use the built-in generator with proper COCO annotation conversion
generator.generate_dataset(
    num_images=200,
    rows=(15, 30),              # min, max rows
    cols=(15, 30),              # min, max columns
    cell_size=(20, 25),         # min, max cell size
    apply_symbol_effects=False, # Skip slow effects
    apply_image_effects=True,   # Keep image effects
)

# Save annotations to disk
generator.save_annotations()

num_images = len(os.listdir(os.path.join(dataset_dir, 'images')))
print(f"‚úì Dataset generated: {num_images} images")
print(f"‚úì Annotations saved: {dataset_dir}/annotations.json")

Generating dataset (200 images)...
Generating 200 dataset images with varied dimensions...
  - Rows range: 15 to 30
  - Cols range: 15 to 30
  - Cell size range: 20 to 25 pixels
  - Symbol effects: disabled
  - Image effects: enabled
Loaded 7 symbol classes:
  - Junction Box: 1 variant(s)
  - Two-way switch: 1 variant(s)
  - Single-pole, one-way switch: 1 variant(s)
  - Light: 1 variant(s)
  - Three-pole, one-way switch: 1 variant(s)
  - Two-pole, one-way switch: 1 variant(s)
  - Duplex Receptacle: 1 variant(s)
[1/200] Generated floor_plan_0000.png (22x30) - 4 rooms, 9 symbols
Loaded 7 symbol classes:
  - Junction Box: 1 variant(s)
  - Two-way switch: 1 variant(s)
  - Single-pole, one-way switch: 1 variant(s)
  - Light: 1 variant(s)
  - Three-pole, one-way switch: 1 variant(s)
  - Two-pole, one-way switch: 1 variant(s)
  - Duplex Receptacle: 1 variant(s)
[2/200] Generated floor_plan_0001.png (29x25) - 4 rooms, 4 symbols
Loaded 7 symbol classes:
  - Junction Box: 1 variant(s)
  - Two-wa

## 8. Training Configuration

In [35]:
# Training hyperparameters for A100 GPU (40GB memory)
training_config = {
    'num_epochs': 50,       # Full training
    'batch_size': 12,       # A100 can handle larger batches
    'learning_rate': 0.005,
    'num_classes': 7,       # Electrical symbols
    'use_ciou_loss': True,  # Complete IoU Loss per paper
}

print("Training Configuration:")
for key, value in training_config.items():
    print(f"  {key}: {value}")
print(f"\n‚úì A100 GPU selected - using batch_size=12 for optimal performance")

Training Configuration:
  num_epochs: 50
  batch_size: 12
  learning_rate: 0.005
  num_classes: 7
  use_ciou_loss: True

‚úì A100 GPU selected - using batch_size=12 for optimal performance


## 9. Run Training

In [41]:
import importlib
import symbol_detection.training.trainer
importlib.reload(symbol_detection.training.trainer)

from symbol_detection.training import Trainer  # Now with fixed validate() method
import torch

# Initialize trainer with updated code
trainer = Trainer(
    dataset_dir=dataset_dir,
    output_dir=checkpoints_dir,
    num_classes=training_config['num_classes'],
    batch_size=training_config['batch_size'],
    learning_rate=training_config['learning_rate'],
    num_epochs=training_config['num_epochs'],
    device='cuda' if torch.cuda.is_available() else 'cpu',
    use_ciou_loss=training_config['use_ciou_loss'],
)

print(f"Trainer initialized on device: {trainer.device}")
print(f"Model: FasterRCNN with ResNet50+FPN backbone")
print(f"CIoU Loss: {training_config['use_ciou_loss']}")

Using device: cuda
Trainer initialized on device: cuda
Model: FasterRCNN with ResNet50+FPN backbone
CIoU Loss: True


In [39]:
import json

# Verify annotations have bbox field before training
ann_file = f'{dataset_dir}/annotations.json'
with open(ann_file, 'r') as f:
    data = json.load(f)

print(f"Checking annotations format...")
print(f"  - Total images: {len(data.get('images', []))}")
print(f"  - Total annotations: {len(data.get('annotations', []))}")
print(f"  - Categories: {len(data.get('categories', []))}")

if data.get('annotations'):
    first_ann = data['annotations'][0]
    print(f"\nFirst annotation sample:")
    print(f"  - Keys: {list(first_ann.keys())}")
    if 'bbox' in first_ann:
        print(f"  - bbox: {first_ann['bbox']} ‚úì")
    else:
        print(f"  - bbox: MISSING ‚úó")
        print(f"\n‚ö† ERROR: annotations.json does not have 'bbox' field!")
        print(f"Solution: Please re-run cell 7 (Dataset Generation) to regenerate with proper COCO format")
else:
    print("‚ö† No annotations found in JSON")

Checking annotations format...
  - Total images: 200
  - Total annotations: 1546
  - Categories: 7

First annotation sample:
  - Keys: ['id', 'image_id', 'category_id', 'bbox', 'area', 'iscrowd']
  - bbox: [223.0, 125.0, 62.0, 45.0] ‚úì


In [40]:
# Start training
try:
    trainer.train()
    print("\n‚úì Training completed successfully")
except Exception as e:
    print(f"‚úó Training failed: {e}")
    import traceback
    traceback.print_exc()

Training for 50 epochs...
Training samples: 160, Validation samples: 40
‚úó Training failed: 'list' object has no attribute 'values'


Traceback (most recent call last):
  File "/tmp/ipython-input-1712165534.py", line 3, in <cell line: 0>
    trainer.train()
  File "/content/symbol-detection/python/src/symbol_detection/training/trainer.py", line 191, in train
    val_loss = self.validate(val_loader)
               ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/symbol-detection/python/src/symbol_detection/training/trainer.py", line 172, in validate
    losses = sum(loss_dict.values(), torch.tensor(0.0, device=self.device))
                 ^^^^^^^^^^^^^^^^
AttributeError: 'list' object has no attribute 'values'


## 10. Visualize Training Metrics

In [None]:
import matplotlib.pyplot as plt
import json
from pathlib import Path

metrics_file = Path(checkpoints_dir) / 'metrics.json'

if metrics_file.exists():
    with open(metrics_file, 'r') as f:
        metrics = json.load(f)
    
    plt.figure(figsize=(10, 6))
    plt.plot(metrics['train_losses'], label='Train Loss', marker='o')
    plt.plot(metrics['val_losses'], label='Validation Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'{checkpoints_dir}/training_curve.png', dpi=150)
    plt.show()
    
    print(f"Final train loss: {metrics['train_losses'][-1]:.4f}")
    print(f"Final val loss: {metrics['val_losses'][-1]:.4f}")
else:
    print("Metrics file not found. Training may not have completed.")

## 11. List Saved Checkpoints

In [None]:
from pathlib import Path

checkpoints = list(Path(checkpoints_dir).glob('*.pth'))

if checkpoints:
    print(f"Saved checkpoints ({len(checkpoints)}):")
    for ckpt in sorted(checkpoints):
        size_mb = ckpt.stat().st_size / (1024 * 1024)
        print(f"  {ckpt.name} ({size_mb:.1f} MB)")
    print(f"\nLatest checkpoint: {max(checkpoints, key=lambda x: x.stat().st_mtime).name}")
else:
    print("No checkpoints found")

## 12. Download Best Model (Optional)

In [None]:
# The models are already saved in Google Drive (/content/drive/MyDrive/symbol-detection/checkpoints/)
# You can download them directly from Google Drive or use the Colab files interface

print(f"Checkpoints saved to: {checkpoints_dir}")
print("You can download them from Google Drive or use the Colab Files panel on the left")