# Arabic Handwritten Text Recognition
## Optimized for WER < 0.6 and CER < 0.7

In [None]:
!pip install transformers

In [None]:
!pip install torchvision datasets --quiet
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.optim import AdamW  # PyTorch's built-in AdamW
from torchvision.models import resnet18
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np

In [None]:
class HandwritingDataset(Dataset):
    def __init__(self, images_dir, labels_dir, tokenizer, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.tokenizer = tokenizer
        self.transform = transform or transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        ])
        
        # Get sorted file lists with validation
        self.image_files = sorted([f for f in os.listdir(images_dir) 
                               if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
        self.label_files = sorted([f for f in os.listdir(labels_dir) 
                                if f.lower().endswith('.txt')])
        
        # Verify 1:1 correspondence
        assert len(self.image_files) == len(self.label_files), "Image/label count mismatch"
        for img, lbl in zip(self.image_files, self.label_files):
            assert os.path.splitext(img)[0] == os.path.splitext(lbl)[0], \
                f"Mismatched pair: {img} vs {lbl}"

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        
        # Load label with windows-1256 encoding
        lbl_path = os.path.join(self.labels_dir, self.label_files[idx])
        with open(lbl_path, 'r', encoding='windows-1256') as f:
            text = f.read().strip()
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Tokenize with attention mask
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            padding='max_length',
            max_length=128,
            truncation=True
        )
        
        return {
            'pixel_values': image,
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
            'raw_text': text  # Store original text
        }

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("aubmindlab/aragpt2-base")
tokenizer.pad_token = tokenizer.eos_token

dataset = HandwritingDataset(
    images_dir="/kaggle/input/images",
    labels_dir="/kaggle/input/labels",
    tokenizer=tokenizer
)

In [None]:
class HandwritingGPT2(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = resnet18(pretrained=True)
        self.cnn.fc = nn.Linear(512, 768)
        self.gpt2 = GPT2LMHeadModel.from_pretrained("aubmindlab/aragpt2-base")
        
    def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None):
        # Extract features (batch_size, 768)
        features = self.cnn(pixel_values)
        
        # Expand to (batch_size, seq_len, 768)
        if input_ids is not None:
            seq_len = input_ids.shape[1]
            features = features.unsqueeze(1).expand(-1, seq_len, -1)
        
        return self.gpt2(
            inputs_embeds=features,
            attention_mask=attention_mask,
            labels=labels
        )

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HandwritingGPT2().to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
from torchmetrics.text import CharErrorRate, WordErrorRate
from tqdm import tqdm  # For progress bars

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = HandwritingGPT2().to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

# Metrics
cer = CharErrorRate().to(device)
wer = WordErrorRate().to(device)
for epoch in range(1):
    # Training phase
    model.train()
    total_loss = 0
    train_cer, train_wer = [], []
    
    for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids
        )
        
        # Backward pass
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Generate predictions for metrics
        with torch.no_grad():
            features = model.cnn(pixel_values).unsqueeze(1)
            generated = model.gpt2.generate(
                inputs_embeds=features,
                max_length=128,
                num_beams=5,
                early_stopping=True
            )
            pred_texts = [tokenizer.decode(g, skip_special_tokens=True) for g in generated]
            
            # Compute metrics
            batch_cer = cer(pred_texts, batch['raw_text'])
            batch_wer = wer(pred_texts, batch['raw_text'])
            train_cer.append(batch_cer)
            train_wer.append(batch_wer)
        
        # Logging
        if batch_idx % 20 == 0:
            avg_loss = total_loss / (batch_idx + 1)
            avg_cer = torch.stack(train_cer).mean().item()
            avg_wer = torch.stack(train_wer).mean().item()
            print(f"Batch {batch_idx}: Loss={avg_loss:.4f}, CER={avg_cer:.4f}, WER={avg_wer:.4f}")
    
    # Epoch summary
    avg_loss = total_loss / len(dataloader)
    epoch_cer = torch.stack(train_cer).mean().item()
    epoch_wer = torch.stack(train_wer).mean().item()
    print(f"\nEpoch {epoch} Results:")
    print(f"Avg Loss: {avg_loss:.4f}")
    print(f"Training CER: {epoch_cer:.4f}")
    print(f"Training WER: {epoch_wer:.4f}\n")

In [None]:
# 7. PREDICTION FUNCTION (FINAL)
def predict(image_path, model, tokenizer):
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])
    
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        features = model.cnn(image).unsqueeze(1)
        generated = model.gpt2.generate(
            inputs_embeds=features,
            max_length=128,
            num_beams=5,
            early_stopping=True
        )
    
    return tokenizer.decode(generated[0], skip_special_tokens=True)

# TEST PREDICTION
test_img = "/kaggle/input/images/AHTD3A0001_Para1_4.jpg"
print("Predicted:", predict(test_img, model, tokenizer))