# Railway Track Anomaly Detection - Training Notebook

This notebook provides an end-to-end guide for:
1. Environment setup and installation
2. Dataset exploration
3. Training TWO YOLOv8 models (Binary + Detailed)
4. Model evaluation and visualization
5. Cascade inference testing

---

## 1. Environment Setup & Installation

In [None]:
# Install required packages
!pip install ultralytics opencv-python pillow pyyaml matplotlib --quiet

In [None]:
# Verify installation
import ultralytics
ultralytics.checks()

In [None]:
# Imports
import os
import cv2
import yaml
import random
import shutil
from pathlib import Path
from ultralytics import YOLO
import matplotlib.pyplot as plt
from PIL import Image

# Set working directory to project root
PROJECT_ROOT = Path(".").absolute()
print(f"Project Root: {PROJECT_ROOT}")

## 2. Dataset Exploration

In [None]:
# Dataset paths
DATASETS = {
    "binary": PROJECT_ROOT / "data" / "Railway Track Fault detection.v4i.yolov8",
    "detailed": PROJECT_ROOT / "data" / "Railway Track Defect Detection.v1i.yolov8"
}

# Check datasets exist
for name, path in DATASETS.items():
    exists = path.exists()
    print(f"{name}: {'Found' if exists else 'NOT FOUND'} - {path}")

In [None]:
# Load and display data.yaml for each dataset
def show_dataset_info(dataset_name):
    yaml_path = DATASETS[dataset_name] / "data.yaml"
    with open(yaml_path, 'r') as f:
        data = yaml.safe_load(f)
    
    print(f"\n=== {dataset_name.upper()} Dataset ===")
    print(f"Number of classes: {data['nc']}")
    print(f"Classes: {data['names']}")
    
    # Count images
    train_dir = DATASETS[dataset_name] / "train" / "images"
    if train_dir.exists():
        train_count = len(list(train_dir.glob("*")))
        print(f"Training images: {train_count}")

show_dataset_info("binary")
show_dataset_info("detailed")

In [None]:
# Visualize sample images from each dataset
def show_samples(dataset_name, num_samples=4):
    img_dir = DATASETS[dataset_name] / "train" / "images"
    images = list(img_dir.glob("*.jpg")) + list(img_dir.glob("*.png"))
    samples = random.sample(images, min(len(images), num_samples))
    
    fig, axes = plt.subplots(1, num_samples, figsize=(16, 4))
    fig.suptitle(f"{dataset_name.upper()} Dataset Samples", fontsize=14)
    
    for ax, img_path in zip(axes, samples):
        img = Image.open(img_path)
        ax.imshow(img)
        ax.set_title(img_path.name[:20])
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

show_samples("binary")
show_samples("detailed")

## 3. Model Training

We train TWO models:
1. **Binary Model**: Defective vs Non-Defective (fast filter)
2. **Detailed Model**: 9 specific defect types with severity levels

### 3.1 Train Binary Model (Stage 1 - Fast Filter)

In [None]:
# Configuration for Binary Model
BINARY_CONFIG = {
    "data": str(DATASETS["binary"] / "data.yaml"),
    "epochs": 50,
    "imgsz": 640,
    "batch": 16,  # Reduce if you get OOM errors
    "project": "runs/detect",
    "name": "train_binary",
    "exist_ok": True
}

print("Binary Model Configuration:")
for k, v in BINARY_CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
# Train Binary Model
print("Loading base model (yolov8n)...")
binary_model = YOLO("yolov8n.pt")

print("Starting Binary Model Training...")
print("This may take 30-60 minutes depending on your hardware.")
print("-" * 50)

binary_results = binary_model.train(**BINARY_CONFIG)

print("\n" + "=" * 50)
print("Binary Model Training Complete!")
print(f"Results saved to: {binary_results.save_dir}")

### 3.2 Train Detailed Model (Stage 2 - Classification)

In [None]:
# Configuration for Detailed Model
DETAILED_CONFIG = {
    "data": str(DATASETS["detailed"] / "data.yaml"),
    "epochs": 50,
    "imgsz": 640,
    "batch": 16,
    "project": "runs/detect",
    "name": "train_detailed",
    "exist_ok": True
}

print("Detailed Model Configuration:")
for k, v in DETAILED_CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
# Train Detailed Model
print("Loading base model (yolov8n)...")
detailed_model = YOLO("yolov8n.pt")

print("Starting Detailed Model Training...")
print("This may take 30-60 minutes depending on your hardware.")
print("-" * 50)

detailed_results = detailed_model.train(**DETAILED_CONFIG)

print("\n" + "=" * 50)
print("Detailed Model Training Complete!")
print(f"Results saved to: {detailed_results.save_dir}")

### 3.3 Copy Trained Weights to Models Folder

In [None]:
# Create models directory and copy weights
models_dir = PROJECT_ROOT / "models"
models_dir.mkdir(exist_ok=True)

# Copy Binary weights
binary_weights = Path("runs/detect/train_binary/weights/best.pt")
if binary_weights.exists():
    shutil.copy(binary_weights, models_dir / "binary_model.pt")
    print(f"Copied: {binary_weights} -> models/binary_model.pt")
else:
    print(f"Warning: {binary_weights} not found")

# Copy Detailed weights
detailed_weights = Path("runs/detect/train_detailed/weights/best.pt")
if detailed_weights.exists():
    shutil.copy(detailed_weights, models_dir / "detailed_model.pt")
    print(f"Copied: {detailed_weights} -> models/detailed_model.pt")
