### CELL 1: Install Required Packages

In [None]:
# Install ultralytics (YOLOv8)
!pip install ultralytics opencv-python matplotlib pillow pandas

In [None]:
# Import libraries
import os
import yaml
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
import cv2
from IPython.display import Image, display
from ultralytics import YOLO
import pandas as pd

print("✓ All packages installed successfully!")

### CELL 2: Setup Configuration


In [None]:
DATASET_PATH = "dataset"
DATA_YAML = os.path.join(DATASET_PATH, "data.yaml")

In [None]:
# Training configuration
MODEL_NAME = "yolov8m.pt"
EPOCHS = 100
BATCH_SIZE = 16
IMAGE_SIZE = 640
DEVICE = 0

In [None]:
# Output directory
OUTPUT_DIR = "runs/detect/khmer_text_detection"

In [None]:
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Configuration:")
print(f"  Dataset: {DATASET_PATH}")
print(f"  Model: {MODEL_NAME}")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Image Size: {IMAGE_SIZE}")
print(f"  Device: {DEVICE}")

### CELL 3: Verify Dataset Structure

In [None]:
def verify_dataset(dataset_path, data_yaml):
    """Verify dataset structure and files"""
    print("=" * 70)
    print("DATASET VERIFICATION")
    print("=" * 70)
    
    # Check if data.yaml exists
    if not os.path.exists(data_yaml):
        print(f"Error: {data_yaml} not found!")
        return False
    
    print(f"Found data.yaml")
    
    # Load data.yaml
    with open(data_yaml, 'r') as f:
        data = yaml.safe_load(f)
    
    print(f"\nDataset configuration:")
    print(f"  Path: {data.get('path', 'Not specified')}")
    print(f"  Classes: {data.get('names', {})}")
    
    # Check splits
    splits = ['train', 'val', 'test']
    for split in splits:
        img_dir = os.path.join(dataset_path, split, 'images')
        label_dir = os.path.join(dataset_path, split, 'labels')
        
        if os.path.exists(img_dir) and os.path.exists(label_dir):
            img_count = len([f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
            label_count = len([f for f in os.listdir(label_dir) if f.endswith('.txt')])
            print(f"\n  {split.upper()}:")
            print(f"    Images: {img_count}")
            print(f"    Labels: {label_count}")
            
            if img_count != label_count:
                print(f"Warning: Image and label counts don't match!")
        else:
            print(f"\n  {split.upper()}: Directory not found")
    
    print("\n" + "=" * 70)
    return True


In [None]:
# Run verification
verify_dataset(DATASET_PATH, DATA_YAML)

### CELL 4: Visualize Sample Images

In [None]:
def visualize_samples(dataset_path, num_samples=4):
    """Visualize random samples with bounding boxes"""
    import random
    
    train_img_dir = os.path.join(dataset_path, 'train', 'images')
    train_label_dir = os.path.join(dataset_path, 'train', 'labels')
    
    # Get random images
    images = [f for f in os.listdir(train_img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
    samples = random.sample(images, min(num_samples, len(images)))
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))
    axes = axes.flatten()
    
    for idx, img_name in enumerate(samples):
        # Read image
        img_path = os.path.join(train_img_dir, img_name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Read labels
        label_name = Path(img_name).stem + '.txt'
        label_path = os.path.join(train_label_dir, label_name)
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                lines = f.readlines()
            
            h, w = img.shape[:2]
            
            # Draw bounding boxes
            for line in lines[:50]:  # Limit to 50 boxes for visualization
                parts = line.strip().split()
                if len(parts) >= 5:
                    class_id, x_center, y_center, width, height = map(float, parts[:5])
                    
                    # Convert YOLO format to pixel coordinates
                    x1 = int((x_center - width/2) * w)
                    y1 = int((y_center - height/2) * h)
                    x2 = int((x_center + width/2) * w)
                    y2 = int((y_center + height/2) * h)
                    
                    # Draw rectangle
                    cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
            
            axes[idx].imshow(img)
            axes[idx].set_title(f'{img_name}\n{len(lines)} words')
            axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'sample_images.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print(f"✓ Saved visualization to {OUTPUT_DIR}/sample_images.png")

visualize_samples(DATASET_PATH, num_samples=4)


### CELL 5: Initialize YOLOv8 Model

In [None]:
# Load pretrained YOLOv8 model
model = YOLO(MODEL_NAME)

print(f"✓ Loaded {MODEL_NAME} model")
print(f"\nModel Summary:")
print(model.model)

### CELL 6: Train the Model

In [None]:
print("=" * 70)
print("STARTING TRAINING")
print("=" * 70)

# Train the model
results = model.train(
    data=DATA_YAML,
    epochs=EPOCHS,
    batch=BATCH_SIZE,
    imgsz=IMAGE_SIZE,
    device=DEVICE,
    project='runs/detect',
    name='khmer_text_detection',
    patience=20,  # Early stopping patience
    save=True,
    save_period=10,  # Save checkpoint every 10 epochs
    cache=False,  # Set to True to cache images for faster training (requires more RAM)
    pretrained=True,
    optimizer='AdamW',
    verbose=True,
    seed=42,
    deterministic=True,
    single_cls=True,  # Single class (word)
    rect=False,
    cos_lr=True,  # Cosine learning rate scheduler
    close_mosaic=10,  # Disable mosaic augmentation for last N epochs
    resume=False,  # Resume from last checkpoint
    amp=True,  # Automatic Mixed Precision
    fraction=1.0,  # Use 100% of dataset
    profile=False,
    freeze=None,  # Freeze layers
    # Data augmentation
    hsv_h=0.015,
    hsv_s=0.7,
    hsv_v=0.4,
    degrees=0.0,
    translate=0.1,
    scale=0.5,
    shear=0.0,
    perspective=0.0,
    flipud=0.0,
    fliplr=0.0,
    mosaic=1.0,
    mixup=0.0,
    copy_paste=0.0
)

print("\n" + "=" * 70)
print("TRAINING COMPLETE!")
print("=" * 70)


### CELL 7: Evaluate Model Performance

In [None]:
print("=" * 70)
print("MODEL EVALUATION")
print("=" * 70)

# Validate on validation set
val_results = model.val()

print(f"\nValidation Results:")
print(f"  mAP50: {val_results.box.map50:.4f}")
print(f"  mAP50-95: {val_results.box.map:.4f}")
print(f"  Precision: {val_results.box.mp:.4f}")
print(f"  Recall: {val_results.box.mr:.4f}")


### CELL 8: Visualize Training Results


In [None]:
def plot_training_results(results_dir):
    """Plot training metrics"""
    results_csv = os.path.join(results_dir, 'results.csv')
    
    if not os.path.exists(results_csv):
        print(f"Results file not found: {results_csv}")
        return
    
    # Read results
    df = pd.read_csv(results_csv)
    df.columns = df.columns.str.strip()
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot 1: Loss
    axes[0, 0].plot(df['epoch'], df['train/box_loss'], label='Train Box Loss', marker='o')
    axes[0, 0].plot(df['epoch'], df['train/cls_loss'], label='Train Class Loss', marker='s')
    axes[0, 0].plot(df['epoch'], df['train/dfl_loss'], label='Train DFL Loss', marker='^')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Losses')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Plot 2: mAP
    if 'metrics/mAP50(B)' in df.columns:
        axes[0, 1].plot(df['epoch'], df['metrics/mAP50(B)'], label='mAP@0.5', marker='o', color='green')
        axes[0, 1].plot(df['epoch'], df['metrics/mAP50-95(B)'], label='mAP@0.5:0.95', marker='s', color='blue')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('mAP')
        axes[0, 1].set_title('Mean Average Precision')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
    
    # Plot 3: Precision and Recall
    if 'metrics/precision(B)' in df.columns:
        axes[1, 0].plot(df['epoch'], df['metrics/precision(B)'], label='Precision', marker='o', color='purple')
        axes[1, 0].plot(df['epoch'], df['metrics/recall(B)'], label='Recall', marker='s', color='orange')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Score')
        axes[1, 0].set_title('Precision and Recall')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
    
    # Plot 4: Learning Rate
    if 'lr/pg0' in df.columns:
        axes[1, 1].plot(df['epoch'], df['lr/pg0'], label='Learning Rate', marker='o', color='red')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_title('Learning Rate Schedule')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'training_metrics.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved training metrics to {results_dir}/training_metrics.png")

# Plot results
results_dir = 'runs/detect/khmer_text_detection'
plot_training_results(results_dir)

### CELL 9: Display Confusion Matrix and Other Plots


In [None]:
# Display confusion matrix
confusion_matrix_path = os.path.join(results_dir, 'confusion_matrix.png')
if os.path.exists(confusion_matrix_path):
    print("\nConfusion Matrix:")
    display(Image(filename=confusion_matrix_path, width=600))
else:
    print("Confusion matrix not found")

# Display F1 curve
f1_curve_path = os.path.join(results_dir, 'F1_curve.png')
if os.path.exists(f1_curve_path):
    print("\nF1 Score Curve:")
    display(Image(filename=f1_curve_path, width=600))

# Display PR curve
pr_curve_path = os.path.join(results_dir, 'PR_curve.png')
if os.path.exists(pr_curve_path):
    print("\nPrecision-Recall Curve:")
    display(Image(filename=pr_curve_path, width=600))


### CELL 10: Test on Validation Images

In [None]:
def test_on_samples(model, dataset_path, num_samples=6, conf_threshold=0.25):
    """Test model on validation samples"""
    val_img_dir = os.path.join(dataset_path, 'val', 'images')
    
    # Get random images
    images = [f for f in os.listdir(val_img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
    samples = random.sample(images, min(num_samples, len(images)))
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    for idx, img_name in enumerate(samples):
        img_path = os.path.join(val_img_dir, img_name)
        
        # Run prediction
        results = model.predict(img_path, conf=conf_threshold, verbose=False)
        
        # Plot results
        result_img = results[0].plot()
        result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
        
        num_detections = len(results[0].boxes)
        
        axes[idx].imshow(result_img)
        axes[idx].set_title(f'{img_name}\nDetections: {num_detections}')
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'validation_predictions.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved predictions to {results_dir}/validation_predictions.png")

import random
test_on_samples(model, DATASET_PATH, num_samples=6, conf_threshold=0.25)


### CELL 11: Export Model

In [None]:
print("=" * 70)
print("EXPORTING MODEL")
print("=" * 70)

# Get best model path
best_model_path = os.path.join(results_dir, 'weights', 'best.pt')

if os.path.exists(best_model_path):
    # Load best model
    best_model = YOLO(best_model_path)
    
    # Export to ONNX format (for deployment)
    onnx_path = best_model.export(format='onnx')
    print(f"✓ Exported to ONNX: {onnx_path}")
    
    # Export to TorchScript (for PyTorch deployment)
    torchscript_path = best_model.export(format='torchscript')
    print(f"✓ Exported to TorchScript: {torchscript_path}")
    
    print(f"\nModel weights saved at:")
    print(f"  Best: {best_model_path}")
    print(f"  Last: {os.path.join(results_dir, 'weights', 'last.pt')}")
else:
    print(f"Best model not found at {best_model_path}")

### CELL 12: Inference on Custom Image


In [None]:
def predict_on_image(model_path, image_path, conf_threshold=0.25, save_path=None):
    """Run inference on a single image"""
    model = YOLO(model_path)
    
    # Run prediction
    results = model.predict(image_path, conf=conf_threshold, verbose=True)
    
    # Plot results
    result_img = results[0].plot()
    result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
    
    # Display
    plt.figure(figsize=(12, 8))
    plt.imshow(result_img)
    plt.axis('off')
    plt.title(f'Detections: {len(results[0].boxes)} words')
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"✓ Saved prediction to {save_path}")
    
    plt.show()
    
    # Print detection details
    print(f"\nDetection Details:")
    for i, box in enumerate(results[0].boxes):
        conf = box.conf[0].item()
        cls = int(box.cls[0].item())
        xyxy = box.xyxy[0].tolist()
        print(f"  Box {i+1}: Class={cls}, Confidence={conf:.3f}, BBox={xyxy}")

# Example usage (replace with your image path)
# predict_on_image(
#     model_path=best_model_path,
#     image_path='path/to/your/test_image.png',
#     conf_threshold=0.25,
#     save_path=os.path.join(results_dir, 'custom_prediction.png')
# )

print("\n" + "=" * 70)
print("✓ Training notebook complete!")
print("=" * 70)
print(f"\nNext steps:")
print(f"  1. Check results in: {results_dir}")
print(f"  2. Use best model: {best_model_path}")
print(f"  3. Run inference on new images using the predict_on_image() function")