# Handwritten Chinese OCR - Colab Training (A100 Optimized)

Train handwritten Chinese OCR model on Google Colab using CASIA-HWDB2.0 dataset.

**Optimized for A100 GPU:**
- TF32 precision (8-19x faster than FP32)
- Mixed precision training (AMP)
- Optimized data loading with prefetching
- Large batch sizes for better GPU utilization

## Quick Start
1. Upload **only data** (`HWDB2.0Train`, `HWDB2.0Test`) to `My Drive/HWDB-data/`
2. Code is automatically cloned from GitHub (always up-to-date)
3. Open this notebook in Colab with **A100 GPU runtime**
4. Run cells in order

## 1. Check GPU

In [1]:
!nvidia-smi

import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f'GPU: {gpu_name}')
    print(f'Memory: {gpu_memory:.1f} GB')

    if 'A100' in gpu_name:
        print('\n✓ A100 detected! Optimal settings will be used:')
        print('  - Batch size: 32-64')
        print('  - TF32 precision: enabled')
        print('  - Mixed precision (AMP): enabled')
    elif 'V100' in gpu_name:
        print('\n✓ V100 detected. Recommended batch size: 16-32')
    elif 'T4' in gpu_name:
        print('\n✓ T4 detected. Recommended batch size: 8-16')
    else:
        print(f'\n✓ {gpu_name} detected.')

Mon Jan 19 05:45:08 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   31C    P0             47W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

## 2. Mount Google Drive

In [2]:
from google.colab import drive
import os

drive.mount('/content/drive')
print('✓ Drive mounted')

Mounted at /content/drive
✓ Drive mounted


## 3. Clone Project from GitHub

Code is cloned fresh from GitHub, so you always get the latest updates.

In [4]:
# Remove old clone if exists
!rm -rf handwritten-chinese-ocr-samples

# Clone latest code from GitHub
!git clone https://github.com/AndrewCullacino/handwritten-chinese-ocr-samples.git
%cd handwritten-chinese-ocr-samples

print('✓ Latest code cloned from GitHub')

