# Gymnastics Apparatus Detection - MediaPipe Model Maker Training

This notebook trains a custom MediaPipe object detection model for gymnastics apparatus.

**Dataset**: 5,220 images in Pascal VOC format  
**Classes**: Balance_Beam, Horizontal_Bar, Parallel_Bars, Pommel_Horse, Still_Rings, Uneven_Bars, Vault  
**Model**: EfficientDet-Lite2 (balanced speed/accuracy)  
**Output**: TFLite model for MediaPipe integration

---

## üìã Instructions

1. **Enable GPU**: Runtime ‚Üí Change runtime type ‚Üí GPU
2. **Upload Dataset**: Upload your `raw_object_detect_pascalvoc` folder to Colab
3. **Run All Cells**: Runtime ‚Üí Run all
4. **Download Model**: Download the trained `.tflite` file at the end

## 1Ô∏è‚É£ Install Dependencies

In [None]:
!pip install -q mediapipe-model-maker

## 2Ô∏è‚É£ Import Libraries

In [None]:
import os
import json
import xml.etree.ElementTree as ET
from pathlib import Path
import shutil
from sklearn.model_selection import train_test_split
import cv2
import numpy as np
from tqdm import tqdm
from datetime import datetime
from mediapipe_model_maker import object_detector
import tensorflow as tf

print("‚úÖ Libraries imported successfully")
print(f"TensorFlow version: {tf.__version__}")

## 3Ô∏è‚É£ Upload Dataset

**Option A**: Upload ZIP file

In [None]:
from google.colab import files
import zipfile

# Upload dataset ZIP
print("üì§ Upload your dataset ZIP file (raw_object_detect_pascalvoc.zip)")
uploaded = files.upload()

# Extract
for filename in uploaded.keys():
    print(f"Extracting {filename}...")
    with zipfile.ZipFile(filename, 'r') as zip_ref:
        zip_ref.extractall('/content')

print("‚úÖ Dataset uploaded and extracted")

**Option B**: Mount Google Drive (if dataset is in Drive)

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')
# 
# # Update this path to your dataset location in Drive
# RAW_DATA_DIR = Path('/content/drive/MyDrive/gym_data/raw_object_detect_pascalvoc')

## 4Ô∏è‚É£ Configuration

In [None]:
# Paths
RAW_DATA_DIR = Path('/content/raw_object_detect_pascalvoc')  # Update if needed
OUTPUT_DIR = Path('/content/mediapipe_dataset')
MODEL_OUTPUT_DIR = Path('/content/trained_model')

# Training parameters
EPOCHS = 50
BATCH_SIZE = 8

# Data split
TRAIN_SPLIT = 0.7
VAL_SPLIT = 0.2
TEST_SPLIT = 0.1

# Classes (alphabetically sorted)
CLASSES = [
    "Balance_Beam",
    "Horizontal_Bar",
    "Parallel_Bars",
    "Pommel_Horse",
    "Still_Rings",
    "Uneven_Bars",
    "Vault"
]

# Model architecture
MODEL_SPEC = "efficientdet_lite2"  # Options: lite0, lite2, lite4

print(f"‚úÖ Configuration set")
print(f"   Model: {MODEL_SPEC}")
print(f"   Classes: {len(CLASSES)}")
print(f"   Epochs: {EPOCHS}")

## 5Ô∏è‚É£ Data Preparation Functions

In [None]:
def parse_pascal_voc(xml_file):
    """Parse Pascal VOC XML annotation."""
    tree = ET.parse(xml_file)
    root = tree.getroot()
    
    size = root.find('size')
    width = int(size.find('width').text)
    height = int(size.find('height').text)
    filename = root.find('filename').text
    
    objects = []
    for obj in root.findall('object'):
        name = obj.find('name').text
        bbox = obj.find('bndbox')
        
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)
        
        objects.append({
            'class': name,
            'bbox': [xmin, ymin, xmax, ymax]
        })
    
    return {
        'filename': filename,
        'width': width,
        'height': height,
        'objects': objects
    }


def convert_to_coco_format(annotations_list, class_mapping):
    """Convert Pascal VOC to COCO format."""
    coco_data = {
        "images": [],
        "annotations": [],
        "categories": []
    }
    
    # Create categories
    for class_name, class_id in class_mapping.items():
        coco_data["categories"].append({
            "id": class_id + 1,
            "name": class_name,
            "supercategory": "apparatus"
        })
    
    annotation_id = 1
    
    for image_id, annotation in enumerate(annotations_list, start=1):
        coco_data["images"].append({
            "id": image_id,
            "file_name": annotation['filename'],
            "width": annotation['width'],
            "height": annotation['height']
        })
        
        for obj in annotation['objects']:
            class_name = obj['class']
            if class_name not in class_mapping:
                continue
            
            xmin, ymin, xmax, ymax = obj['bbox']
            width = xmax - xmin
            height = ymax - ymin
            
            coco_data["annotations"].append({
                "id": annotation_id,
                "image_id": image_id,
                "category_id": class_mapping[class_name] + 1,
                "bbox": [xmin, ymin, width, height],
                "area": width * height,
                "iscrowd": 0
            })
            annotation_id += 1
    
    return coco_data

print("‚úÖ Data preparation functions defined")

## 6Ô∏è‚É£ Prepare Dataset (Pascal VOC ‚Üí COCO)

In [None]:
print("="*80)
print("PREPARING DATASET (COCO FORMAT)")
print("="*80)

# Create output directories
for split in ['train', 'val', 'test']:
    (OUTPUT_DIR / split).mkdir(parents=True, exist_ok=True)

