# EasyOCR Model Training Notebook

This notebook provides a comprehensive training pipeline for EasyOCR models. It includes:
- Configuration loading from YAML files
- Model training with customizable parameters
- Progress monitoring and validation
- Model checkpointing

## Getting Started
Make sure you have prepared your dataset and configuration files before running this notebook.

In [16]:
# Import required libraries and modules
import os
import sys
import time
import torch
import torch.backends.cudnn as cudnn
import yaml
import pandas as pd
import numpy as np
from datetime import datetime

# Import custom modules
from train import train
from utils import AttrDict

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name()}")

PyTorch version: 2.7.1
CUDA available: False


In [17]:
# Configure CUDNN backend for performance optimization
cudnn.benchmark = True  # Enable auto-tuner to find the best algorithm
cudnn.deterministic = False  # Allow non-deterministic algorithms for better performance

print("CUDNN Configuration:")
print(f"  - Benchmark: {cudnn.benchmark}")
print(f"  - Deterministic: {cudnn.deterministic}")
print("  - This configuration optimizes training speed but may affect reproducibility")

CUDNN Configuration:
  - Benchmark: True
  - Deterministic: False
  - This configuration optimizes training speed but may affect reproducibility


In [21]:
def get_config(file_path):
    """
    Load and process training configuration from YAML file
    
    Args:
        file_path (str): Path to the YAML configuration file
        
    Returns:
        AttrDict: Configuration object with all training parameters
    """
    print(f"Loading configuration from: {file_path}")
    
    # Load YAML configuration
    with open(file_path, 'r', encoding="utf8") as stream:
        opt = yaml.safe_load(stream)
    
    # Convert to AttrDict for easier access
    opt = AttrDict(opt)
    
    # Process character set based on configuration
    if opt.lang_char == 'None':
        print("Extracting character set from training data...")
        characters = ''
        
        # Extract characters from all selected datasets
        for data in opt['select_data'].split('-'):
            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
            print(f"  - Processing dataset: {data}")
            
            # Read labels and extract unique characters
            df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', 
                           usecols=['filename', 'words'], keep_default_na=False)
            all_char = ''.join(df['words'])
            characters += ''.join(set(all_char))
        
        # Create sorted unique character set
        characters = sorted(set(characters))
        opt.character = ''.join(characters)
        print(f"  - Extracted {len(characters)} unique characters")
    else:
        # Use predefined character set
        opt.character = opt.number + opt.symbol + opt.lang_char
        print(f"Using predefined character set: {len(opt.character)} characters")
    
    # Create output directory
    output_dir = f'./saved_models/{opt.experiment_name}'
    os.makedirs(output_dir, exist_ok=True)
    print(f"Models will be saved to: {output_dir}")
    
    # Print key configuration parameters
    print("\nKey Configuration Parameters:")
    print(f"  - Experiment name: {opt.experiment_name}")
    print(f"  - Number of iterations: {opt.num_iter}")
    print(f"  - Batch size: {opt.batch_size}")
    print(f"  - Learning rate: {opt.lr}")
    print(f"  - Image size: {opt.imgH}x{opt.imgW}")
    print(f"  - Character set length: {len(opt.character)}")
    
    return opt

In [22]:
# 🇹🇭 Load Thai OCR Configuration
config_file = 'config_files/thai_auto_config.yaml'
print(f"Loading Thai OCR configuration: {config_file}")

try:
    opt = get_config(config_file)
    print("✅ Configuration loaded successfully!")
    
    # Show key parameters only
    print(f"\nKEY SETTINGS:")
    print(f"  Experiment: {opt.experiment_name}")
    print(f"  Iterations: {opt.num_iter:,}")
    print(f"  Batch size: {opt.batch_size}")
    print(f"  Learning rate: {opt.lr}")
    print(f"  Characters: {len(opt.character)}")
    
except Exception as e:
    print(f"❌ Error loading configuration: {e}")
    raise

Loading Thai OCR configuration: config_files/thai_auto_config.yaml
Loading configuration from: config_files/thai_auto_config.yaml
Using predefined character set: 92 characters
Models will be saved to: ./saved_models/thai_auto

