# Handwritten Chinese OCR - Colab Training (A100 Optimized)

Train handwritten Chinese OCR model on Google Colab using CASIA-HWDB 2.0, 2.1, and 2.2 datasets.

**Optimized for A100 GPU:**
- TF32 precision (8-19x faster than FP32)
- Mixed precision training (AMP)
- Optimized data loading with prefetching
- Large batch sizes for better GPU utilization

## Quick Start
1. Upload **only data** (`HWDB2.0Train`, `HWDB2.0Test`, `HWDB2.1Train`, `HWDB2.1Test`, `HWDB2.2Train`, `HWDB2.2Test`) to `My Drive/HWDB-data/`
2. Code is automatically cloned from GitHub (always up-to-date)
3. Open this notebook in Colab with **A100 GPU runtime**
4. Run cells in order

## 1. Check GPU

In [None]:
!nvidia-smi

import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f'GPU: {gpu_name}')
    print(f'Memory: {gpu_memory:.1f} GB')
    
    if 'A100' in gpu_name:
        print('\n✓ A100 detected! Optimal settings will be used:')
        print('  - Batch size: 32-64')
        print('  - TF32 precision: enabled')
        print('  - Mixed precision (AMP): enabled')
    elif 'V100' in gpu_name:
        print('\n✓ V100 detected. Recommended batch size: 16-32')
    elif 'T4' in gpu_name:
        print('\n✓ T4 detected. Recommended batch size: 8-16')
    else:
        print(f'\n✓ {gpu_name} detected.')

## 2. Mount Google Drive

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
print('✓ Drive mounted')

## 3. Clone Project from GitHub

Code is cloned fresh from GitHub, so you always get the latest updates.

In [None]:
# Remove old clone if exists
!rm -rf handwritten-chinese-ocr-samples

# Clone latest code from GitHub
!git clone https://github.com/AndrewCullacino/handwritten-chinese-ocr-samples.git
%cd handwritten-chinese-ocr-samples

print('✓ Latest code cloned from GitHub')

## 4. Install Dependencies

In [None]:
!pip install -q -r requirements.txt
print('✓ Dependencies installed')

## 5. Link Data from Google Drive

**Important:** Your data should be in `My Drive/HWDB-data/` with this structure:
```
My Drive/HWDB-data/
├── HWDB2.0Train/
│   ├── *.dgrl files
├── HWDB2.0Test/
│   ├── *.dgrl files
├── HWDB2.1Train/
│   ├── *.dgrl files
├── HWDB2.1Test/
│   ├── *.dgrl files
├── HWDB2.2Train/
│   ├── *.dgrl files
└── HWDB2.2Test/
    ├── *.dgrl files
```

In [None]:
# Verify data exists in Drive
DATA_DIR = '/content/drive/MyDrive/HWDB-data'

# Define all dataset versions to use
DATASET_VERSIONS = ['2.0', '2.1', '2.2']

if not os.path.exists(DATA_DIR):
    print(f'✗ Error: {DATA_DIR} not found')
    print('Please upload HWDB2.x datasets to My Drive/HWDB-data/')
else:
    print(f'Data verification:')
    
    total_train_dgrl = 0
    total_test_dgrl = 0
    all_train_exist = True
    all_test_exist = True
    
    for version in DATASET_VERSIONS:
        train_dir = f'{DATA_DIR}/HWDB{version}Train'
        test_dir = f'{DATA_DIR}/HWDB{version}Test'
        
        train_exists = os.path.exists(train_dir)
        test_exists = os.path.exists(test_dir)
        
        if not train_exists:
            all_train_exist = False
        if not test_exists:
            all_test_exist = False
        
        print(f"\n  HWDB{version}:")
        print(f"    {'✓' if train_exists else '✗'} Train: {train_dir}")
        if train_exists:
            dgrl_count = len([f for f in os.listdir(train_dir) if f.endswith('.dgrl')])
            print(f'      → Found {dgrl_count} .dgrl files')
            total_train_dgrl += dgrl_count
        
        print(f"    {'✓' if test_exists else '✗'} Test:  {test_dir}")
        if test_exists:
            dgrl_count = len([f for f in os.listdir(test_dir) if f.endswith('.dgrl')])
            print(f'      → Found {dgrl_count} .dgrl files')
            total_test_dgrl += dgrl_count
    
    print(f'\n  Summary:')
    print(f'    Total train .dgrl files: {total_train_dgrl}')
    print(f'    Total test .dgrl files:  {total_test_dgrl}')
    
    if all_train_exist and all_test_exist:
        print('\n✓ All data ready for preprocessing')
    else:
        print('\n⚠ Some datasets are missing. Preprocessing will use available data.')

