# Handwritten Chinese OCR - Colab Training

Train handwritten Chinese OCR model on Google Colab using CASIA-HWDB2.0 dataset.

## Quick Start
1. Upload project folder (with `HWDB2.0Train` and `HWDB2.0Test`) to Google Drive
2. Open this notebook in Colab with GPU runtime
3. Run cells in order

## 1. Mount Google Drive

In [None]:
from google.colab import drive
import os

if os.path.exists('/content/drive/MyDrive'):
    print('Google Drive already mounted')
else:
    drive.mount('/content/drive')

## 2. Set Project Path

Modify `PROJECT_PATH` to match your Google Drive folder location.

In [None]:
# Modify this path to your project folder in Google Drive
PROJECT_PATH = '/content/drive/MyDrive/handwritten-chinese-ocr-samples'

%cd {PROJECT_PATH}
!ls -la
print(f'\nWorking directory: {os.getcwd()}')

## 3. Install Dependencies

In [None]:
!pip install -q -r requirements.txt

import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## 4. Preprocess Dataset

The CASIA-HWDB `.dgrl` files must be converted to PNG images and text labels.

**Two preprocessing options:**
- `preprocess_dgrl.py` - Extracts text **line images** (for line-level OCR)
- `preprocess_dgrl_pages.py` - Extracts **full page images** (for page-level OCR)

**Expected output structure:**
```
data/hwdb2.0/
├── train/          # PNG images
├── val/
├── test/
├── img_gt.txt      # Format: image_name,label_text
└── chars_list.txt  # Character vocabulary
```

In [None]:
# Option A: Extract text line images (recommended for training)
!python preprocess_dgrl.py \
    --input_dir HWDB2.0Train \
    --output_dir data/hwdb2.0 \
    --target_height 128 \
    --workers 4

# Option B: Extract full page images (for page-level analysis)
# Uncomment to use this instead:
# !python preprocess_dgrl_pages.py \
#     --input_dir HWDB2.0Train \
#     --output_dir data/pages \
#     --no-viz

## 5. Verify Dataset

In [None]:
DATASET_PATH = 'data/hwdb2.0'

required_files = [
    f'{DATASET_PATH}/img_gt.txt',
    f'{DATASET_PATH}/chars_list.txt'
]

print('Dataset verification:')
for f in required_files:
    exists = os.path.exists(f)
    print(f"{'✓' if exists else '✗'} {f}")

if os.path.exists(f'{DATASET_PATH}/img_gt.txt'):
    with open(f'{DATASET_PATH}/img_gt.txt', 'r', encoding='utf-8') as f:
        count = len(f.readlines())
    print(f'\nTotal samples: {count}')
    
    with open(f'{DATASET_PATH}/chars_list.txt', 'r', encoding='utf-8') as f:
        num_chars = len(f.readlines())
    print(f'Character vocabulary: {num_chars}')

## 6. Training

In [None]:
DATASET_PATH = 'data/hwdb2.0'
BATCH_SIZE = 8
EPOCHS = 30
PRINT_FREQ = 50
NUM_WORKERS = 2

!python main.py -m hctr \
    -d {DATASET_PATH} \
    -b {BATCH_SIZE} \
    -ep {EPOCHS} \
    -pf {PRINT_FREQ} \
    -j {NUM_WORKERS}

## 7. Find Best Model

In [None]:
import glob

model_files = glob.glob('hctr_*.pth.tar')

if model_files:
    print('Saved models:')
    for f in sorted(model_files):
        size_mb = os.path.getsize(f) / (1024*1024)
        print(f'  {f} ({size_mb:.1f} MB)')
    
    # Select best model
    acc_models = [f for f in model_files if 'acc' in f]
    BEST_MODEL = sorted(acc_models)[-1] if acc_models else 'hctr_checkpoint.pth.tar'
    print(f'\nSelected: {BEST_MODEL}')
else:
    print('No models found')
    BEST_MODEL = None

## 8. Evaluation

In [None]:
DATASET_PATH = 'data/hwdb2.0'
MODEL_FILE = BEST_MODEL
BATCH_SIZE = 16

if MODEL_FILE and os.path.exists(MODEL_FILE):
    print(f'Evaluating: {MODEL_FILE}')
    !python test.py -m hctr \
        -f {MODEL_FILE} \
        -i {DATASET_PATH} \
        -b {BATCH_SIZE} \
        -bm \
        -dm greedy-search \
        -pf 20
else:
    print('Model not found')

## 9. Save Models to Drive

In [None]:
import shutil

SAVE_DIR = f'{PROJECT_PATH}/checkpoints'
os.makedirs(SAVE_DIR, exist_ok=True)

model_files = glob.glob('hctr_*.pth.tar')
for f in model_files:
    dst = os.path.join(SAVE_DIR, f)
    shutil.copy2(f, dst)
    print(f'Saved: {dst}')

print(f'\nAll models saved to: {SAVE_DIR}')