Cloning into 'handwritten-chinese-ocr-samples'...
remote: Enumerating objects: 231, done.[K
remote: Counting objects: 100% (123/123), done.[K
remote: Compressing objects: 100% (83/83), done.[K
remote: Total 231 (delta 70), reused 78 (delta 40), pack-reused 108 (from 1)[K
Receiving objects: 100% (231/231), 746.76 KiB | 26.67 MiB/s, done.
Resolving deltas: 100% (111/111), done.
/content/handwritten-chinese-ocr-samples
✓ Latest code cloned from GitHub


## 4. Install Dependencies

In [None]:
!pip install -q -r requirements.txt
print('✓ Dependencies installed')

## 5. Link Data from Google Drive

**Important:** Your data should be in `My Drive/HWDB-data/` with this structure:
```
My Drive/HWDB-data/
├── HWDB2.0Train/
│   ├── *.dgrl files
└── HWDB2.0Test/
    ├── *.dgrl files
```

In [None]:
# Verify data exists in Drive
DATA_DIR = '/content/drive/MyDrive/HWDB-data'

if not os.path.exists(DATA_DIR):
    print(f'✗ Error: {DATA_DIR} not found')
    print('Please upload HWDB2.0Train and HWDB2.0Test to My Drive/HWDB-data/')
else:
    # Check for .dgrl files
    train_dir = f'{DATA_DIR}/HWDB2.0Train'
    test_dir = f'{DATA_DIR}/HWDB2.0Test'

    train_exists = os.path.exists(train_dir)
    test_exists = os.path.exists(test_dir)

    print(f'Data verification:')
    print(f"  {'✓' if train_exists else '✗'} {train_dir}")
    print(f"  {'✓' if test_exists else '✗'} {test_dir}")

    if train_exists:
        dgrl_count = len([f for f in os.listdir(train_dir) if f.endswith('.dgrl')])
        print(f'  → Found {dgrl_count} .dgrl files in training data')

    if train_exists and test_exists:
        print('\n✓ Data ready for preprocessing')

## 6. Preprocess Dataset

Extract text line images from DGRL files using `dgrl2png.py`.

Processed data will be saved to Colab local storage (faster I/O than Drive).

In [None]:
# Check if already preprocessed
if os.path.exists('data/hwdb2.0/train_img_id_gt.txt'):
    print('✓ Data already preprocessed, skipping...')
    with open('data/hwdb2.0/train_img_id_gt.txt', 'r') as f:
        train_samples = len(f.readlines())
    with open('data/hwdb2.0/val_img_id_gt.txt', 'r') as f:
        val_samples = len(f.readlines())
    with open('data/hwdb2.0/test_img_id_gt.txt', 'r') as f:
        test_samples = len(f.readlines())
    print(f'  Train: {train_samples} | Val: {val_samples} | Test: {test_samples}')
else:
    print('Preprocessing DGRL files using dgrl2png.py...')

    # Create output directories
    !mkdir -p data/hwdb2.0/train data/hwdb2.0/val data/hwdb2.0/test

    # Extract training data
    print('\n[1/2] Extracting training data...')
    !python utils/casia-hwdb-data-preparation/dgrl2png.py \
        /content/drive/MyDrive/HWDB-data/HWDB2.0Train \
        data/hwdb2.0/train \
        --image_height 128

    # Extract test data
    print('\n[2/2] Extracting test data...')
    !python utils/casia-hwdb-data-preparation/dgrl2png.py \
        /content/drive/MyDrive/HWDB-data/HWDB2.0Test \
        data/hwdb2.0/test \
        --image_height 128

    # Create train/val split and generate metadata files
    print('\n[3/3] Creating train/val split...')

    import random
    import shutil

    # Read training ground truth
    with open('data/hwdb2.0/train/dgrl_img_gt.txt', 'r', encoding='utf-8') as f:
        train_lines = [line.strip() for line in f.readlines() if line.strip()]

    # Shuffle and split
    random.seed(42)
    random.shuffle(train_lines)

    val_size = int(len(train_lines) * 0.1)
    val_lines = train_lines[:val_size]
    train_lines = train_lines[val_size:]

    # Move val images to val folder
    for line in val_lines:
        img_name = line.split(',')[0]
        src = f'data/hwdb2.0/train/{img_name}'
        dst = f'data/hwdb2.0/val/{img_name}'
        if os.path.exists(src):
            shutil.move(src, dst)

    # Write metadata files
    with open('data/hwdb2.0/train_img_id_gt.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(train_lines))

    with open('data/hwdb2.0/val_img_id_gt.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(val_lines))

    # Copy test ground truth
    with open('data/hwdb2.0/test/dgrl_img_gt.txt', 'r', encoding='utf-8') as f:
        test_lines = f.read()
    with open('data/hwdb2.0/test_img_id_gt.txt', 'w', encoding='utf-8') as f:
        f.write(test_lines)

    # Generate character list from all data
    all_chars = set()
    for line in train_lines + val_lines + test_lines.strip().split('\n'):
        if ',' in line:
            text = line.split(',', 1)[1]
            all_chars.update(text)

    with open('data/hwdb2.0/chars_list.txt', 'w', encoding='utf-8') as f:
        for char in sorted(all_chars):
            f.write(char + '\n')

    print(f'\n✓ Preprocessing complete')
    print(f'  Train: {len(train_lines)} | Val: {len(val_lines)} | Test: {len(test_lines.strip().split(chr(10)))}')
    print(f'  Characters: {len(all_chars)}')

## 7. Verify Preprocessed Dataset

In [None]:
DATASET_PATH = 'data/hwdb2.0'

if os.path.exists(f'{DATASET_PATH}/train_img_id_gt.txt'):
    # Count samples per split
    splits = {}
    for split in ['train', 'val', 'test']:
        gt_file = f'{DATASET_PATH}/{split}_img_id_gt.txt'
        if os.path.exists(gt_file):
            with open(gt_file, 'r', encoding='utf-8') as f:
                splits[split] = len(f.readlines())

    # Count characters
    with open(f'{DATASET_PATH}/chars_list.txt', 'r', encoding='utf-8') as f:
        chars = len(f.readlines())

    print(f'Dataset ready:')
    print(f'  Train: {splits.get("train", 0):,} samples')
    print(f'  Val:   {splits.get("val", 0):,} samples')
    print(f'  Test:  {splits.get("test", 0):,} samples')
    print(f'  Character vocab: {chars:,}')
    print(f'  Location: {DATASET_PATH}/')
else:
    print('✗ Dataset not found. Please run preprocessing first.')

## 8. Training

Train the model with A100-optimized settings:
- **TF32 precision**: 8-19x faster than FP32 on A100
- **Mixed precision (AMP)**: Further speedup with FP16/BF16
- **Large batch sizes**: Better GPU utilization
- **Gradient clipping**: Prevents exploding gradients in RNN

Checkpoints will be saved locally in Colab.

In [5]:
import torch

DATASET_PATH = 'data/hwdb2.0'

# Auto-detect optimal batch size based on GPU
# Note: max_width=1200 is enforced in main.py to prevent OOM on long text lines
# The model is large (~38M params), so conservative batch sizes are needed
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else ''
if 'A100' in gpu_name:
    BATCH_SIZE = 8        # A100 40GB: 16-24 (conservative for large model)
    NUM_WORKERS = 4
    print(f'✓ A100 detected - using batch_size={BATCH_SIZE}')
elif 'V100' in gpu_name:
    BATCH_SIZE = 6        # V100 16GB: 8-12
    NUM_WORKERS = 4
    print(f'✓ V100 detected - using batch_size={BATCH_SIZE}')
elif 'T4' in gpu_name:
    BATCH_SIZE = 8         # T4 16GB: 4-8
    NUM_WORKERS = 2
    print(f'✓ T4 detected - using batch_size={BATCH_SIZE}')
else:
    BATCH_SIZE = 8
    NUM_WORKERS = 2
    print(f'Using default batch_size={BATCH_SIZE}')

EPOCHS = 30             # Recommended: 30-50 epochs
PRINT_FREQ = 100        # Print every N batches
VAL_FREQ = 5000         # Validate every N batches

print(f'\nTraining configuration:')
print(f'  Batch size: {BATCH_SIZE}')
print(f'  Epochs: {EPOCHS}')
print(f'  Workers: {NUM_WORKERS}')
print(f'  Dataset: {DATASET_PATH}')
print(f'  Max image width: 1200 (enforced in main.py)')

!python main.py -m hctr \
    -d {DATASET_PATH} \
    -b {BATCH_SIZE} \
    -ep {EPOCHS} \
    -pf {PRINT_FREQ} \
    -vf {VAL_FREQ} \
    -j {NUM_WORKERS}

✓ A100 detected - using batch_size=16

Training configuration:
  Batch size: 16
  Epochs: 50
  Workers: 4
  Dataset: data/hwdb2.0
  Max image width: 1200 (enforced in main.py)
  self.setter(val)

GPU: NVIDIA A100-SXM4-40GB
Memory: 42.5 GB
✓ A100 detected - TF32 and optimizations enabled
  - TF32 matmul: ON (8-19x faster than FP32)
  - cuDNN benchmark: ON
  - Mixed precision (AMP): ON

Traceback (most recent call last):
  File "/content/handwritten-chinese-ocr-samples/main.py", line 617, in <module>
    main()
  File "/content/handwritten-chinese-ocr-samples/main.py", line 177, in main
    main_worker(args.gpu, ngpus_per_node, args)
  File "/content/handwritten-chinese-ocr-samples/main.py", line 192, in main_worker
    model, characters = get_model_info(args)
                        ^^^^^^^^^^^^^^^^^^^^
  File "/content/handwritten-chinese-ocr-samples/main.py", line 597, in get_model_info
    with open(chars_list_file, 'r', encoding='utf-8') as f:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

## 9. Find Best Model

In [None]:
import glob

model_files = glob.glob('hctr_*.pth.tar')

if model_files:
    print('Saved models:')
    for f in sorted(model_files):
        size_mb = os.path.getsize(f) / (1024*1024)
        print(f'  {f} ({size_mb:.1f} MB)')

    acc_models = [f for f in model_files if 'acc' in f]
    BEST_MODEL = sorted(acc_models)[-1] if acc_models else 'hctr_checkpoint.pth.tar'
    print(f'\n✓ Best model: {BEST_MODEL}')
else:
    print('No models found')
    BEST_MODEL = None

## 10. Evaluation

In [None]:
if BEST_MODEL and os.path.exists(BEST_MODEL):
    print(f'Evaluating: {BEST_MODEL}\n')
    !python test.py -m hctr \
        -f {BEST_MODEL} \
        -i {DATASET_PATH} \
        -b 16 \
        -bm \
        -dm greedy-search \
        -pf 20
else:
    print('Model not found')

## 11. Save Checkpoints to Drive

**Important:** Save models to Drive to prevent loss when Colab disconnects.

In [None]:
import shutil

# Save to Drive
CHECKPOINT_DIR = '/content/drive/MyDrive/HWDB-data/checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

model_files = glob.glob('hctr_*.pth.tar')
if model_files:
    for f in model_files:
        dst = os.path.join(CHECKPOINT_DIR, f)
        shutil.copy2(f, dst)
        print(f'✓ Saved: {dst}')
    print(f'\n✓ All models saved to Drive: {CHECKPOINT_DIR}')
else:
    print('No models to save')

## 12. Resume Training (Optional)

If Colab disconnects, run cells 1-7 to restore environment, then run this cell to resume training.

In [None]:
# Copy checkpoint from Drive
CHECKPOINT_DIR = '/content/drive/MyDrive/HWDB-data/checkpoints'
checkpoint_file = 'hctr_checkpoint.pth.tar'

if os.path.exists(f'{CHECKPOINT_DIR}/{checkpoint_file}'):
    !cp {CHECKPOINT_DIR}/{checkpoint_file} .
    print(f'✓ Restored checkpoint: {checkpoint_file}')

    # Detect GPU and set batch size (conservative values for large model)
    import torch
    gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else ''
    BATCH_SIZE = 8 if 'A100' in gpu_name else (12 if 'V100' in gpu_name else 8)

    # Resume training
    !python main.py -m hctr \
        -d data/hwdb2.0 \
        -b {BATCH_SIZE} \
        -ep 100 \
        -pf 100 \
        -vf 5000 \
        -j 4 \
        -re {checkpoint_file}
else:
    print(f'✗ Checkpoint not found in {CHECKPOINT_DIR}')