# Kraken OCR Training on Google Colab

This notebook allows you to train/fine-tune Kraken OCR models for Arabic/Persian handwritten text recognition.

## Setup Instructions

1. **Enable GPU**: Go to `Runtime` > `Change runtime type` > Select `T4 GPU` (or better)
2. **Upload your data** to Google Drive before running
3. **Run cells in order** from top to bottom

---

## 1. Check GPU and Setup Environment

In [None]:
# Check GPU availability
!nvidia-smi

import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Mount Google Drive

Your training data and models will be stored in Google Drive so they persist between sessions.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create project directory in Drive
import os
PROJECT_DIR = '/content/drive/MyDrive/kraken_ocr_training'
os.makedirs(PROJECT_DIR, exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/models', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/training_data', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/checkpoints', exist_ok=True)

print(f"Project directory: {PROJECT_DIR}")
print("\nDirectory structure created:")
print("  - models/          (store base and fine-tuned models)")
print("  - training_data/   (upload your .png + .gt.txt pairs here)")
print("  - checkpoints/     (training checkpoints saved here)")

## 3. Install Kraken OCR

In [None]:
# Install Kraken with CUDA support
!pip install kraken --quiet

# Verify installation
!ketos --version

## 4. Download Base Model (Optional)

If you want to fine-tune from a pre-trained model, download one here. Skip if you already have a model in Drive.

In [None]:
# List available models
!kraken list

In [None]:
# Download a base model (uncomment and modify as needed)
# For Arabic text:
# !kraken get arabic_best -o {PROJECT_DIR}/models/

# Or copy your own model from local upload
# If you uploaded a model to Colab, copy it to Drive:
# !cp /content/your_model.mlmodel {PROJECT_DIR}/models/

## 5. Upload Training Data

**Option A**: Upload directly to Google Drive (recommended)
- Upload your training data to `kraken_ocr_training/training_data/` in Google Drive
- Each training sample needs: `image.png` + `image.gt.txt`

**Option B**: Upload a ZIP file

In [None]:
# Option B: Upload and extract a ZIP file
from google.colab import files

# Uncomment to upload a ZIP file:
# uploaded = files.upload()
# !unzip -o *.zip -d {PROJECT_DIR}/training_data/

# Check what's in training_data
!ls -la {PROJECT_DIR}/training_data/ | head -20

# Count training files
import glob
png_files = glob.glob(f'{PROJECT_DIR}/training_data/*.png')
gt_files = glob.glob(f'{PROJECT_DIR}/training_data/*.gt.txt')
print(f"\nFound {len(png_files)} images and {len(gt_files)} ground truth files")

## 6. Training Configuration

Adjust these parameters based on your needs:

In [None]:
#@title Training Configuration { display-mode: "form" }

#@markdown ### Training Mode
TRAINING_MODE = "finetune"  #@param ["finetune", "scratch", "continue"]

#@markdown ### Paths
BASE_MODEL = "/content/drive/MyDrive/kraken_ocr_training/models/all_arabic_scripts.mlmodel"  #@param {type:"string"}
TRAINING_DATA = "/content/drive/MyDrive/kraken_ocr_training/training_data"  #@param {type:"string"}
OUTPUT_MODEL = "/content/drive/MyDrive/kraken_ocr_training/checkpoints/model"  #@param {type:"string"}

#@markdown ### Hyperparameters
BATCH_SIZE = 8  #@param {type:"slider", min:1, max:32, step:1}
EPOCHS = 50  #@param {type:"slider", min:10, max:200, step:10}
LEARNING_RATE = 0.0001  #@param {type:"number"}
EARLY_STOPPING = 10  #@param {type:"slider", min:3, max:20, step:1}

#@markdown ### Options
USE_AUGMENTATION = True  #@param {type:"boolean"}
LR_SCHEDULE = "reduceonplateau"  #@param ["reduceonplateau", "cosine", "exponential", "1cycle"]

print("Configuration saved!")
print(f"  Mode: {TRAINING_MODE}")
print(f"  Base model: {BASE_MODEL}")
print(f"  Training data: {TRAINING_DATA}")
print(f"  Output: {OUTPUT_MODEL}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")

## 7. Validate Training Data

In [None]:
import os
import glob
from pathlib import Path

# Find all training images
images = glob.glob(f'{TRAINING_DATA}/*.png') + glob.glob(f'{TRAINING_DATA}/*.jpg')
print(f"Found {len(images)} training images")

# Validate ground truth files
valid_pairs = []
missing_gt = []

for img in images:
    gt_file = Path(img).with_suffix('.gt.txt')
    if gt_file.exists():
        valid_pairs.append(img)
    else:
        missing_gt.append(img)

print(f"Valid image+text pairs: {len(valid_pairs)}")

if missing_gt:
    print(f"\nWARNING: {len(missing_gt)} images missing ground truth files")
    print("First 5 missing:")
    for m in missing_gt[:5]:
        print(f"  - {os.path.basename(m)}")

if len(valid_pairs) == 0:
    print("\nERROR: No valid training pairs found!")
    print("Make sure each .png has a matching .gt.txt file")
else:
    print(f"\nReady to train with {len(valid_pairs)} samples!")

## 8. Start Training

This cell runs the actual training. It will:
- Save checkpoints to Google Drive (so you don't lose progress if disconnected)
- Display training progress with accuracy metrics

In [None]:
import subprocess
import os

# Build training command
cmd = [
    'ketos',
    '-d', 'cuda:0',
    'train',
    '-o', OUTPUT_MODEL,
    '-f', 'path',
    '-B', str(BATCH_SIZE),
    '-N', str(EPOCHS),
    '--lag', str(EARLY_STOPPING),
    '-r', str(LEARNING_RATE),
    '--schedule', LR_SCHEDULE,
]

# Add augmentation if enabled
if USE_AUGMENTATION:
    cmd.append('--augment')

# Add model based on training mode
if TRAINING_MODE == 'finetune' and os.path.exists(BASE_MODEL):
    cmd.extend(['-i', BASE_MODEL, '--resize', 'union'])
    print(f"Fine-tuning from: {BASE_MODEL}")
elif TRAINING_MODE == 'continue' and os.path.exists(f'{OUTPUT_MODEL}_best.mlmodel'):
    cmd.extend(['-i', f'{OUTPUT_MODEL}_best.mlmodel', '--resize', 'add'])
    print(f"Continuing from: {OUTPUT_MODEL}_best.mlmodel")
else:
    print("Training from scratch")

# Add training data pattern
cmd.append(f'{TRAINING_DATA}/*.png')

print(f"\nCommand: {' '.join(cmd[:10])}...")
print("\n" + "="*60)
print("Starting training... (this may take a while)")
print("="*60 + "\n")

# Run training
!{' '.join(cmd)}

## 9. Check Training Results

In [None]:
import os
import glob

# List all generated models
checkpoint_dir = os.path.dirname(OUTPUT_MODEL)
models = glob.glob(f'{checkpoint_dir}/*.mlmodel')

print("Generated models:")
for m in sorted(models):
    size_mb = os.path.getsize(m) / (1024*1024)
    print(f"  - {os.path.basename(m)} ({size_mb:.1f} MB)")

# Find best model
best_model = f'{OUTPUT_MODEL}_best.mlmodel'
if os.path.exists(best_model):
    print(f"\nBest model: {best_model}")
    print("This is the model you should use for inference.")

## 10. Test the Model

In [None]:
# Test OCR on a sample image
import glob
import random

# Get a random test image
test_images = glob.glob(f'{TRAINING_DATA}/*.png')
if test_images:
    test_image = random.choice(test_images)
    print(f"Testing on: {os.path.basename(test_image)}")
    
    # Show the image
    from IPython.display import Image, display
    display(Image(filename=test_image, width=600))
    
    # Run OCR
    best_model = f'{OUTPUT_MODEL}_best.mlmodel'
    if os.path.exists(best_model):
        print("\nOCR Output:")
        !kraken -i "{test_image}" output.txt binarize segment -bl ocr -m "{best_model}"
        !cat output.txt
        
        # Show ground truth for comparison
        gt_file = test_image.replace('.png', '.gt.txt')
        if os.path.exists(gt_file):
            print("\nGround Truth:")
            !cat "{gt_file}"
else:
    print("No test images found")

## 11. Download the Model

In [None]:
# Option 1: Model is already in Google Drive - access it directly
print(f"Your model is saved at: {OUTPUT_MODEL}_best.mlmodel")
print("You can access it from Google Drive.")

# Option 2: Download directly to your computer
from google.colab import files

best_model = f'{OUTPUT_MODEL}_best.mlmodel'
if os.path.exists(best_model):
    print("\nDownloading best model...")
    files.download(best_model)
else:
    print("Best model not found. Check if training completed successfully.")

---

## Tips for Long Training Sessions

1. **Keep the browser tab open** - Colab disconnects after inactivity
2. **Use Colab Pro** for longer sessions and better GPUs
3. **Checkpoints are saved to Drive** - you can resume if disconnected
4. **To resume training**: Change `TRAINING_MODE` to `continue` and re-run

## Troubleshooting

- **Out of memory**: Reduce `BATCH_SIZE` to 4 or 2
- **Training too slow**: Enable GPU in Runtime settings
- **Disconnected**: Change mode to `continue` and re-run from cell 8

## 12. Continue Training (After Disconnect)

If you got disconnected, run these cells to continue:

In [None]:
# Quick reconnect and continue
from google.colab import drive
drive.mount('/content/drive')

# Reinstall kraken
!pip install kraken --quiet

# Set paths (update these to match your setup)
PROJECT_DIR = '/content/drive/MyDrive/kraken_ocr_training'
TRAINING_DATA = f'{PROJECT_DIR}/training_data'
OUTPUT_MODEL = f'{PROJECT_DIR}/checkpoints/model'
CHECKPOINT = f'{OUTPUT_MODEL}_best.mlmodel'  # Latest checkpoint

# Continue training from checkpoint
!ketos -d cuda:0 train \
    -o {OUTPUT_MODEL} \
    -f path \
    -B 8 \
    -N 50 \
    --lag 10 \
    -r 0.0001 \
    --schedule reduceonplateau \
    --augment \
    -i {CHECKPOINT} \
    --resize add \
    '{TRAINING_DATA}/*.png'