# Complete Handwriting OCR + Q&A Segmentation Pipeline
## TrOCR + CRF Integration

**What this does:**
1. 📷 Upload handwritten exam image
2. 🔍 Remove ruled lines (preprocessing)
3. 📝 Detect text lines (segmentation)
4. 🤖 Recognize text with TrOCR (OCR)
5. 🏷️ Tag lines with CRF (Q&A segmentation)
6. 📊 Extract Q&A pairs (structured output)
7. 💾 Download JSON output

**Complete end-to-end solution!**

## 1️⃣ Install Dependencies

In [None]:
!pip install -q transformers torch pillow opencv-python-headless sklearn-crfsuite

import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image, ImageDraw
import cv2
import numpy as np
from scipy.ndimage import gaussian_filter1d
import matplotlib.pyplot as plt
from google.colab import files
import json
import pickle
import re
from typing import List, Dict, Tuple

print("✅ All dependencies installed")

## 2️⃣ Load Models

In [None]:
print("Loading TrOCR model...")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
print("✅ TrOCR model loaded")

# Note: CRF model will be loaded after upload

## 3️⃣ Upload CRF Model

Upload your trained `qa_segmentation_crf_squad.pkl` file

In [None]:
print("Upload qa_segmentation_crf_squad.pkl:")
uploaded_model = files.upload()

# Load CRF model
with open('qa_segmentation_crf_squad.pkl', 'rb') as f:
    crf_data = pickle.load(f)

crf_model = crf_data['model']
labels = crf_data['labels']

print("\n✅ CRF model loaded")
print(f"   Validation F1: {crf_data['training_stats']['val_f1']:.4f}")
print(f"   Labels: {labels}")

## 4️⃣ Preprocessing Functions