Key Configuration Parameters:
  - Experiment name: thai_auto
  - Number of iterations: 5000
  - Batch size: 8
  - Learning rate: 0.001
  - Image size: 64x400
  - Character set length: 92
✅ Configuration loaded successfully!

KEY SETTINGS:
  Experiment: thai_auto
  Iterations: 5,000
  Batch size: 8
  Learning rate: 0.001
  Characters: 92



DATASET VALIDATION
🔍 Checking datasets...
🔧 Checking for common issues...
   No automatic fixes needed
✅ Training data directory found: all_data
✅ Validation uses hierarchical structure from: all_data

⚠️  Found 1 issue(s):
   ❌ Dataset not found: all_data/thai_train

💡 Suggested solutions:
   1. Check that your dataset folders contain 'labels.csv' files
   2. Verify dataset paths in the configuration file
   3. Make sure validation data path is correct
   4. Consider using train_data path for validation if no separate validation set
ℹ️  Validation: using hierarchical structure

⚠️  ISSUES FOUND:
   ❌ Missing: all_data/thai_train/labels.csv


Exception: Dataset issues found

In [15]:
# 📊 Dataset Folder Check
print("📊 Checking required folders:")

# Define expected folders
REQUIRED_FOLDERS = {
    'thai_train': 'all_data/thai_train',
    'thai_val': 'all_data/thai_val'
}

all_good = True

for folder_name, folder_path in REQUIRED_FOLDERS.items():
    labels_file = os.path.join(folder_path, 'labels.csv')
    
    if os.path.exists(labels_file):
        try:
            df = pd.read_csv(labels_file, sep='^([^,]+),', engine='python', 
                           usecols=['filename', 'words'], keep_default_na=False)
            print(f"✅ {folder_name}: OK ({len(df)} samples)")
        except:
            print(f"❌ {folder_name}: labels.csv corrupted")
            all_good = False
    else:
        print(f"❌ {folder_name}: Missing or no labels.csv")
        all_good = False

print(f"\n🎯 Result: {'✅ Ready to train!' if all_good else '❌ Fix folders first'}")
print("-" * 30)

📊 Checking required folders:
❌ thai_train: Missing or no labels.csv
❌ thai_val: Missing or no labels.csv

🎯 Result: ❌ Fix folders first
------------------------------


In [None]:
# 🔧 Quick Parameter Adjustments (Optional)
print("🔧 Parameter adjustments:")

# CRITICAL PARAMETERS - Modify if needed
ADJUSTMENTS = {
    'batch_size': None,     # Reduce if out of memory (e.g., 4, 8, 16)
    'num_iter': None,       # Reduce for testing (e.g., 1000, 5000)
    'workers': None,        # Set to 0 if multiprocessing issues
}

# Apply adjustments
modified = 0
for param, value in ADJUSTMENTS.items():
    if value is not None:
        old_value = getattr(opt, param)
        setattr(opt, param, value)
        print(f"   🔄 {param}: {old_value} → {value}")
        modified += 1

if modified == 0:
    print("   ✅ Using default parameters")

# Quick memory warning
if opt.batch_size > 16:
    print(f"   ⚠️  Large batch_size ({opt.batch_size}) - reduce if CUDA out of memory")

In [None]:
# 📋 Final Training Summary
print("="*50)
print("📋 READY TO TRAIN")
print("="*50)

print(f"🎯 CRITICAL INFO:")
print(f"   📝 Experiment: {opt.experiment_name}")
print(f"   🏗️  Iterations: {opt.num_iter:,}")
print(f"   📦 Batch size: {opt.batch_size}")
print(f"   ⚡ Learning rate: {opt.lr}")
print(f"   🔄 Validation every: {opt.valInterval} iterations")

print(f"\n📊 Data:")
print(f"   📚 Training: {opt.select_data}")
print(f"   ✅ Validation: {opt.valid_data.split('/')[-1]}")

print(f"\n🖼️  Image: {opt.imgH}x{opt.imgW}")
print(f"🧠 Model: {opt.FeatureExtraction}+{opt.SequenceModeling}+{opt.Prediction}")

print(f"\n💾 Output: ./saved_models/{opt.experiment_name}/")
print("="*50)

In [25]:
# Start model training
print("="*50)
print("STARTING MODEL TRAINING")
print("="*50)

