# Arabic Handwritten Text Recognition
## Optimized for WER < 0.6 and CER < 0.7

In [1]:
!pip install transformers



In [2]:
!pip install torchvision datasets --quiet
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.optim import AdamW  # PyTorch's built-in AdamW
from torchvision.models import resnet18
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.1/21.1 MB[0m [31m80.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━

2025-05-12 10:18:48.503970: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747045128.708408      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747045128.765786      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
import os
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class HandwritingDataset(Dataset):
    def __init__(self, images_dir, labels_dir, tokenizer, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.tokenizer = tokenizer
        
        # Transform par défaut pour ViT si non fourni
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5]
            )
        ])

        # Lister les fichiers
        self.image_files = sorted([
            f for f in os.listdir(images_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ])
        self.label_files = sorted([
            f for f in os.listdir(labels_dir)
            if f.lower().endswith('.txt')
        ])

        # Validation stricte des noms (sans extension)
        image_stems = [os.path.splitext(f)[0] for f in self.image_files]
        label_stems = [os.path.splitext(f)[0] for f in self.label_files]

        assert image_stems == label_stems, \
            f"Les noms d'images et de labels ne correspondent pas:\n{image_stems[:5]} vs {label_stems[:5]}"

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Charger l'image
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)

        # Charger le label
        lbl_path = os.path.join(self.labels_dir, self.label_files[idx])
        with open(lbl_path, 'r', encoding='windows-1256') as f:
            text = f.read().strip()

        # Tokenization
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            padding='max_length',
            max_length=128,
            truncation=True
        )

        return {
            'pixel_values': image,
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
            'labels': inputs['input_ids'].squeeze(0),
            'raw_text': text
        }


In [6]:
from torchvision import transforms
from transformers import GPT2Tokenizer

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],  # ViT expects normalized input
        std=[0.5, 0.5, 0.5]
    )
])

tokenizer = GPT2Tokenizer.from_pretrained("aubmindlab/aragpt2-base")
tokenizer.pad_token = tokenizer.eos_token

dataset = HandwritingDataset(
    images_dir="/kaggle/input/khatt-arabic-hand-written-lines/images",
    labels_dir="/kaggle/input/khatt-arabic-hand-written-lines/labels",
    tokenizer=tokenizer,
    transform=transform
)


vocab.json:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.50M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/4.52M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

In [7]:
import torch.nn as nn
from transformers import ViTModel, GPT2LMHeadModel

class HandwritingGPT2(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Charger ViT
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        
        # Geler le ViT
        for param in self.vit.parameters():
            param.requires_grad = False
        
        # Adapter la sortie de ViT à GPT2 (ViT hidden_size = 768 → GPT2 = 768, donc optionnel ici)
        self.linear = nn.Linear(self.vit.config.hidden_size, 768)
        
        # Charger GPT2
        self.gpt2 = GPT2LMHeadModel.from_pretrained("aubmindlab/aragpt2-base")
        
    def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None):
        # Extraire les features de l'image avec ViT
        outputs = self.vit(pixel_values=pixel_values)
        features = outputs.pooler_output  # (batch_size, hidden_size)

        # Adapter les dimensions à GPT2
        features = self.linear(features)  # (batch_size, 768)

        # Étendre à (batch_size, seq_len, 768)
        if input_ids is not None:
            seq_len = input_ids.shape[1]
            features = features.unsqueeze(1).expand(-1, seq_len, -1)

        return self.gpt2(
            inputs_embeds=features,
            attention_mask=attention_mask,
            labels=labels
        )


In [9]:
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = HandwritingGPT2().to(device)

# Optimizer: ne pas inclure les paramètres gelés
optimizer = AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=5e-5
)

# Dataloader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/553M [00:00<?, ?B/s]

In [10]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

# Séparation train/val
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.1, random_state=42)

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)


In [12]:
from torchmetrics.text import CharErrorRate, WordErrorRate
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

# --- Device and model setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HandwritingGPT2().to(device)

# --- Optimizer (ignoring frozen ViT) ---
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)

# --- Tokenizer config ---
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# --- Metrics ---
cer_metric = CharErrorRate().to(device)
wer_metric = WordErrorRate().to(device)

# --- Train/Val split ---
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.1, random_state=42)
train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

