# Fine-tune dots.ocr for Custom OCR Tasks

This notebook shows how to fine-tune the dots.ocr model for custom OCR tasks on Google Colab.

> **Note:** This notebook makes use of wjbmattingly's [dots.ocr training repo](https://github.com/wjbmattingly/dots.ocr).  

## What is in this notebook:
- Autolabel images using the base dots.ocr model
- Prepare training data from your custom images
- Finetune the model for better OCR on your specific content
- Test and evaluate your finetuned model

## Requirements:
- A100/L4 GPU recommended
- Images you want to train on (upload to Google Drive)

## Workflow:
1. Setup - Install dependencies and download base model
2. Auto-label - Generate initial OCR predictions (skip if you have prepared data)
3. Correct - Manually fix the generated labels (skip if you have prepared data)
4. Train- Finetune the model on your corrected data
5. Test - Evaluate the model

In [None]:
## Environment Setup

import os
import subprocess

# Set memory allocation for better GPU usage
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Check GPU availability
print("Checking GPU...")
!nvidia-smi

# Clone repository
if not os.path.exists("dots.ocr"):
    !git clone https://github.com/wjbmattingly/dots.ocr.git

%cd dots.ocr

# Install dependencies
!pip install -q -r requirements.txt -r training_requirements.txt
!pip install -e .

# Download base model weights (~6GB)
!python tools/download_model.py

print("Setup complete")


In [None]:
## Mount Google Drive

from google.colab import drive

# Mount Google Drive
drive.mount("/content/drive")

# Set up paths
IMAGES_DIR = "/content/drive/MyDrive/images"  # Upload your images here
AUTOLABEL_DIR = "/content/autolabel"  # Autolabeled results

print(f"Images directory: {IMAGES_DIR}")
print(f"Autolabel directory: {AUTOLABEL_DIR}")


Mounted at /content/drive
Images directory: /content/drive/MyDrive/images
Autolabel directory: /content/autolabel


In [None]:
## Upload Your Images

**Before proceeding:**
1. Go to Google Drive and create a folder called 'images' in your My Drive
2. Upload the images you want to train on
3. Supported formats: .jpg, .jpeg, .png, .pdf



Note: If you already have prepared training data, skip to the training section.


In [None]:
## Autolabel Your Images

import os
import glob
from tqdm import tqdm

# Create output directory
os.makedirs(AUTOLABEL_DIR, exist_ok=True)

# Find all image files
image_files = []
for ext in ["*.jpg", "*.jpeg", "*.png", "*.pdf"]:
    image_files.extend(glob.glob(os.path.join(IMAGES_DIR, ext)))

print(f"Found {len(image_files)} images to process")

if len(image_files) == 0:
    print("No images found! Please upload images to Google Drive first.")
    print(f"Expected location: {IMAGES_DIR}")
else:
    # Process each image
    successful = 0
    failed = 0

    for img_path in tqdm(image_files, desc="Autolabeling"):
        try:
            !python -m dots_ocr.parser "{img_path}" --output "{AUTOLABEL_DIR}" --prompt "prompt_ocr" --use_hf true
            successful += 1
        except Exception as e:
            print(f"Failed to process {os.path.basename(img_path)}: {e}")
            failed += 1

    print(f"\nAuto-labeling completed!")
    print(f"Successful: {successful}")
    print(f"Failed: {failed}")
    print(f"Results saved to: {AUTOLABEL_DIR}")


## Manual Correction


### How to correct your labels:

1. Download the '/content/autolabel' folder from Colab

2. Edit the '.md' files in each subfolder to correct OCR errors

3. Upload the corrected folder back to Google Drive:
   - Upload the entire corrected 'autolabel' folder to '/content/drive/MyDrive/autolabel'



Skip this section if you already have prepared training data.

## Prepare Training Data

