In [None]:
# ignore if running outside Google Colab
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd drive/MyDrive/TraditionalMedicineChatbot/

# PaddleOCR

In [None]:
!pip install paddlepaddle paddlepaddle-gpu
!pip install paddleocr
!pip install "langchain==0.0.353"

In [None]:
# --- PaddleOCR Initialization ---
from paddleocr import PaddleOCR
ocr = PaddleOCR(use_angle_cls=True, lang='vi')

In [None]:
# --- OCR Usage: Process Image ---
import os
img_filename = 'sample.png'
img_path = os.path.abspath(img_filename) if os.path.exists(img_filename) else None
if img_path is None:
    raise FileNotFoundError("Could not find 'sample.png' locally. Upload it or mount Google Drive and set img_path accordingly.")
print(f'Using image: {img_path}')
try:
    result = ocr.predict(img_path)
except Exception:
    result = ocr.ocr(img_path, cls=True)

In [None]:
import json
import numpy as np

# 1. Prepare data container
output_data = []
ocr_data = result[0]

rec_texts = ocr_data.get('rec_texts', [])
rec_scores = ocr_data.get('rec_scores', [])
rec_polys = ocr_data.get('rec_polys', [])

# 2. Convert data to standard Python types (to avoid JSON errors)
if rec_texts:
    for i in range(len(rec_texts)):
        # Handle the box: Convert numpy array to list if necessary
        box = rec_polys[i]
        if isinstance(box, np.ndarray):
            box = box.tolist()

        # Handle the score: Convert numpy float to python float
        score = rec_scores[i]
        if isinstance(score, (np.float32, np.float64)):
            score = float(score)

        # Create a structured dictionary for this detection
        detection = {
            "id": i + 1,
            "text": rec_texts[i],
            "confidence": score,
            "box": box
        }
        output_data.append(detection)

# 3. Write to JSON file
output_filename = 'ocr_result.json'
try:
    with open(output_filename, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=4)
    print(f"Successfully saved {len(output_data)} detections to '{output_filename}'")
except Exception as e:
    print(f"Error saving JSON: {e}")

In [None]:
import json
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
print("Loading Text Correction model...")
model_path = "protonx-models/protonx-legal-tc"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

model.to(device)
model.eval()
print("Model loaded successfully.")

In [None]:
def correct_text_with_hf(raw_text):
    """
    Takes raw OCR text and passes it through the ProtonX Legal TC model
    to fix accents and grammar.
    """
    if not raw_text or len(str(raw_text).strip()) == 0:
        return ""

    inputs = tokenizer(
        raw_text,
        return_tensors="pt",
        truncation=True,
        max_length=128
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            num_beams=10,
            max_new_tokens=128,
            length_penalty=1.0,
            early_stopping=True,
            repetition_penalty=1.2,
            no_repeat_ngram_size=2,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
if 'result' in locals() and result and result[0] is not None:
    ocr_data = result[0]

    # Extract lists safely
    rec_texts = ocr_data.get('rec_texts', [])
    rec_scores = ocr_data.get('rec_scores', [])
    rec_polys = ocr_data.get('rec_polys', [])

    print(f"Loaded {len(rec_texts)} detected text lines.")
else:
    print("Variable 'result' is empty or not defined. Please run PaddleOCR first.")
    rec_texts, rec_scores, rec_polys = [], [], []

In [None]:
output_data = []

if rec_texts:
    total = len(rec_texts)
    for i in range(total):
        raw_text = rec_texts[i]

        # 1. Status Update
        if i % 5 == 0:
            print(f"Processing line {i+1}/{total}...")

        # 2. Run Correction
        corrected = correct_text_with_hf(raw_text)

        # 3. Handle Numpy Types for JSON serialization
        # Box
        box = rec_polys[i]
        if isinstance(box, np.ndarray):
            box = box.tolist()

        # Score
        score = rec_scores[i]
        if isinstance(score, (np.float32, np.float64)):
            score = float(score)

        # 4. Build Dictionary
        detection = {
            "id": i + 1,
            "original_text": raw_text,
            "corrected_text": corrected,
            "confidence": score,
            "box": box
        }
        output_data.append(detection)

    print("Processing complete.")
else:
    print("No text to process.")

In [None]:
output_filename = 'ocr_result_corrected.json'

try:
    with open(output_filename, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=4)
    print(f"Successfully saved {len(output_data)} detections to '{output_filename}'")
except Exception as e:
    print(f"Error saving JSON: {e}")

# Enhanced Vietnamese OCR Pipeline

This implementation combines:
- **CRAFT** for text detection
- **VietOCR** with VGG-Transformer for recognition

In [None]:
%pip install numpy==1.26.4 --no-cache-dir
%pip install Pillow==9.5.0
%pip install opencv-python==4.7.0.72 opencv-contrib-python==4.7.0.72 --no-cache-dir
%pip install torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu117
%pip install vietocr==1.3.1
%pip install craft-text-detector==0.4.3

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import json
import numpy as np
import cv2
import torch
from craft_text_detector import Craft
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg

print(np.__version__)
print(cv2.__version__)
print(torch.__version__)

# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Initialize CRAFT text detector (best for Vietnamese)
print("Loading CRAFT detector...")
craft = Craft(output_dir='./craft_output', crop_type="poly", cuda=torch.cuda.is_available())

# Initialize VietOCR with best model (VGG-Transformer)
print("Loading VietOCR recognizer...")
config = Cfg.load_config_from_name('vgg_transformer')
config['cnn']['pretrained'] = False
config['device'] = device
config['predictor']['beamsearch'] = True  # Enable beam search for better accuracy
recognizer = Predictor(config)

print("✓ Models loaded successfully!")

In [None]:
# correct imports for modern craft-text-detector
from craft_text_detector.image_utils import (
    read_image,
    normalizeMeanVariance,
    resize_aspect_ratio
)
from craft_text_detector.craft_utils import (
    getDetBoxes,
    adjustResultCoordinates
)


def craft_detect_safe(craft, image_path):
    # Read image
    image = read_image(image_path)

    img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
        image,
        craft.long_size,
        interpolation=cv2.INTER_LINEAR
    )

    ratio_h = ratio_w = 1 / target_ratio

    # Step 2 — Normalize image (same as original CRAFT)
    x = normalizeMeanVariance(img_resized)
    x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0).float()

    if craft.cuda:
        x = x.cuda()

    # Step 3 — Get score maps
    with torch.no_grad():
        y, _ = craft.craft_net(x)

    score_text = y[0, :, :, 0].cpu().numpy()
    score_link = y[0, :, :, 1].cpu().numpy()

    # Step 4 — Detect boxes
    boxes, polys = getDetBoxes(
        score_text,
        score_link,
        text_threshold=0.7,  
        link_threshold=0.8, 
        low_text=0.4
    )

    # Step 5 — Fix polys (drop Nones)
    polys = [p for p in polys if p is not None]

    # Step 6 — Adjust coordinates
    boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
    polys = adjustResultCoordinates(polys, ratio_w, ratio_h)

    return {
        "boxes": boxes,
        "polys": polys,
        "heatmap": score_text
    }

