# SurgicalAI Training Notebook

This notebook trains all components of the SurgicalAI system:
1. Tool Detection
2. Phase Recognition
3. Mistake Detection

Each section can be run independently.

In [None]:
# Setup environment and mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create a directory to store our weights
!mkdir -p /content/drive/MyDrive/SurgicalAI/weights

# Clone the repository
!git clone https://github.com/YOUR_USERNAME/SurgicalAI
%cd SurgicalAI

# Install dependencies
!pip install -r requirements.txt
!pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118

## Upload Training Data

Upload your training data or copy from Drive if already uploaded.

In [None]:
# Option 1: Upload data directly to this Colab session
# from google.colab import files
# uploaded = files.upload()  # Upload annotation files

# Option 2: Copy from Google Drive if already uploaded
!mkdir -p data/annotations
!mkdir -p data/train_processed
!mkdir -p data/phases

# Copy your data from Drive (uncomment and modify paths as needed)
# !cp /content/drive/MyDrive/SurgicalAI/data/annotations/* data/annotations/
# !cp -r /content/drive/MyDrive/SurgicalAI/data/train_processed/* data/train_processed/
# !cp -r /content/drive/MyDrive/SurgicalAI/data/phases/* data/phases/

## 1. Tool Detection Training

In [None]:
# Train the tool detection model
!python training/train_tool_detection.py \
  --data_dir data \
  --output_dir models/weights \
  --batch_size 4 \
  --num_epochs 20 \
  --learning_rate 3e-4 \
  --backbone resnet50 \
  --use_mixed_precision True

# Save the trained model to Drive
!cp models/weights/tool_detection/tool_detection.pth /content/drive/MyDrive/SurgicalAI/weights/

## 2. Phase Recognition Training

In [None]:
# Train the phase recognition model
!python training/train_phase_recognition.py \
  --data_dir data \
  --output_dir models/weights/vit_lstm \
  --batch_size 2 \
  --num_epochs 15 \
  --vit_model vit_base_patch16_224 \
  --freeze_vit True

# Save the trained model to Drive
!cp models/weights/vit_lstm/phase_recognition.pth /content/drive/MyDrive/SurgicalAI/weights/

## 3. Mistake Detection Training

In [None]:
# Train the mistake detection model
!python training/train_all_models.py \
  --train_subset mistake_detection \
  --data_dir data \
  --output_dir models/weights \
  --batch_size 4 \
  --num_epochs 10

# Save the trained model to Drive
!cp models/weights/mistake_detector/mistake_detection.pth /content/drive/MyDrive/SurgicalAI/weights/

## Verify and Download Models

Check that all models are trained and saved.

In [None]:
# List saved models in Drive
!ls -la /content/drive/MyDrive/SurgicalAI/weights/

# Download models directly from Colab if needed
from google.colab import files

# Uncomment to download specific models
# files.download('/content/drive/MyDrive/SurgicalAI/weights/tool_detection.pth')
# files.download('/content/drive/MyDrive/SurgicalAI/weights/phase_recognition.pth')
# files.download('/content/drive/MyDrive/SurgicalAI/weights/mistake_detection.pth')