Expected data format: Your data should be in '/content/drive/MyDrive/autolabel/' with this structure:

    autolabel/
    ├── sample1/
    │   ├── sample1.jpg
    │   ├── sample1.md
    │   └── sample1.json
    ├── sample2/
    │   ├── sample2.jpg
    │   ├── sample2.md
    │   └── sample2.json
    └── ...


### **Training JSONL Format**

Your corrected data will be converted into a `.jsonl` file for training (`train_ocr_resized.jsonl`).  
Each line in this file is one training sample in JSON format.

**Expected JSONL structure:**

```json
{"messages":[
  {"role":"user","content":[
    {"type":"image","image":"/content/resized_images/sample1.jpg"},
    {"type":"text","text":"prompt_ocr"}
  ]},
  {"role":"assistant","content":"This is the corrected OCR text for sample 1."}
]}
{"messages":[
  {"role":"user","content":[
    {"type":"image","image":"/content/resized_images/sample2.jpg"},
    {"type":"text","text":"prompt_ocr"}
  ]},
  {"role":"assistant","content":"This is the corrected OCR text for sample 2."}
]}


In [None]:
# resizing images and preparing training data
import os
import json
from PIL import Image
from tqdm import tqdm
from dots_ocr.utils.prompts import dict_promptmode_to_prompt

# Set up paths
CORRECTED_DIR = "/content/drive/MyDrive/autolabel"  # Your corrected labels
RESIZED_DIR = "/content/resized_images"
TRAINING_JSONL = "/content/drive/MyDrive/train_ocr_resized.jsonl"

os.makedirs(RESIZED_DIR, exist_ok=True)
prompt = dict_promptmode_to_prompt["prompt_ocr"]

def list_samples(base_dir):
    """List all sample directories"""
    if not os.path.exists(base_dir):
        return []
    return [d for d in sorted(os.listdir(base_dir)) if os.path.isdir(os.path.join(base_dir, d))]

def read_original_image_path(base_dir, name):
    """Get original image path from metadata"""
    jsonl_path = os.path.join(base_dir, f"{name}.jsonl")
    if os.path.exists(jsonl_path):
        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    obj = json.loads(line)
                    if "file_path" in obj:
                        return obj["file_path"]
                except:
                    pass
    # Fallback to sample directory
    return os.path.join(base_dir, name, f"{name}.jpg")

def load_corrected_text(sample_dir, name):
    """Load corrected text from .md file"""
    md_path = os.path.join(sample_dir, f"{name}.md")
    if os.path.exists(md_path):
        return open(md_path, "r", encoding="utf-8").read().strip()
    return None

def resize_image(src_path, dst_path, max_side=512):
    """Resize image to reduce memory usage during training"""
    img = Image.open(src_path).convert("RGB")
    w, h = img.size
    scale = max_side / max(w, h)
    if scale < 1.0:
        new_w, new_h = int(w * scale), int(h * scale)
        img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
    img.save(dst_path)

# Check if corrected directory exists
if not os.path.exists(CORRECTED_DIR):
    print(f"Corrected directory not found: {CORRECTED_DIR}")
    print("Please upload your corrected autolabel folder to Google Drive first.")
else:
    # Process samples
    samples = list_samples(CORRECTED_DIR)
    print(f"Found {len(samples)} corrected samples")

    if len(samples) == 0:
        print("No samples found in corrected directory!")
    else:
        training_data = []
        processed = 0
        skipped = 0

        for name in tqdm(samples, desc="Preparing training data"):
            sample_dir = os.path.join(CORRECTED_DIR, name)

            # Get original image and resize it
            orig_path = read_original_image_path(CORRECTED_DIR, name)
            resized_path = os.path.join(RESIZED_DIR, f"{name}.jpg")

            try:
                resize_image(orig_path, resized_path, max_side=512)
            except Exception as e:
                print(f"Failed to resize {name}: {e}")
                skipped += 1
                continue

            # Load corrected text
            text = load_corrected_text(sample_dir, name)
            if not text:
                print(f"No corrected text found for {name}")
                skipped += 1
                continue

            # Create training record
            record = {
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {"type": "image", "image": resized_path},
                            {"type": "text", "text": prompt}
                        ]
                    },
                    {"role": "assistant", "content": text}
                ]
            }
            training_data.append(record)
            processed += 1

        # Save training data
        with open(TRAINING_JSONL, "w", encoding="utf-8") as f:
            for record in training_data:
                f.write(json.dumps(record, ensure_ascii=False) + "\n")

        print(f"\nTraining data prepared!")
        print(f"Processed: {processed} samples")
        print(f"Skipped: {skipped} samples")
        print(f"Training file: {TRAINING_JSONL}")
        print(f"Resized images: {RESIZED_DIR}")


In [None]:
## 7. Finetune the Model

TRAINING_JSONL = "/content/drive/MyDrive/train_ocr_resized.jsonl"

# Check if training data exists
if not os.path.exists(TRAINING_JSONL):
    print(f"Training data not found: {TRAINING_JSONL}")
    print("Please run the previous step to prepare training data first.")
else:
    print("Starting finetuning...")

    # Train the model
    !python train_simple.py \
        --data "{TRAINING_JSONL}" \
        --epochs 15 \
        --batch_size 1 \
        --learning_rate 3e-4 \
        --max_length 1024 \
        --gradient_checkpointing \
        --output_dir "/content/local_checkpoints"

    print("Training completed")


In [None]:
## 8. Setup Finetuned Model

print("Copying configuration files...")

# Ensure we have base model files
!python tools/download_model.py

# Copy missing configuration files from base model
!cp ./weights/DotsOCR/configuration_dots.py /content/local_checkpoints/final_model/
!cp ./weights/DotsOCR/modeling_*.py /content/local_checkpoints/final_model/

print("Replacing base model with finetuned model...")
!rm -rf ./weights/DotsOCR
!cp -r /content/local_checkpoints/final_model ./weights/DotsOCR

print("Verifying model setup...")
!python -c "import json; json.load(open('./weights/DotsOCR/config.json')); print('Model setup complete')"

print("\nYour finetuned model is ready for inference")



In [None]:
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

def sanitize(img_path, dst_dir):
    try:
        with Image.open(img_path) as im:
            im = im.convert("RGB")
            w, h = im.size
            if max(w, h) > 1024:
                s = 1024 / max(w, h)
                im = im.resize((int(w*s), int(h*s)), Image.Resampling.LANCZOS)
            clean_path = os.path.join(dst_dir, os.path.basename(img_path))
            im.save(clean_path, format="JPEG", quality=92, optimize=True)
            return clean_path
    except Exception as e:
        print(f"[skip bad image] {img_path} -> {e}")
        return None


In [None]:
# Test OCR
# Set your file path here:
IMAGE_PATH = "/content/resized_images/SCR-20250715-iho.jpg"

from dots_ocr.parser import DotsOCRParser
import os

parser = DotsOCRParser(use_hf=True, max_completion_tokens=128)
result = parser.parse_file(IMAGE_PATH, prompt_mode="prompt_ocr")

if not result:
    print("[no result]")
else:
    info = result[0]
    text = None

    md_path = info.get("md_content_path")
    if md_path and os.path.exists(md_path):
        with open(md_path, "r", encoding="utf-8") as f:
            text = f.read()

    if text is None and isinstance(info.get("content"), str):
        text = info["content"]

    print(text or "[empty]")


In [None]:
## 10. Save the Fine-tuned Model
print("Saving fine-tuned model to Google Drive...")

!cp -r /content/local_checkpoints/final_model /content/drive/MyDrive/dots_ocr_finetuned

print("\nFine-tuned model saved to Google Drive")
print("Location: /content/drive/MyDrive/dots_ocr_finetuned")





### Resources:
- [dots.ocr GitHub](https://github.com/rednote-hilab/dots.ocr)
- [Training Documentation](https://github.com/wjbmattingly/dots.ocr/blob/main/README_model_training.md)

Happy training!