def group_boxes_to_lines(boxes, y_threshold=10):
    """
    Group word-level boxes into line-level boxes.

    Args:
        boxes: list of boxes (Nx4x2 arrays)
        y_threshold: vertical distance to consider boxes in the same line
    Returns:
        list of merged boxes per line
    """
    if len(boxes) == 0:
        return []

    # Compute the center y for each box
    box_centers = [np.mean(box[:, 1]) for box in boxes]

    # Sort boxes top to bottom
    sorted_idx = np.argsort(box_centers)
    boxes_sorted = [boxes[i] for i in sorted_idx]
    centers_sorted = [box_centers[i] for i in sorted_idx]

    lines = []
    current_line = [boxes_sorted[0]]
    current_y = centers_sorted[0]

    for i in range(1, len(boxes_sorted)):
        if abs(centers_sorted[i] - current_y) <= y_threshold:
            current_line.append(boxes_sorted[i])
            current_y = np.mean([current_y, centers_sorted[i]])
        else:
            lines.append(current_line)
            current_line = [boxes_sorted[i]]
            current_y = centers_sorted[i]

    if current_line:
        lines.append(current_line)

    # Merge boxes in each line
    merged_lines = []
    for line in lines:
        all_points = np.vstack(line)
        x_min = np.min(all_points[:, 0])
        y_min = np.min(all_points[:, 1])
        x_max = np.max(all_points[:, 0])
        y_max = np.max(all_points[:, 1])
        merged_lines.append(np.array([[x_min, y_min], [x_max, y_min],
                                      [x_max, y_max], [x_min, y_max]]))
    return merged_lines