else:
    print(f"Warning: {detailed_weights} not found")

print("\nModels directory contents:")
for f in models_dir.glob("*.pt"):
    print(f"  - {f.name}")

## 4. Model Evaluation

In [None]:
# Load trained models
binary_model = YOLO("models/binary_model.pt")
detailed_model = YOLO("models/detailed_model.pt")

print("Models loaded successfully!")
print(f"Binary model classes: {binary_model.names}")
print(f"Detailed model classes: {detailed_model.names}")

In [None]:
# Validate Binary Model
print("Validating Binary Model...")
binary_val = binary_model.val(data=str(DATASETS["binary"] / "data.yaml"))
print(f"mAP50: {binary_val.box.map50:.4f}")
print(f"mAP50-95: {binary_val.box.map:.4f}")

In [None]:
# Validate Detailed Model
print("Validating Detailed Model...")
detailed_val = detailed_model.val(data=str(DATASETS["detailed"] / "data.yaml"))
print(f"mAP50: {detailed_val.box.map50:.4f}")
print(f"mAP50-95: {detailed_val.box.map:.4f}")

## 5. Cascade Inference Testing

In [None]:
def extract_severity(class_name):
    """Extract severity from class name."""
    if "high" in class_name.lower():
        return "HIGH"
    elif "medium" in class_name.lower():
        return "MEDIUM"
    elif "low" in class_name.lower():
        return "LOW"
    return "UNKNOWN"


def cascade_inference(image_path, binary_model, detailed_model, threshold=0.5):
    """
    Run two-stage cascade inference.
    Stage 1: Binary check (defective / non-defective)
    Stage 2: Detailed classification (only if defective)
    """
    
    # Stage 1: Binary
    binary_result = binary_model(image_path, conf=threshold, verbose=False)
    boxes = binary_result[0].boxes
    
    if boxes is None or len(boxes) == 0:
        return {
            "status": "non-defective",
            "confidence": 0.0,
            "stage": 1
        }
    
    max_conf = float(boxes.conf.max())
    
    if max_conf < threshold:
        return {
            "status": "non-defective",
            "confidence": max_conf,
            "stage": 1
        }
    
    # Stage 2: Detailed
    detailed_result = detailed_model(image_path, conf=0.25, verbose=False)
    detailed_boxes = detailed_result[0].boxes
    
    if detailed_boxes is None or len(detailed_boxes) == 0:
        return {
            "status": "defective",
            "defect_type": "unknown",
            "severity": "UNKNOWN",
            "confidence": max_conf,
            "stage": 2
        }
    
    best_idx = detailed_boxes.conf.argmax()
    class_id = int(detailed_boxes.cls[best_idx])
    class_name = detailed_result[0].names[class_id]
    conf = float(detailed_boxes.conf[best_idx])
    
    return {
        "status": "defective",
        "defect_type": class_name,
        "severity": extract_severity(class_name),
        "confidence": conf,
        "stage": 2
    }

In [None]:
# Test cascade inference on sample images
test_images_dir = DATASETS["detailed"] / "test" / "images"
test_images = list(test_images_dir.glob("*.jpg")) + list(test_images_dir.glob("*.png"))

print(f"Found {len(test_images)} test images")
print("-" * 60)

# Test on random samples
samples = random.sample(test_images, min(10, len(test_images)))

for img_path in samples:
    result = cascade_inference(str(img_path), binary_model, detailed_model)
    
    if result["status"] == "non-defective":
        print(f"[OK] {img_path.name}: Non-defective (conf: {result['confidence']:.2f})")
    else:
        print(f"[!!] {img_path.name}: {result['defect_type']} | Severity: {result['severity']} | Conf: {result['confidence']:.2f}")

In [None]:
# Visualize cascade results
def visualize_cascade_result(image_path, binary_model, detailed_model):
    """Visualize the cascade inference result."""
    result = cascade_inference(str(image_path), binary_model, detailed_model)
    
    # Get annotated image from detailed model
    detailed_result = detailed_model(str(image_path), verbose=False)
    annotated = detailed_result[0].plot()
    annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(annotated_rgb)
    
    title = f"Status: {result['status'].upper()}"
    if result['status'] == 'defective':
        title += f" | Type: {result['defect_type']} | Severity: {result['severity']}"
    title += f" | Conf: {result['confidence']:.2f}"
    
    plt.title(title, fontsize=12)
    plt.axis('off')
    plt.show()

# Visualize a few samples
for img_path in random.sample(test_images, min(3, len(test_images))):
    visualize_cascade_result(img_path, binary_model, detailed_model)

## 6. Summary & Next Steps

In [None]:
print("=" * 60)
print("TRAINING COMPLETE - SUMMARY")
print("=" * 60)

print("\nTrained Models:")
print("  1. Binary Model: models/binary_model.pt")
print("     Classes: defective, non-defective")
print("     Purpose: Fast filter")

print("\n  2. Detailed Model: models/detailed_model.pt")
print("     Classes: 9 defect types with severity")
print("     Purpose: Classification")

print("\nCascade Pipeline:")
print("  Image -> Binary Model -> (if defective) -> Detailed Model -> Result")

print("\nNext Steps:")
print("  1. Start FastAPI backend: uvicorn backend.main:app --reload")
print("  2. Start Next.js frontend: cd frontend && npm run dev")
print("  3. Expose with ngrok: ngrok http 3000")
print("  4. Test on mobile!")