# Create class mapping
class_mapping = {name: idx for idx, name in enumerate(CLASSES)}

# Get all image files
image_files = list(RAW_DATA_DIR.glob("*.jpg"))
print(f"\nFound {len(image_files)} images")

# Parse annotations
print("\nParsing annotations...")
all_annotations = []
valid_image_files = []

for img_file in tqdm(image_files):
    xml_file = img_file.with_suffix('.xml')
    
    if not xml_file.exists():
        continue
    
    try:
        annotation = parse_pascal_voc(xml_file)
        if annotation['objects']:
            all_annotations.append(annotation)
            valid_image_files.append(img_file)
    except Exception as e:
        print(f"Error: {e}")
        continue

print(f"Valid images: {len(valid_image_files)}")

# Split dataset
train_idx, temp_idx = train_test_split(
    range(len(valid_image_files)), 
    train_size=TRAIN_SPLIT, 
    random_state=42
)
val_idx, test_idx = train_test_split(
    temp_idx,
    train_size=VAL_SPLIT / (VAL_SPLIT + TEST_SPLIT),
    random_state=42
)

splits = {
    'train': train_idx,
    'val': val_idx,
    'test': test_idx
}

print(f"\nTrain: {len(train_idx)}")
print(f"Val: {len(val_idx)}")
print(f"Test: {len(test_idx)}")

# Process each split
split_paths = {}

for split_name, indices in splits.items():
    print(f"\nProcessing {split_name}...")
    
    split_annotations = [all_annotations[i] for i in indices]
    split_images = [valid_image_files[i] for i in indices]
    
    coco_data = convert_to_coco_format(split_annotations, class_mapping)
    
    for img_file in tqdm(split_images, desc=f"Copying {split_name}"):
        dst_img = OUTPUT_DIR / split_name / img_file.name
        shutil.copy(img_file, dst_img)
    
    coco_json_path = OUTPUT_DIR / split_name / "annotations.json"
    with open(coco_json_path, 'w') as f:
        json.dump(coco_data, f, indent=2)
    
    split_paths[split_name] = {
        'images': str(OUTPUT_DIR / split_name),
        'annotations': str(coco_json_path)
    }
    
    print(f"{split_name.upper()}: {len(coco_data['images'])} images, {len(coco_data['annotations'])} objects")

print("\n‚úÖ Dataset preparation complete!")

## 7Ô∏è‚É£ Load Data for Training

In [None]:
print("Loading training data...")
train_data = object_detector.Dataset.from_coco_folder(
    split_paths['train']['images'],
    annotations_json_path=split_paths['train']['annotations']
)

print("Loading validation data...")
val_data = object_detector.Dataset.from_coco_folder(
    split_paths['val']['images'],
    annotations_json_path=split_paths['val']['annotations']
)

print("‚úÖ Data loaded successfully")

## 8Ô∏è‚É£ Train Model

**This will take 2-4 hours with GPU**

In [None]:
print("="*80)
print("TRAINING MEDIAPIPE OBJECT DETECTION MODEL")
print("="*80)

# Create model specification
spec = object_detector.SupportedModels.get(MODEL_SPEC)

# Configure hyperparameters
hparams = object_detector.HParams(
    learning_rate=0.3,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    cosine_decay_epochs=EPOCHS,
    cosine_decay_alpha=1.0
)

# Train model
print(f"\nModel: {MODEL_SPEC}")
print(f"Epochs: {EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print("\nStarting training...\n")

model = object_detector.ObjectDetector.create(
    train_data=train_data,
    validation_data=val_data,
    model_spec=spec,
    hparams=hparams,
    do_train=True
)

print("\n‚úÖ Training complete!")

## 9Ô∏è‚É£ Evaluate Model

In [None]:
print("="*80)
print("EVALUATING MODEL")
print("="*80)

loss, coco_metrics = model.evaluate(val_data, batch_size=BATCH_SIZE)

print(f"\nValidation Loss: {loss:.4f}")
print(f"COCO mAP: {coco_metrics}")

## üîü Export Model

In [None]:
print("="*80)
print("EXPORTING MODEL")
print("="*80)

# Create output directory
MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Export to TFLite
tflite_path = MODEL_OUTPUT_DIR / 'gym_apparatus_detector.tflite'
model.export_model(str(tflite_path))

print(f"\n‚úÖ Model exported to: {tflite_path}")

# Save label map
label_map_path = MODEL_OUTPUT_DIR / 'labels.txt'
with open(label_map_path, 'w') as f:
    for class_name in CLASSES:
        f.write(f"{class_name}\n")

print(f"Label map saved to: {label_map_path}")

# Save metadata
metadata = {
    'model_spec': MODEL_SPEC,
    'epochs': EPOCHS,
    'batch_size': BATCH_SIZE,
    'classes': CLASSES,
    'num_classes': len(CLASSES),
    'trained_on': datetime.now().isoformat(),
    'validation_loss': float(loss),
    'coco_metrics': str(coco_metrics)
}

metadata_path = MODEL_OUTPUT_DIR / 'metadata.json'
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"Metadata saved to: {metadata_path}")
print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)

## üì• Download Trained Model

In [None]:
from google.colab import files

# Download model
files.download(str(tflite_path))
files.download(str(label_map_path))
files.download(str(metadata_path))

print("‚úÖ Files downloaded!")
print("\nNext steps:")
print("1. Copy gym_apparatus_detector.tflite to your project")
print("2. Rename to gym_apparatus_custom.tflite")
print("3. Place in model_service/models/")
print("4. Restart your API server")
print("5. Test on gymnastics videos!")