def vietnamese_ocr_pipeline(image_path, visualize=True):
    """
    Complete OCR pipeline for Vietnamese text with best-in-class models.

    Args:
        image_path: Path to the image file
        visualize: Whether to show detection visualization

    Returns:
        List of dictionaries containing detected text and metadata
    """
    # Step 1: Read and preprocess image
    img = cv2.imread(image_path)
    if img is None:
        raise FileNotFoundError(f"Cannot read image: {image_path}")

    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    
    # Step 2: Detect text regions using CRAFT
    prediction_result = craft_detect_safe(craft, image_path)
    boxes = prediction_result['boxes']

    # Merge small word boxes into line-level boxes
    boxes_lines = group_boxes_to_lines(boxes, y_threshold=10)  
    regions = boxes_lines

    if len(regions) == 0:
        print("No text detected!")
        return []

    # Step 3: Sort regions top-to-bottom, left-to-right (reading order)
    def get_region_center(box):
        if isinstance(box, dict):
            box = box['points']
        box = np.array(box)
        center_y = np.mean(box[:, 1])
        center_x = np.mean(box[:, 0])
        return (int(center_y // 30), center_x)  # Group by rows

    regions_sorted = sorted(regions, key=get_region_center)

    # Step 4: Recognize text in each region
    results = []
    img_pil = Image.fromarray(img_rgb)

    for idx, region in enumerate(regions_sorted):
        try:
            # Extract bounding box
            if isinstance(region, dict):
                box = np.array(region['points'], dtype=np.int32)
            else:
                box = np.array(region, dtype=np.int32)

            # Get bounding rectangle with padding
            x_min = max(0, np.min(box[:, 0]) - 5)
            x_max = min(w, np.max(box[:, 0]) + 5)
            y_min = max(0, np.min(box[:, 1]) - 5)
            y_max = min(h, np.max(box[:, 1]) + 5)

            # Skip if region is too small
            if (x_max - x_min) < 10 or (y_max - y_min) < 10:
                continue

            # Crop region
            cropped = img_pil.crop((x_min, y_min, x_max, y_max))

            # Recognize text using VietOCR
            text = recognizer.predict(cropped)

            if text.strip():  # Only include non-empty results
                results.append({
                    'id': idx + 1,
                    'text': text,
                    'box': box.tolist(),
                    'bbox': [int(x_min), int(y_min), int(x_max), int(y_max)],
                    'confidence': 0.95  # CRAFT + VietOCR is highly reliable
                })

            # Progress update
            if (idx + 1) % 10 == 0:
                print(f"Processed {idx + 1}/{len(regions_sorted)} regions...")

        except Exception as e:
            print(f"Error processing region {idx}: {e}")
            continue

    print(f"\n✓ Successfully extracted {len(results)} text segments")

    # Step 5: Visualization
    if visualize and len(results) > 0:
        visualize_results(img_rgb, results)

    return results


def visualize_results(img_rgb, results):
    """Visualize OCR results on the image"""
    from PIL import ImageDraw, ImageFont

    img_pil = Image.fromarray(img_rgb)
    draw = ImageDraw.Draw(img_pil)

    # Try to load a font
    try:
        font = ImageFont.truetype("arial.ttf", 12)
    except:
        font = ImageFont.load_default()

    for result in results:
        box = np.array(result['box'])
        text = result['text']
        idx = result['id']

        # Draw bounding box
        points = [tuple(p) for p in box]
        draw.polygon(points, outline='#00FF00', width=2)

        # Draw text label 
        label = f"[{idx}] {text[:30]}..." if len(text) > 30 else f"[{idx}] {text}"
        x = int(np.min(box[:, 0]))
        y = int(np.min(box[:, 1])) - 20

        # Background rectangle
        try:
            bbox = draw.textbbox((x, y), label, font=font)
            draw.rectangle(bbox, fill='white', outline='red')
        except:
            pass

        draw.text((x, y), label, fill='red', font=font)

    # Display
    plt.figure(figsize=(20, 20))
    plt.imshow(img_pil)
    plt.axis('off')
    plt.title(f'Detected {len(results)} text regions')
    plt.show()

In [None]:
import os
img_filename = 'sample.png'
img_path = os.path.abspath(img_filename) if os.path.exists(img_filename) else None
if img_path is None:
    raise FileNotFoundError("Could not find 'sample.png' locally. Upload it or mount Google Drive and set img_path accordingly.")
img_filename = 'sample.png'
ocr_results = vietnamese_ocr_pipeline(img_filename, visualize=True)

In [None]:
# Export results to JSON
output_filename = 'vietnamese_ocr_results.json'

if 'ocr_results' in locals() and ocr_results:
    # Prepare data for export
    export_data = {
        'image': img_filename,
        'total_detections': len(ocr_results),
        'results': ocr_results
    }

    # Save to JSON
    with open(output_filename, 'w', encoding='utf-8') as f:
        json.dump(export_data, f, ensure_ascii=False, indent=2)

    print(f"✓ Saved {len(ocr_results)} detections to '{output_filename}'")

    # Also create a plain text version
    text_filename = 'vietnamese_ocr_text.txt'
    with open(text_filename, 'w', encoding='utf-8') as f:
        for result in ocr_results:
            text = result.get('text_corrected', result['text'])
            f.write(f"{text}\n")

    print(f"✓ Saved plain text to '{text_filename}'")

    # Display summary
    print("\n" + "="*60)
    print("OCR SUMMARY")
    print("="*60)
    print(f"Total text segments: {len(ocr_results)}")
    print(f"\nExtracted Text:\n")
    for i, result in enumerate(ocr_results[:10], 1):
        text = result.get('text_corrected', result['text'])
        print(f"{i}. {text}")

    if len(ocr_results) > 10:
        print(f"\n... and {len(ocr_results) - 10} more segments")
    print("="*60)
else:
    print("No OCR results to export. Run the pipeline first.")