# Arabic Handwritten OCR Pipeline

This notebook demonstrates an end-to-end pipeline for Arabic handwritten OCR using a CNN-BLSTM-CTC model trained on the KHATT dataset.

In [17]:
# If running on Colab, mount Google Drive (optional)
# from google.colab import drive
# drive.mount('/content/drive')

# Install requirements (uncomment if needed)
# !pip install -r project/requirements.txt

In [None]:
import sys
import os
sys.path.append('..')
from project.data.loaders import build_master_csv

# Build the master CSV for all datasets (here, just KHATT)
build_master_csv(
    config='../project/data/config.yaml',
    output_csv='../project/data/combined_labels.csv'
)

print(os.path.exists('../project/data/khatt/labels.csv'))
print(os.path.abspath('../project/data/khatt/labels.csv'))


Master CSV saved to ../project/data/combined_labels.csv with 1633 samples.
True
c:\Users\riadh\Desktop\ocr\project\data\khatt\labels.csv


: 

In [None]:

from project.data.loaders import get_dataloader
import matplotlib.pyplot as plt

# Get a batch of preprocessed images
dataloader = get_dataloader(['khatt'], batch_size=4, shuffle=False, train=False ,  master_csv='../project/data/combined_labels.csv')
images, texts = next(iter(dataloader))

# Show images side by side
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for i, ax in enumerate(axes):
    ax.imshow(images[i][0].numpy(), cmap='gray')
    ax.set_title(texts[i])
    ax.axis('off')
plt.show()

In [None]:
from project.training.train_ctc import train_ctc

# Train the model (adjust epochs and batch_size as needed)
model, vocab = train_ctc(
    epochs=10,
    batch_size=16,
    lr=1e-3,
    dataset_names=['khatt'],
    master_csv='data/combined_labels.csv',
    img_height=32,
    img_width=128
)

In [None]:
# Save the trained model and vocabulary
import torch
torch.save(model.state_dict(), 'cnn_blstm_ctc_khatt.pth')
import pickle
with open('vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)

## Next Steps

- Evaluate the model on validation/test data.
- Visualize predictions.
- Integrate with the interactive app.

In [None]:
from project.data.loaders import get_dataloader
from project.evaluation.metrics import cer, wer
import torch

# Set model to evaluation mode
model.eval()

# Get a validation dataloader (no augmentation)
val_loader = get_dataloader(['khatt'], batch_size=16, shuffle=False, train=False)

total_cer, total_wer, n = 0, 0, 0
all_preds, all_gts = [], []

with torch.no_grad():
    for images, texts in val_loader:
        images = images.to(next(model.parameters()).device)
        logits = model(images)  # (T, B, C)
        # Greedy decoding
        pred_indices = logits.argmax(-1).permute(1, 0)  # (B, T)
        for i, pred_seq in enumerate(pred_indices):
            # Collapse repeats and remove blanks (assume blank=0)
            pred = []
            prev = -1
            for idx in pred_seq.cpu().numpy():
                if idx != prev and idx != 0:
                    pred.append(idx)
                prev = idx
            pred_text = vocab.decode(pred)
            gt_text = texts[i]
            all_preds.append(pred_text)
            all_gts.append(gt_text)
            total_cer += cer(pred_text, gt_text)
            total_wer += wer(pred_text, gt_text)
            n += 1

print(f"Validation CER: {total_cer/n:.3f}")
print(f"Validation WER: {total_wer/n:.3f}")

# Show a few predictions
for i in range(5):
    print(f"GT: {all_gts[i]}")
    print(f"PR: {all_preds[i]}")
    print('-'*30)