## 6. Preprocess Dataset

Extract text line images from DGRL files using `dgrl2png.py`.

**Preprocessed data is saved to Google Drive** (`My Drive/HWDB-data/preprocessed/`) for persistence across sessions. A symlink is created to `data/hwdb2.0/` for training.

In [None]:
# Preprocessed data is saved to Google Drive for persistence
# Then symlinked to local path for faster training I/O

DRIVE_DATA_DIR = '/content/drive/MyDrive/HWDB-data/preprocessed'
LOCAL_DATA_DIR = 'data/hwdb2.x'

# Dataset versions to process
DATASET_VERSIONS = ['2.0', '2.1', '2.2']

# Check if preprocessed data exists in Drive
if os.path.exists(f'{DRIVE_DATA_DIR}/train_img_id_gt.txt'):
    print('✓ Found preprocessed data in Google Drive')
    
    # Create symlink to local path for training
    !rm -rf {LOCAL_DATA_DIR}
    !mkdir -p data
    !ln -s {DRIVE_DATA_DIR} {LOCAL_DATA_DIR}
    
    with open(f'{LOCAL_DATA_DIR}/train_img_id_gt.txt', 'r') as f:
        train_samples = len(f.readlines())
    with open(f'{LOCAL_DATA_DIR}/val_img_id_gt.txt', 'r') as f:
        val_samples = len(f.readlines())
    with open(f'{LOCAL_DATA_DIR}/test_img_id_gt.txt', 'r') as f:
        test_samples = len(f.readlines())
    print(f'  Train: {train_samples} | Val: {val_samples} | Test: {test_samples}')
    print(f'  Symlinked: {DRIVE_DATA_DIR} -> {LOCAL_DATA_DIR}')