# Training parameters
use_amp = False  # Set to True to enable Automatic Mixed Precision for faster training
show_samples = 3  # Number of prediction samples to show during validation

print(f"🚀 Starting training at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"📊 Training samples to show: {show_samples}")
print(f"⚡ Mixed precision (AMP): {'Enabled' if use_amp else 'Disabled'}")

# Check if resuming from checkpoint
if opt.saved_model and opt.saved_model != '':
    print(f"🔄 Resuming training from: {opt.saved_model}")
else:
    print("🆕 Starting training from scratch")

print("\n" + "="*50)
print("TRAINING LOG")
print("="*50)

try:
    # Start training
    train(opt, show_number=show_samples, amp=use_amp)
except KeyboardInterrupt:
    print("\n⚠️  Training interrupted by user (Ctrl+C)")
    print("Model checkpoints are saved in: ./saved_models/{}/".format(opt.experiment_name))
except Exception as e:
    print(f"\n❌ Training failed with error: {e}")
    print("Please check the error details above and ensure:")
    print("1. Data paths are correct")
    print("2. Required dependencies are installed")
    print("3. GPU memory is sufficient")
    raise
finally:
    print(f"\n🏁 Training session ended at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

STARTING MODEL TRAINING
🚀 Starting training at: 2025-06-30 02:50:39
📊 Training samples to show: 3
⚡ Mixed precision (AMP): Disabled
🆕 Starting training from scratch

TRAINING LOG
Filtering the images containing characters which are not in opt.character
Filtering the images whose label is longer than opt.batch_max_length
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['thai_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root:    all_data	 dataset: thai_train

❌ Training failed with error: datasets should not be an empty iterable
Please check the error details above and ensure:
1. Data paths are correct
2. Required dependencies are installed
3. GPU memory is sufficient

🏁 Training session ended at: 2025-06-30 02:50:39


AssertionError: datasets should not be an empty iterable

In [None]:
# 📊 Quick Monitoring Tools
print("📊 Monitoring tools ready:")

def check_progress(experiment_name):
    """Quick progress check"""
    log_dir = f"./saved_models/{experiment_name}"
    
    if not os.path.exists(log_dir):
        print(f"❌ No logs yet: {log_dir}")
        return
    
    # Check models
    models = [f for f in os.listdir(log_dir) if f.endswith('.pth')]
    if models:
        print(f"💾 Models: {len(models)} saved")
        for model in sorted(models)[-3:]:  # Show last 3
            size_mb = os.path.getsize(os.path.join(log_dir, model)) / (1024*1024)
            print(f"   - {model} ({size_mb:.1f}MB)")
    
    # Check training log
    log_file = os.path.join(log_dir, "log_train.txt")
    if os.path.exists(log_file):
        with open(log_file, 'r', encoding='utf8') as f:
            lines = f.readlines()
        print(f"📄 Training log: {len(lines)} lines")
        if lines:
            print(f"   Last: {lines[-1].strip()}")

def quick_log(experiment_name, lines=5):
    """Show last few log lines"""
    log_file = f"./saved_models/{experiment_name}/log_train.txt"
    if os.path.exists(log_file):
        with open(log_file, 'r', encoding='utf8') as f:
            log_lines = f.readlines()
        print(f"📄 Last {lines} lines:")
        for line in log_lines[-lines:]:
            print(f"   {line.strip()}")
    else:
        print("❌ No training log found")

print("   📊 check_progress('experiment_name') - Check saved models")
print("   📄 quick_log('experiment_name') - Show recent logs")
print("   Example: check_progress(opt.experiment_name)")

## 💡 Essential Tips

### 🚨 Critical Issues:
- **Out of memory**: Reduce `batch_size` in cell 8
- **Slow training**: Set `workers=0` if multiprocessing issues
- **Stop training**: Ctrl+C (models auto-saved every 10k iterations)

### 📁 After Training:
Check `./saved_models/{experiment_name}/`:
- `best_accuracy.pth` - Best model
- `log_train.txt` - Training progress

### 🔧 Quick Fixes:
- Memory error → Reduce batch_size to 4 or 8
- File not found → Check dataset paths in config file
- Training stuck → Set workers=0