In [None]:
def remove_ruled_lines(img: np.ndarray, line_thickness_range: Tuple[int, int] = (1, 3)) -> np.ndarray:
    """
    Remove horizontal ruled lines from image
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img.copy()
    _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    
    # Detect horizontal lines
    horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
    detected_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
    
    # Remove lines from original
    result = gray.copy()
    result[detected_lines > 0] = 255
    
    return result


def segment_lines(img: np.ndarray, remove_lines: bool = True) -> List[Tuple[np.ndarray, Tuple]]:
    """
    Segment image into individual text lines (IMPROVED VERSION)
    Returns: [(line_image, (x, y, w, h)), ...]
    """
    # Preprocess
    if remove_lines:
        img = remove_ruled_lines(img)
    
    # Denoise
    denoised = cv2.fastNlMeansDenoising(img, None, 10, 7, 21)
    
    # Binarize with adaptive threshold for better handwriting
    binary = cv2.adaptiveThreshold(denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                   cv2.THRESH_BINARY_INV, 15, 10)
    
    # Dilate slightly to connect broken characters
    kernel = np.ones((2, 2), np.uint8)
    binary = cv2.dilate(binary, kernel, iterations=1)
    
    # Horizontal projection
    h_projection = np.sum(binary, axis=1)
    
    # IMPROVED: Use median instead of max, lower threshold
    threshold = np.median(h_projection[h_projection > 0]) * 0.3
    
    # Smooth projection to reduce noise
    from scipy.ndimage import gaussian_filter1d
    h_projection_smooth = gaussian_filter1d(h_projection, sigma=2)
    
    # Find line boundaries with minimum gap
    line_regions = []
    in_line = False
    start_y = 0
    min_gap = 15  # Minimum pixels between lines
    last_end = 0
    
    for y, val in enumerate(h_projection_smooth):
        if val > threshold and not in_line:
            # Start new line only if minimum gap from last line
            if y - last_end > min_gap:
                start_y = y
                in_line = True
        elif val <= threshold and in_line:
            # End line if significant height
            if y - start_y > 15:  # Minimum line height
                line_regions.append((start_y, y))
                last_end = y
            in_line = False
    
    # If still in line at end
    if in_line and len(h_projection_smooth) - start_y > 15:
        line_regions.append((start_y, len(h_projection_smooth)))
    
    # Extract line images with bounding boxes
    lines = []
    for start_y, end_y in line_regions:
        # Add margin
        start_y = max(0, start_y - 5)
        end_y = min(img.shape[0], end_y + 5)
        
        # Vertical projection to find x bounds
        line_strip = binary[start_y:end_y, :]
        v_projection = np.sum(line_strip, axis=0)
        
        # Find first and last non-zero columns
        non_zero = np.where(v_projection > 0)[0]
        if len(non_zero) > 0:
            start_x = max(0, non_zero[0] - 10)
            end_x = min(img.shape[1], non_zero[-1] + 10)
            
            # Extract line (from denoised grayscale, not binary)
            line_img = denoised[start_y:end_y, start_x:end_x]
            bbox = (start_x, start_y, end_x - start_x, end_y - start_y)
            lines.append((line_img, bbox))
    
    return lines

print("✅ Preprocessing functions defined")

## 5️⃣ TrOCR Recognition

In [None]:
def recognize_line(line_img: np.ndarray) -> str:
    """
    Recognize text in a single line using TrOCR
    """
    # Convert to PIL Image
    if len(line_img.shape) == 2:
        pil_img = Image.fromarray(line_img).convert('RGB')
    else:
        pil_img = Image.fromarray(cv2.cvtColor(line_img, cv2.COLOR_BGR2RGB))
    
    # Process with TrOCR
    pixel_values = processor(pil_img, return_tensors='pt').pixel_values
    generated_ids = trocr_model.generate(pixel_values)
    text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return text


print("✅ TrOCR recognition function defined")

## 6️⃣ CRF Feature Extraction

In [None]:
def extract_line_features(texts: List[str], bboxes: List[Tuple], line_idx: int, prev_label: str = 'O') -> Dict:
    """
    Extract 12 features for CRF from recognized text and bounding box
    """
    text = texts[line_idx].strip()
    x, y, w, h = bboxes[line_idx]
    
    # Calculate indent (x position normalized)
    indent_level = x / 100.0  # Normalize by typical margin
    
    # Vertical gap
    vertical_gap = 0
    if line_idx > 0:
        prev_y = bboxes[line_idx - 1][1] + bboxes[line_idx - 1][3]
        vertical_gap = (y - prev_y) / 20.0  # Normalize
    
    # Text features
    words = text.split()
    word_count = len(words)
    
    features = {
        # Visual
        'indent_level': min(indent_level, 3.0),
        'vertical_gap': min(vertical_gap, 5.0),
        'x_position': min(indent_level / 3.0, 1.0),
        
        # Pattern matching
        'starts_with_q_marker': bool(re.match(r'^Q\d+[:.\s]', text)),
        'starts_with_a_marker': text.startswith('A:') or text.startswith('A. '),
        'starts_with_number': text and text[0].isdigit(),
        
        # Textual
        'ends_with_question': text.endswith('?'),
        'has_colon_start': ':' in text[:15],
        'is_capitalized': text and text[0].isupper(),
        'word_count': min(word_count, 20),
        'line_length': min(len(text), 100),
        
        # Context
        'prev_label': prev_label,
    }
    
    return features


def texts_to_crf_features(texts: List[str], bboxes: List[Tuple]) -> List[Dict]:
    """
    Convert recognized texts to CRF feature format
    """
    features_sequence = []
    prev_label = 'O'
    
    for idx in range(len(texts)):
        features = extract_line_features(texts, bboxes, idx, prev_label)
        features_sequence.append(features)
    
    return features_sequence


print("✅ CRF feature extraction functions defined")

## 7️⃣ Q&A Pair Extraction

In [None]:
def extract_qa_pairs(texts: List[str], tags: List[str]) -> List[Dict]:
    """
    Extract Q&A pairs from tagged lines
    """
    pairs = []
    current_q = []
    current_a = []
    q_number = 0
    
    for text, tag in zip(texts, tags):
        if tag == 'B-Q':
            # Save previous pair
            if current_q and current_a:
                pairs.append({
                    'question_number': q_number,
                    'question': ' '.join(current_q),
                    'answer': ' '.join(current_a)
                })
            # Start new question
            current_q = [text.strip()]
            current_a = []
            q_number += 1
            
        elif tag == 'I-Q':
            current_q.append(text.strip())
            
        elif tag == 'B-A':
            current_a = [text.strip()]
            
        elif tag == 'I-A':
            current_a.append(text.strip())
    
    # Don't forget last pair
    if current_q and current_a:
        pairs.append({
            'question_number': q_number,
            'question': ' '.join(current_q),
            'answer': ' '.join(current_a)
        })
    
    return pairs


print("✅ Q&A extraction function defined")

## 8️⃣ Upload & Process Image

In [None]:
print("Upload your handwritten exam image:")
uploaded_img = files.upload()

if uploaded_img:
    img_filename = list(uploaded_img.keys())[0]
    print(f"\n✅ Uploaded: {img_filename}")
    
    # Load image
    img = cv2.imread(img_filename, cv2.IMREAD_GRAYSCALE)
    print(f"   Image size: {img.shape[1]}x{img.shape[0]} pixels")
    
    # Display original
    plt.figure(figsize=(12, 8))
    plt.imshow(img, cmap='gray')
    plt.title("Original Image")
    plt.axis('off')
    plt.show()
else:
    print("No image uploaded")

## 9️⃣ Step 1: Line Segmentation

In [None]:
print("Segmenting text lines...")
lines = segment_lines(img, remove_lines=True)
print(f"✅ Found {len(lines)} text lines")

# Visualize line detection
img_display = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
for i, (line_img, (x, y, w, h)) in enumerate(lines):
    cv2.rectangle(img_display, (x, y), (x+w, y+h), (0, 255, 0), 2)
    cv2.putText(img_display, f'Line {i+1}', (x, y-5), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)

plt.figure(figsize=(14, 10))
plt.imshow(cv2.cvtColor(img_display, cv2.COLOR_BGR2RGB))
plt.title(f"Detected {len(lines)} Lines")
plt.axis('off')
plt.show()

## 🔟 Step 2: OCR with TrOCR

In [None]:
print("Recognizing text with TrOCR...\n")
recognized_texts = []
bboxes = []

for i, (line_img, bbox) in enumerate(lines):
    print(f"Processing line {i+1}/{len(lines)}...", end=' ')
    text = recognize_line(line_img)
    recognized_texts.append(text)
    bboxes.append(bbox)
    print(f"✓ Text: {text}")

print(f"\n✅ OCR complete! Recognized {len(recognized_texts)} lines")

## 1️⃣1️⃣ Step 3: CRF Q&A Segmentation

In [None]:
print("Extracting features for CRF...")
crf_features = texts_to_crf_features(recognized_texts, bboxes)
print(f"✅ Extracted features for {len(crf_features)} lines")

print("\nPredicting Q&A tags with CRF...")
predicted_tags = crf_model.predict([crf_features])[0]
print(f"✅ Predicted tags")

# Display predictions
print("\n" + "="*80)
print("TAGGED OUTPUT")
print("="*80)
for text, tag in zip(recognized_texts, predicted_tags):
    print(f"[{tag:5s}] {text}")
print("="*80)

## 1️⃣2️⃣ Step 4: Extract Q&A Pairs

In [None]:
print("Extracting Q&A pairs...")
qa_pairs = extract_qa_pairs(recognized_texts, predicted_tags)
print(f"✅ Found {len(qa_pairs)} question-answer pairs\n")

# Display pairs
print("="*80)
print("QUESTION-ANSWER PAIRS")
print("="*80)
for pair in qa_pairs:
    print(f"\nQ{pair['question_number']}: {pair['question']}")
    print(f"A: {pair['answer']}")
    print("-" * 80)
print("="*80)

## 1️⃣3️⃣ Save & Download Results

In [None]:
# Create output JSON
output = {
    'input_image': img_filename,
    'num_lines': len(recognized_texts),
    'num_qa_pairs': len(qa_pairs),
    'qa_pairs': qa_pairs,
    'raw_text': ' '.join(recognized_texts),
    'tagged_lines': [
        {'text': text, 'tag': tag}
        for text, tag in zip(recognized_texts, predicted_tags)
    ]
}

# Save JSON
output_filename = 'qa_extraction_results.json'
with open(output_filename, 'w') as f:
    json.dump(output, f, indent=2)

print(f"\n✅ Results saved to: {output_filename}")
print(f"\n📊 Summary:")
print(f"   Lines detected: {output['num_lines']}")
print(f"   Q&A pairs found: {output['num_qa_pairs']}")

# Download
files.download(output_filename)
print(f"\n✅ Downloaded: {output_filename}")

## ✅ Complete Pipeline Success!

**What was accomplished:**
1. ✅ Preprocessed image (removed ruled lines)
2. ✅ Segmented text lines
3. ✅ Recognized handwriting with TrOCR
4. ✅ Tagged lines with CRF (Q vs A)
5. ✅ Extracted structured Q&A pairs
6. ✅ Exported JSON output

**Output includes:**
- `qa_pairs`: Structured question-answer pairs
- `raw_text`: Full recognized text
- `tagged_lines`: Line-by-line with BIO tags

**Ready for production use!** 🎉