else:
    print('Preprocessing DGRL files from HWDB 2.0, 2.1, 2.2 using dgrl2png.py...')
    print(f'Output will be saved to: {DRIVE_DATA_DIR}')
    
    # Create output directories in Google Drive (persistent)
    !mkdir -p {DRIVE_DATA_DIR}/train {DRIVE_DATA_DIR}/val {DRIVE_DATA_DIR}/test
    
    # Process each dataset version
    all_train_lines = []
    all_test_lines = []
    
    for idx, version in enumerate(DATASET_VERSIONS):
        train_dir = f'/content/drive/MyDrive/HWDB-data/HWDB{version}Train'
        test_dir = f'/content/drive/MyDrive/HWDB-data/HWDB{version}Test'
        
        # Extract training data
        if os.path.exists(train_dir):
            print(f'\n[{idx*2 + 1}/{len(DATASET_VERSIONS)*2}] Extracting HWDB{version} training data...')
            !python utils/casia-hwdb-data-preparation/dgrl2png.py \
                {train_dir} \
                {DRIVE_DATA_DIR}/train \
                --image_height 128
            
            # Read and collect ground truth
            gt_file = f'{DRIVE_DATA_DIR}/train/dgrl_img_gt.txt'
            if os.path.exists(gt_file):
                with open(gt_file, 'r', encoding='utf-8') as f:
                    lines = [line.strip() for line in f.readlines() if line.strip()]
                    all_train_lines.extend(lines)
                print(f'    → Collected {len(lines)} samples')
        else:
            print(f'\n⚠ Skipping HWDB{version}Train (not found)')
        
        # Extract test data
        if os.path.exists(test_dir):
            print(f'\n[{idx*2 + 2}/{len(DATASET_VERSIONS)*2}] Extracting HWDB{version} test data...')
            !python utils/casia-hwdb-data-preparation/dgrl2png.py \
                {test_dir} \
                {DRIVE_DATA_DIR}/test \
                --image_height 128
            
            # Read and collect ground truth
            gt_file = f'{DRIVE_DATA_DIR}/test/dgrl_img_gt.txt'
            if os.path.exists(gt_file):
                with open(gt_file, 'r', encoding='utf-8') as f:
                    lines = [line.strip() for line in f.readlines() if line.strip()]
                    all_test_lines.extend(lines)
                print(f'    → Collected {len(lines)} samples')
        else:
            print(f'\n⚠ Skipping HWDB{version}Test (not found)')
    
    # Create train/val split from all collected training data
    print('\n[Final] Creating train/val split from all datasets...')
    
    import random
    import shutil
    
    # Shuffle and split (90% train, 10% val)
    random.seed(42)
    random.shuffle(all_train_lines)
    
    val_size = int(len(all_train_lines) * 0.1)
    val_lines = all_train_lines[:val_size]
    train_lines = all_train_lines[val_size:]
    
    # Move val images to val folder
    print(f'  Moving {len(val_lines)} images to validation set...')
    for line in val_lines:
        img_name = line.split(',')[0]
        src = f'{DRIVE_DATA_DIR}/train/{img_name}'
        dst = f'{DRIVE_DATA_DIR}/val/{img_name}'
        if os.path.exists(src):
            shutil.move(src, dst)
    
    # Write metadata files to Drive
    with open(f'{DRIVE_DATA_DIR}/train_img_id_gt.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(train_lines))
    
    with open(f'{DRIVE_DATA_DIR}/val_img_id_gt.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(val_lines))
    
    with open(f'{DRIVE_DATA_DIR}/test_img_id_gt.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(all_test_lines))
    
    # Generate character list from all data
    all_chars = set()
    for line in train_lines + val_lines + all_test_lines:
        if ',' in line:
            text = line.split(',', 1)[1]
            all_chars.update(text)
    
    with open(f'{DRIVE_DATA_DIR}/chars_list.txt', 'w', encoding='utf-8') as f:
        for char in sorted(all_chars):
            f.write(char + '\n')
    
    # Create symlink for training
    !rm -rf {LOCAL_DATA_DIR}
    !mkdir -p data
    !ln -s {DRIVE_DATA_DIR} {LOCAL_DATA_DIR}
    
    print(f'\n✓ Preprocessing complete - saved to Google Drive')
    print(f'  Train: {len(train_lines)} | Val: {len(val_lines)} | Test: {len(all_test_lines)}')
    print(f'  Characters: {len(all_chars)}')
    print(f'  Location: {DRIVE_DATA_DIR}')
    print(f'  Symlinked to: {LOCAL_DATA_DIR}')

## 7. Verify Preprocessed Dataset

In [None]:
DATASET_PATH = 'data/hwdb2.x'

if os.path.exists(f'{DATASET_PATH}/train_img_id_gt.txt'):
    # Count samples per split
    splits = {}
    for split in ['train', 'val', 'test']:
        gt_file = f'{DATASET_PATH}/{split}_img_id_gt.txt'
        if os.path.exists(gt_file):
            with open(gt_file, 'r', encoding='utf-8') as f:
                splits[split] = len(f.readlines())
    
    # Count characters
    with open(f'{DATASET_PATH}/chars_list.txt', 'r', encoding='utf-8') as f:
        chars = len(f.readlines())
    
    print(f'Dataset ready (HWDB 2.0 + 2.1 + 2.2):')
    print(f'  Train: {splits.get("train", 0):,} samples')
    print(f'  Val:   {splits.get("val", 0):,} samples')
    print(f'  Test:  {splits.get("test", 0):,} samples')
    print(f'  Character vocab: {chars:,}')
    print(f'  Location: {DATASET_PATH}/')
else:
    print('✗ Dataset not found. Please run preprocessing first.')

## 8. Training

Train the model with A100-optimized settings:
- **TF32 precision**: 8-19x faster than FP32 on A100
- **Mixed precision (AMP)**: Further speedup with FP16/BF16
- **Large batch sizes**: Better GPU utilization
- **Gradient clipping**: Prevents exploding gradients in RNN

Checkpoints will be saved locally in Colab.

In [None]:
import torch

DATASET_PATH = 'data/hwdb2.x'

# Auto-detect optimal batch size based on GPU
# Note: max_width=1200 is enforced in main.py to prevent OOM on long text lines
# The model is large (~38M params), so conservative batch sizes are needed
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else ''
if 'A100' in gpu_name:
    BATCH_SIZE = 16        # A100 40GB: 16-24 (conservative for large model)
    NUM_WORKERS = 4
    print(f'✓ A100 detected - using batch_size={BATCH_SIZE}')
elif 'V100' in gpu_name:
    BATCH_SIZE = 12        # V100 16GB: 8-12
    NUM_WORKERS = 4
    print(f'✓ V100 detected - using batch_size={BATCH_SIZE}')
elif 'T4' in gpu_name:
    BATCH_SIZE = 8         # T4 16GB: 4-8
    NUM_WORKERS = 2
    print(f'✓ T4 detected - using batch_size={BATCH_SIZE}')
else:
    BATCH_SIZE = 8
    NUM_WORKERS = 2
    print(f'Using default batch_size={BATCH_SIZE}')

EPOCHS = 50             # Recommended: 30-50 epochs
PRINT_FREQ = 100        # Print every N batches
VAL_FREQ = 5000         # Validate every N batches

print(f'\nTraining configuration:')
print(f'  Batch size: {BATCH_SIZE}')
print(f'  Epochs: {EPOCHS}')
print(f'  Workers: {NUM_WORKERS}')
print(f'  Dataset: {DATASET_PATH} (HWDB 2.0 + 2.1 + 2.2)')
print(f'  Max image width: 1200 (enforced in main.py)')

!python main.py -m hctr \
    -d {DATASET_PATH} \
    -b {BATCH_SIZE} \
    -ep {EPOCHS} \
    -pf {PRINT_FREQ} \
    -vf {VAL_FREQ} \
    -j {NUM_WORKERS}

## 9. Find Best Model

In [None]:
import glob

model_files = glob.glob('hctr_*.pth.tar')

if model_files:
    print('Saved models:')
    for f in sorted(model_files):
        size_mb = os.path.getsize(f) / (1024*1024)
        print(f'  {f} ({size_mb:.1f} MB)')
    
    acc_models = [f for f in model_files if 'acc' in f]
    BEST_MODEL = sorted(acc_models)[-1] if acc_models else 'hctr_checkpoint.pth.tar'
    print(f'\n✓ Best model: {BEST_MODEL}')
else:
    print('No models found')
    BEST_MODEL = None

## 10. Evaluation

In [None]:
DATASET_PATH = 'data/hwdb2.x'

if BEST_MODEL and os.path.exists(BEST_MODEL):
    print(f'Evaluating: {BEST_MODEL}\n')
    !python test.py -m hctr \
        -f {BEST_MODEL} \
        -i {DATASET_PATH} \
        -b 16 \
        -bm \
        -dm greedy-search \
        -pf 20
else:
    print('Model not found')

## 11. Save Checkpoints to Drive

**Important:** Save models to Drive to prevent loss when Colab disconnects.

In [None]:
import shutil

# Save to Drive
CHECKPOINT_DIR = '/content/drive/MyDrive/HWDB-data/checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

model_files = glob.glob('hctr_*.pth.tar')
if model_files:
    for f in model_files:
        dst = os.path.join(CHECKPOINT_DIR, f)
        shutil.copy2(f, dst)
        print(f'✓ Saved: {dst}')
    print(f'\n✓ All models saved to Drive: {CHECKPOINT_DIR}')
else:
    print('No models to save')

## 12. Resume Training (Optional)

If Colab disconnects, run cells 1-7 to restore environment, then run this cell to resume training.

In [None]:
# Copy checkpoint from Drive
CHECKPOINT_DIR = '/content/drive/MyDrive/HWDB-data/checkpoints'
checkpoint_file = 'hctr_checkpoint.pth.tar'

if os.path.exists(f'{CHECKPOINT_DIR}/{checkpoint_file}'):
    !cp {CHECKPOINT_DIR}/{checkpoint_file} .
    print(f'✓ Restored checkpoint: {checkpoint_file}')
    
    # Detect GPU and set batch size (conservative values for large model)
    import torch
    gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else ''
    BATCH_SIZE = 16 if 'A100' in gpu_name else (12 if 'V100' in gpu_name else 8)
    
    # Resume training with combined dataset
    !python main.py -m hctr \
        -d data/hwdb2.x \
        -b {BATCH_SIZE} \
        -ep 100 \
        -pf 100 \
        -vf 5000 \
        -j 4 \
        -re {checkpoint_file}
else:
    print(f'✗ Checkpoint not found in {CHECKPOINT_DIR}')