# --- Training Loop ---
num_epochs = 15
for epoch in range(num_epochs):
    print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")

    # ---------- TRAIN ----------
    model.train()
    train_loss = 0
    train_cers, train_wers = [], []

    for batch in tqdm(train_loader, desc="Training"):
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        raw_text = batch["raw_text"]

        optimizer.zero_grad()
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # Prédictions
        with torch.no_grad():
            vit_features = model.vit(pixel_values=pixel_values).pooler_output
            features = model.linear(vit_features).unsqueeze(1).expand(-1, input_ids.shape[1], -1)
            generated = model.gpt2.generate(
                inputs_embeds=features,
                max_new_tokens=128,  # corrigé ici
                num_beams=5,
                early_stopping=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
            preds = [tokenizer.decode(g, skip_special_tokens=True) for g in generated]

            # Metrics
            batch_cer = cer_metric(preds, raw_text)
            batch_wer = wer_metric(preds, raw_text)
            train_cers.append(batch_cer)
            train_wers.append(batch_wer)

    avg_train_loss = train_loss / len(train_loader)
    avg_train_cer = torch.stack(train_cers).mean().item()
    avg_train_wer = torch.stack(train_wers).mean().item()
    print(f"\n[Train] Loss: {avg_train_loss:.4f} | CER: {avg_train_cer:.4f} | WER: {avg_train_wer:.4f}")

    # ---------- VALIDATION ----------
    model.eval()
    val_loss = 0
    val_cers, val_wers = [], []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            raw_text = batch["raw_text"]

            outputs = model(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            val_loss += outputs.loss.item()

            vit_features = model.vit(pixel_values=pixel_values).pooler_output
            features = model.linear(vit_features).unsqueeze(1).expand(-1, input_ids.shape[1], -1)
            generated = model.gpt2.generate(
                inputs_embeds=features,
                max_new_tokens=128,  # corrigé ici 
                num_beams=5,
                early_stopping=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
            preds = [tokenizer.decode(g, skip_special_tokens=True) for g in generated]

            val_cers.append(cer_metric(preds, raw_text))
            val_wers.append(wer_metric(preds, raw_text))

    avg_val_loss = val_loss / len(val_loader)
    avg_val_cer = torch.stack(val_cers).mean().item()
    avg_val_wer = torch.stack(val_wers).mean().item()
    print(f"[Validation] Loss: {avg_val_loss:.4f} | CER: {avg_val_cer:.4f} | WER: {avg_val_wer:.4f}")



=== Epoch 1/15 ===


Training: 100%|██████████| 1280/1280 [14:50<00:00,  1.44it/s]



[Train] Loss: 1.1532 | CER: 0.9952 | WER: 0.9987


Validation: 100%|██████████| 143/143 [01:20<00:00,  1.79it/s]


[Validation] Loss: 1.6906 | CER: 0.9726 | WER: 0.9827

=== Epoch 2/15 ===


Training: 100%|██████████| 1280/1280 [12:29<00:00,  1.71it/s]



[Train] Loss: 0.9476 | CER: 0.9862 | WER: 0.9951


Validation: 100%|██████████| 143/143 [01:04<00:00,  2.20it/s]


[Validation] Loss: 1.1186 | CER: 0.9648 | WER: 0.9926

=== Epoch 3/15 ===


Training: 100%|██████████| 1280/1280 [12:28<00:00,  1.71it/s]



[Train] Loss: 0.9145 | CER: 0.9870 | WER: 0.9941


Validation: 100%|██████████| 143/143 [01:05<00:00,  2.19it/s]


[Validation] Loss: 1.0316 | CER: 0.9775 | WER: 0.9896

=== Epoch 4/15 ===


Training: 100%|██████████| 1280/1280 [12:36<00:00,  1.69it/s]



[Train] Loss: 0.8830 | CER: 0.9876 | WER: 0.9918


Validation: 100%|██████████| 143/143 [01:06<00:00,  2.16it/s]


[Validation] Loss: 0.9915 | CER: 0.9613 | WER: 0.9881

=== Epoch 5/15 ===


Training: 100%|██████████| 1280/1280 [12:35<00:00,  1.69it/s]



[Train] Loss: 0.8599 | CER: 0.9896 | WER: 0.9926


Validation: 100%|██████████| 143/143 [01:05<00:00,  2.20it/s]


[Validation] Loss: 0.9480 | CER: 0.9702 | WER: 0.9814

=== Epoch 6/15 ===


Training: 100%|██████████| 1280/1280 [12:36<00:00,  1.69it/s]



[Train] Loss: 0.8446 | CER: 0.9939 | WER: 0.9950


Validation: 100%|██████████| 143/143 [01:04<00:00,  2.22it/s]


[Validation] Loss: 0.9066 | CER: 0.9905 | WER: 0.9948

=== Epoch 7/15 ===


Training: 100%|██████████| 1280/1280 [12:31<00:00,  1.70it/s]



[Train] Loss: 0.8288 | CER: 0.9955 | WER: 0.9965


Validation: 100%|██████████| 143/143 [01:06<00:00,  2.16it/s]


[Validation] Loss: 0.9230 | CER: 0.9899 | WER: 0.9951

=== Epoch 8/15 ===


Training: 100%|██████████| 1280/1280 [12:31<00:00,  1.70it/s]



[Train] Loss: 0.8115 | CER: 0.9974 | WER: 0.9980


Validation: 100%|██████████| 143/143 [01:06<00:00,  2.14it/s]


[Validation] Loss: 0.8980 | CER: 0.9959 | WER: 0.9946

=== Epoch 9/15 ===


Training: 100%|██████████| 1280/1280 [12:30<00:00,  1.71it/s]



[Train] Loss: 0.7972 | CER: 0.9984 | WER: 0.9986


Validation: 100%|██████████| 143/143 [01:04<00:00,  2.21it/s]


[Validation] Loss: 0.8899 | CER: 0.9985 | WER: 0.9980

=== Epoch 10/15 ===


Training: 100%|██████████| 1280/1280 [12:27<00:00,  1.71it/s]



[Train] Loss: 0.7809 | CER: 0.9991 | WER: 0.9992


Validation: 100%|██████████| 143/143 [01:05<00:00,  2.19it/s]


[Validation] Loss: 0.8852 | CER: 0.9952 | WER: 0.9957

=== Epoch 11/15 ===


Training: 100%|██████████| 1280/1280 [12:27<00:00,  1.71it/s]



[Train] Loss: 0.7651 | CER: 0.9990 | WER: 0.9992


Validation: 100%|██████████| 143/143 [01:05<00:00,  2.17it/s]


[Validation] Loss: 0.8886 | CER: 0.9969 | WER: 0.9982

=== Epoch 12/15 ===


Training: 100%|██████████| 1280/1280 [12:28<00:00,  1.71it/s]



[Train] Loss: 0.7476 | CER: 0.9987 | WER: 0.9992


Validation: 100%|██████████| 143/143 [01:04<00:00,  2.21it/s]


[Validation] Loss: 0.8913 | CER: 0.9991 | WER: 0.9991

=== Epoch 13/15 ===


Training: 100%|██████████| 1280/1280 [12:27<00:00,  1.71it/s]



[Train] Loss: 0.7277 | CER: 0.9993 | WER: 0.9995


Validation: 100%|██████████| 143/143 [01:04<00:00,  2.20it/s]


[Validation] Loss: 0.8790 | CER: 0.9997 | WER: 0.9997

=== Epoch 14/15 ===


Training: 100%|██████████| 1280/1280 [12:25<00:00,  1.72it/s]



[Train] Loss: 0.7065 | CER: 0.9991 | WER: 0.9995


Validation: 100%|██████████| 143/143 [01:04<00:00,  2.22it/s]


[Validation] Loss: 0.8957 | CER: 0.9973 | WER: 0.9976

=== Epoch 15/15 ===


Training: 100%|██████████| 1280/1280 [12:16<00:00,  1.74it/s]



[Train] Loss: 0.6820 | CER: 0.9994 | WER: 0.9998


Validation: 100%|██████████| 143/143 [01:03<00:00,  2.24it/s]

[Validation] Loss: 0.9114 | CER: 0.9966 | WER: 0.9977





In [13]:
# Sauvegarder uniquement les poids du modèle
torch.save(model.state_dict(), "handwriting_gpt2_vit.pt")
#from transformers import GPT2Tokenizer
#tokenizer = GPT2Tokenizer.from_pretrained("aragpt2-base")


In [15]:
from torchvision import transforms
from PIL import Image

def predict(image_path, model, tokenizer):
    model.eval()

    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # ✅ adapté pour ViT
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5])
    ])

    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)  # (1, 3, 224, 224)

    with torch.no_grad():
        vit_features = model.vit(pixel_values=image).pooler_output  # (1, 768)
        features = model.linear(vit_features).unsqueeze(1)  # (1, 1, 768)

        generated = model.gpt2.generate(
            inputs_embeds=features,
            max_new_tokens=128,  # ✅ éviter ValueError
            num_beams=5,
            early_stopping=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    return tokenizer.decode(generated[0], skip_special_tokens=True)


In [None]:
test_img = "/kaggle/input/dataset1/test2.png"
print("Predicted:", predict(test_img, model, tokenizer))
