In [3]:
!pip install gradio
!



In [4]:
"""
OCR Model Testing User Interface
================================
A Gradio-based UI for testing different OCR architectures trained on Arabic text.

Supported Architectures:
1. CNN-BiLSTM-CTC (Baseline)
2. Conformer-CTC
3. TrOCR-SMALL
4. ViT + OCR

Usage:
    python user_interface.py
"""

import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from pathlib import Path
import math

# ====================== Device Setup ======================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# ====================== Default Charset (Arabic characters + common symbols) ======================
# This is a default charset - should match the one used during training
DEFAULT_CHARSET = ['<BLANK>'] + list("ÿ°ÿ¢ÿ£ÿ§ÿ•ÿ¶ÿßÿ®ÿ©ÿ™ÿ´ÿ¨ÿ≠ÿÆÿØÿ∞ÿ±ÿ≤ÿ≥ÿ¥ÿµÿ∂ÿ∑ÿ∏ÿπÿ∫ŸÄŸÅŸÇŸÉŸÑŸÖŸÜŸáŸàŸâŸäŸãŸåŸçŸéŸèŸêŸëŸí 0123456789.,ÿü!ÿõ:()-")

# For seq2seq models (TrOCR, ViT+OCR)
PAD_TOKEN = '<PAD>'
BOS_TOKEN = '<BOS>'
SOS_TOKEN = '<SOS>'
EOS_TOKEN = '<EOS>'
UNK_TOKEN = '<UNK>'

# ====================== Image Processing Constants ======================
H_TARGET = 64
W_MAX = 512
W_TARGET = 384  # For TrOCR and ViT+OCR
MEAN, STD = 0.5, 0.5

# ====================== Image Preprocessing Functions ======================
def preprocess_image_ctc(img_array, h=H_TARGET, w_max=W_MAX):
    """Preprocess image for CTC-based models (Baseline, Conformer)."""
    if img_array is None:
        return None

    # Convert to grayscale if needed
    if len(img_array.shape) == 3:
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
    else:
        gray = img_array

    h0, w0 = gray.shape[:2]
    scale = h / float(h0)
    new_w = int(np.ceil(w0 * scale))
    new_w = min(new_w, w_max)
    new_w = max(new_w, 16)

    interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_CUBIC
    gray = cv2.resize(gray, (new_w, h), interpolation=interp)

    # Pad to multiple of 4
    pad_w = (4 - (new_w % 4)) % 4
    if pad_w > 0:
        gray = cv2.copyMakeBorder(gray, 0, 0, 0, pad_w, cv2.BORDER_CONSTANT, value=255)

    # Normalize
    img = gray.astype(np.float32) / 255.0
    img = (img - MEAN) / STD
    img_t = torch.from_numpy(img)[None, None, ...]  # (1, 1, H, W)

    return img_t

def preprocess_image_trocr(img_array, h=H_TARGET, w=W_TARGET):
    """Preprocess image for TrOCR (3-channel, fixed size)."""
    if img_array is None:
        return None

    # Convert to grayscale if needed
    if len(img_array.shape) == 3:
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
    else:
        gray = img_array

    h0, w0 = gray.shape[:2]
    scale = h / h0
    new_w = int(w0 * scale)

    if new_w > w:
        scale = w / w0
        new_h = int(h0 * scale)
        new_w = w
        gray = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC)
        pad_top = (h - new_h) // 2
        pad_bottom = h - new_h - pad_top
        gray = cv2.copyMakeBorder(gray, pad_top, pad_bottom, 0, 0, cv2.BORDER_CONSTANT, value=255)
    else:
        gray = cv2.resize(gray, (new_w, h), interpolation=cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC)
        pad_right = w - new_w
        gray = cv2.copyMakeBorder(gray, 0, 0, 0, pad_right, cv2.BORDER_CONSTANT, value=255)

    # Convert to 3-channel and normalize
    img = np.stack([gray, gray, gray], axis=-1)
    img = img.astype(np.float32) / 255.0
    img = (img - 0.5) / 0.5
    img_t = torch.from_numpy(img).permute(2, 0, 1)[None, ...]  # (1, 3, H, W)

    return img_t

def preprocess_image_vit(img_array, h=H_TARGET, w=W_TARGET):
    """Preprocess image for ViT+OCR (1-channel, fixed size)."""
    if img_array is None:
        return None

    # Convert to grayscale if needed
    if len(img_array.shape) == 3:
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
    else:
        gray = img_array

    h0, w0 = gray.shape[:2]
    scale = h / h0
    new_w = int(w0 * scale)
    new_w = min(new_w, w)

    gray = cv2.resize(gray, (new_w, h), interpolation=cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC)

    if new_w < w:
        pad_w = w - new_w
        gray = cv2.copyMakeBorder(gray, 0, 0, 0, pad_w, cv2.BORDER_CONSTANT, value=255)

    # Normalize
    img = gray.astype(np.float32) / 255.0
    img = (img - MEAN) / STD
    img_t = torch.from_numpy(img)[None, None, ...]  # (1, 1, H, W)

    return img_t


# ====================== Model Definitions ======================

# ----- 1. CNN-BiLSTM-CTC (Baseline) -----
from torchvision import models

class ResNet34OCRBackbone(nn.Module):
    """ResNet34 backbone adapted for OCR (preserves width dimension)."""
    def __init__(self, in_ch=1):
        super().__init__()
        try:
            m = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        except:
            m = models.resnet34(pretrained=True)

        with torch.no_grad():
            w = m.conv1.weight.data
            new_w = w.mean(dim=1, keepdim=True)
        m.conv1 = nn.Conv2d(in_ch, 64, kernel_size=7, stride=(2, 1), padding=3, bias=False)
        with torch.no_grad():
            m.conv1.weight.copy_(new_w)

        m.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)

        def set_layer_height_only(layer):
            b0 = layer[0]
            b0.conv1.stride = (2, 1)
            if b0.downsample is not None and isinstance(b0.downsample, nn.Sequential):
                ds0 = b0.downsample[0]
                if isinstance(ds0, nn.Conv2d):
                    ds0.stride = (2, 1)

        set_layer_height_only(m.layer2)
        set_layer_height_only(m.layer3)
        set_layer_height_only(m.layer4)

        self.stem = nn.Sequential(m.conv1, m.bn1, m.relu, m.maxpool)
        self.layer1 = m.layer1
        self.layer2 = m.layer2
        self.layer3 = m.layer3
        self.layer4 = m.layer4
        self.out = nn.AdaptiveAvgPool2d((1, None))

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        fmap = x
        seq = self.out(fmap).squeeze(2).permute(2, 0, 1)
        return seq, fmap

class TemporalGLUBlock(nn.Module):
    def __init__(self, c=512, k=5, dilation=1, p=0.1):
        super().__init__()
        pad = (k - 1) // 2 * dilation
        self.conv = nn.Conv1d(c, 2 * c, kernel_size=k, padding=pad, dilation=dilation)
        self.glu = nn.GLU(dim=1)
        self.drop = nn.Dropout(p)
        self.res = nn.Conv1d(c, c, kernel_size=1)

    def forward(self, x):
        x = x.permute(1, 2, 0)
        y = self.glu(self.conv(x))
        y = self.drop(y) + self.res(x)
        return y.permute(2, 0, 1)

class BiLSTMStack(nn.Module):
    def __init__(self, in_dim, hidden=256, num_layers=2, dropout=0.25):
        super().__init__()
        layers = []
        for i in range(num_layers):
            inp = in_dim if i == 0 else hidden * 2
            layers.append(nn.LSTM(inp, hidden, bidirectional=True, batch_first=False))
            layers.append(nn.Dropout(dropout))
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        for i in range(0, len(self.layers), 2):
            x, _ = self.layers[i](x)
            x = self.layers[i + 1](x)
        return x

class OCR_MultiTask_V3(nn.Module):
    """Baseline CNN-BiLSTM-CTC Model"""
    def __init__(self, num_chars, num_styles, blank_idx=0, aux_weight=0.3):
        super().__init__()
        self.backbone = ResNet34OCRBackbone(in_ch=1)
        self.tcn1 = TemporalGLUBlock(512, k=5, dilation=1, p=0.1)
        self.tcn2 = TemporalGLUBlock(512, k=5, dilation=2, p=0.1)
        self.lstm = BiLSTMStack(in_dim=512, hidden=256, num_layers=2, dropout=0.25)
        self.ctc_head_main = nn.Linear(512, num_chars)
        self.ctc_head_aux = nn.Linear(512, num_chars)
        self.style_head = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
            nn.Dropout(0.25),
            nn.Linear(512, max(256, num_styles * 4)), nn.ReLU(True),
            nn.Linear(max(256, num_styles * 4), num_styles)
        )
        self.blank_idx = blank_idx
        self.aux_weight = aux_weight

    def forward(self, x):
        seq_feats, fmap = self.backbone(x)
        tcn_out = self.tcn2(self.tcn1(seq_feats))
        lstm_out = self.lstm(tcn_out)
        ctc_logits_main = self.ctc_head_main(lstm_out)
        ctc_logits_aux = self.ctc_head_aux(tcn_out)
        style_logits = self.style_head(fmap.detach())
        return ctc_logits_main, ctc_logits_aux, style_logits


# ----- 2. Conformer-CTC -----
class ConvSubsampling(nn.Module):
    def __init__(self, in_channels=1, out_channels=256):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=(2, 1), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=(2, 1), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=(2, 1), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, out_channels, kernel_size=3, stride=(2, 1), padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, None))

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        x = x.squeeze(2)
        x = x.permute(2, 0, 1)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class ConformerConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=31, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.pointwise_conv1 = nn.Conv1d(d_model, 2 * d_model, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(d_model, d_model, kernel_size=kernel_size,
                                         padding=(kernel_size - 1) // 2, groups=d_model)
        self.batch_norm = nn.BatchNorm1d(d_model)
        self.swish = nn.SiLU()
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.layer_norm(x)
        x = x.permute(1, 2, 0)
        x = self.pointwise_conv1(x)
        x = self.glu(x)
        x = self.depthwise_conv(x)
        x = self.batch_norm(x)
        x = self.swish(x)
        x = self.pointwise_conv2(x)
        x = self.dropout(x)
        x = x.permute(2, 0, 1)
        return x + residual

class ConformerFeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, d_ff)
        self.swish = nn.SiLU()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.layer_norm(x)
        x = self.linear1(x)
        x = self.swish(x)
        x = self.dropout1(x)
        x = self.linear2(x)
        x = self.dropout2(x)
        return 0.5 * x + residual

class ConformerMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads=8, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        residual = x
        x = self.layer_norm(x)
        x, _ = self.mha(x, x, x, key_padding_mask=mask)
        x = self.dropout(x)
        return x + residual

class ConformerBlock(nn.Module):
    def __init__(self, d_model=256, d_ff=1024, num_heads=4, kernel_size=31, dropout=0.1):
        super().__init__()
        self.ff1 = ConformerFeedForward(d_model, d_ff, dropout)
        self.mha = ConformerMultiHeadAttention(d_model, num_heads, dropout)
        self.conv = ConformerConvModule(d_model, kernel_size, dropout)
        self.ff2 = ConformerFeedForward(d_model, d_ff, dropout)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        x = self.ff1(x)
        x = self.mha(x, mask)
        x = self.conv(x)
        x = self.ff2(x)
        x = self.layer_norm(x)
        return x

class ConformerEncoder(nn.Module):
    def __init__(self, d_model=256, d_ff=1024, num_heads=4, num_layers=6, kernel_size=31, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            ConformerBlock(d_model, d_ff, num_heads, kernel_size, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return x

class ConformerCTC(nn.Module):
    """Conformer-CTC Model for Arabic OCR"""
    def __init__(self, num_chars, num_styles, d_model=256, d_ff=1024, num_heads=4,
                 num_layers=8, kernel_size=31, dropout=0.1, blank_idx=0):
        super().__init__()
        self.d_model = d_model
        self.blank_idx = blank_idx
        self.conv_subsample = ConvSubsampling(in_channels=1, out_channels=d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len=2048, dropout=dropout)
        self.conformer = ConformerEncoder(d_model=d_model, d_ff=d_ff, num_heads=num_heads,
                                          num_layers=num_layers, kernel_size=kernel_size, dropout=dropout)
        self.ctc_head = nn.Linear(d_model, num_chars)
        self.style_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.Linear(d_model, num_styles)
        )

    def forward(self, x, mask=None):
        x = self.conv_subsample(x)
        x = self.pos_encoding(x)
        x = self.conformer(x, mask)
        ctc_logits = self.ctc_head(x)
        pooled = x.mean(dim=0)
        style_logits = self.style_head(pooled)
        return ctc_logits, style_logits


# ----- 3. TrOCR-SMALL -----
try:
    import timm
    TIMM_AVAILABLE = True
except ImportError:
    TIMM_AVAILABLE = False
    print("Warning: timm not available. TrOCR-SMALL will not work.")

class TrOCRDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=384, nhead=6, num_layers=6,
                 dim_feedforward=1536, dropout=0.1, max_seq_len=128, pad_idx=0):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)

        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        self.pos_drop = nn.Dropout(dropout)

        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead,
                                                    dim_feedforward=dim_feedforward,
                                                    dropout=dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_proj = nn.Linear(d_model, vocab_size)

    def generate_square_subsequent_mask(self, sz, device):
        mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
        return mask

    def forward(self, encoder_output, tgt, tgt_key_padding_mask=None):
        tgt_emb = self.token_embedding(tgt)
        tgt_emb = tgt_emb + self.pe[:, :tgt.size(1)]
        tgt_emb = self.pos_drop(tgt_emb)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1), tgt.device)
        output = self.transformer_decoder(tgt_emb, encoder_output, tgt_mask=tgt_mask,
                                          tgt_key_padding_mask=tgt_key_padding_mask)
        logits = self.output_proj(output)
        return logits

class TrOCRSmall(nn.Module):
    """TrOCR-Small: Vision Transformer encoder + Transformer decoder"""
    def __init__(self, vocab_size, num_styles, d_model=384, nhead=6,
                 num_decoder_layers=6, dim_feedforward=1536, dropout=0.1,
                 max_seq_len=128, img_size=(64, 384), pad_idx=0, bos_idx=1, eos_idx=2):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size
        self.bos_idx = bos_idx
        self.eos_idx = eos_idx

        if TIMM_AVAILABLE:
            self.encoder = timm.create_model('deit_small_patch16_224', pretrained=False,
                                              img_size=img_size, in_chans=3, num_classes=0)
            encoder_dim = self.encoder.embed_dim
        else:
            raise ImportError("timm is required for TrOCR-SMALL")

        if encoder_dim != d_model:
            self.encoder_proj = nn.Linear(encoder_dim, d_model)
        else:
            self.encoder_proj = nn.Identity()

        self.decoder = TrOCRDecoder(vocab_size=vocab_size, d_model=d_model, nhead=nhead,
                                     num_layers=num_decoder_layers, dim_feedforward=dim_feedforward,
                                     dropout=dropout, max_seq_len=max_seq_len, pad_idx=pad_idx)

        self.style_head = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.Linear(d_model // 2, num_styles)
        )

    def encode(self, x):
        features = self.encoder.forward_features(x)
        features = self.encoder_proj(features)
        return features

    def forward(self, x, tgt, tgt_key_padding_mask=None):
        encoder_output = self.encode(x)
        logits = self.decoder(encoder_output, tgt, tgt_key_padding_mask)
        cls_token = encoder_output[:, 0]
        style_logits = self.style_head(cls_token)
        return logits, style_logits

    @torch.no_grad()
    def generate(self, x, max_len=None, temperature=1.0):
        if max_len is None:
            max_len = self.max_seq_len
        batch_size = x.size(0)
        device = x.device
        encoder_output = self.encode(x)
        generated = torch.full((batch_size, 1), self.bos_idx, dtype=torch.long, device=device)

        for _ in range(max_len - 1):
            logits = self.decoder(encoder_output, generated)
            next_token_logits = logits[:, -1, :] / temperature
            next_token = next_token_logits.argmax(dim=-1, keepdim=True)
            generated = torch.cat([generated, next_token], dim=1)
            if (next_token == self.eos_idx).all():
                break
        return generated


# ----- 4. ViT + OCR -----
class PatchEmbedding(nn.Module):
    def __init__(self, img_h=64, img_w=384, patch_size=16, in_chans=1, embed_dim=384):
        super().__init__()
        self.n_patches_h = img_h // patch_size
        self.n_patches_w = img_w // patch_size
        self.n_patches = self.n_patches_h * self.n_patches_w
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class ViTEncoder(nn.Module):
    def __init__(self, img_h=64, img_w=384, patch_size=16, in_chans=1,
                 embed_dim=384, depth=6, num_heads=6, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_h, img_w, patch_size, in_chans, embed_dim)
        n_patches = self.patch_embed.n_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, embed_dim))
        self.pos_drop = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
                                                    dim_feedforward=int(embed_dim * mlp_ratio),
                                                    dropout=dropout, activation='gelu',
                                                    batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        x = self.encoder(x)
        x = self.norm(x)
        return x

class TransformerDecoderVit(nn.Module):
    def __init__(self, vocab_size, embed_dim=384, depth=4, num_heads=6,
                 mlp_ratio=4.0, dropout=0.1, max_seq_len=64, pad_idx=0):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        self.token_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        self.embed_drop = nn.Dropout(dropout)

        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads,
                                                    dim_feedforward=int(embed_dim * mlp_ratio),
                                                    dropout=dropout, activation='gelu',
                                                    batch_first=True, norm_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)
        self.output_proj = nn.Linear(embed_dim, vocab_size)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None):
        B, T = tgt.shape
        x = self.token_embed(tgt)
        x = x + self.pos_embed[:, :T, :]
        x = self.embed_drop(x)
        if tgt_mask is None:
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(T, device=tgt.device)
        x = self.decoder(x, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
        x = self.norm(x)
        logits = self.output_proj(x)
        return logits

class VisionTransformerOCR(nn.Module):
    """ViT + Transformer Decoder for OCR"""
    def __init__(self, vocab_size, num_styles, img_h=64, img_w=384, patch_size=16,
                 embed_dim=384, enc_depth=6, dec_depth=4, num_heads=6,
                 mlp_ratio=4.0, dropout=0.1, max_seq_len=64,
                 pad_idx=0, sos_idx=1, eos_idx=2):
        super().__init__()
        self.encoder = ViTEncoder(img_h=img_h, img_w=img_w, patch_size=patch_size, in_chans=1,
                                   embed_dim=embed_dim, depth=enc_depth, num_heads=num_heads,
                                   mlp_ratio=mlp_ratio, dropout=dropout)
        self.decoder = TransformerDecoderVit(vocab_size=vocab_size, embed_dim=embed_dim, depth=dec_depth,
                                              num_heads=num_heads, mlp_ratio=mlp_ratio, dropout=dropout,
                                              max_seq_len=max_seq_len, pad_idx=pad_idx)
        self.style_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_styles)
        )
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.sos_idx = sos_idx
        self.eos_idx = eos_idx

    def forward(self, imgs, tgt_ids=None, tgt_padding_mask=None):
        memory = self.encoder(imgs)
        style_logits = self.style_head(memory.mean(dim=1))
        if tgt_ids is not None:
            logits = self.decoder(tgt_ids, memory, tgt_key_padding_mask=tgt_padding_mask)
            return logits, style_logits
        else:
            return memory, style_logits

    @torch.no_grad()
    def generate(self, imgs, max_len=None, temperature=1.0):
        if max_len is None:
            max_len = self.max_seq_len
        B = imgs.size(0)
        device = imgs.device
        memory = self.encoder(imgs)
        generated = torch.full((B, 1), self.sos_idx, dtype=torch.long, device=device)

        for _ in range(max_len - 1):
            logits = self.decoder(generated, memory)
            next_logits = logits[:, -1, :] / temperature
            next_token = next_logits.argmax(dim=-1, keepdim=True)
            generated = torch.cat([generated, next_token], dim=1)
            if (next_token == self.eos_idx).all():
                break
        return generated


# ====================== Decode Functions ======================
def decode_ctc(indices, idx2char, blank_idx=0, rtl_reverse=True):
    """CTC decoding for Baseline and Conformer models."""
    res, prev = [], None
    for idx in indices:
        if idx != blank_idx and idx != prev:
            if idx in idx2char:
                res.append(idx2char[idx])
        prev = idx
    s = "".join(res)
    return s[::-1] if rtl_reverse else s

def decode_seq2seq(indices, idx2char, pad_idx=0, bos_idx=1, eos_idx=2):
    """Seq2seq decoding for TrOCR and ViT+OCR models."""
    res = []
    for idx in indices:
        if idx == eos_idx:
            break
        if idx not in [pad_idx, bos_idx]:
            if idx in idx2char:
                res.append(idx2char[idx])
    return "".join(res)


# ====================== Model Loading and Inference ======================
class ModelWrapper:
    """Wrapper class to handle model loading and inference for all architectures."""

    def __init__(self):
        self.model = None
        self.model_type = None
        self.charset = None
        self.char2idx = None
        self.idx2char = None
        self.num_styles = 1
        self.config = {}

    def load_model(self, weights_path, model_type):
        """Load model from checkpoint."""
        if not Path(weights_path).exists():
            return f"Error: Weights file not found at {weights_path}"

        try:
            checkpoint = torch.load(weights_path, map_location=DEVICE, weights_only=False)

            # Try to extract charset and config from checkpoint
            if isinstance(checkpoint, dict):
                if 'charset' in checkpoint:
                    self.charset = checkpoint['charset']
                elif 'char2idx' in checkpoint:
                    self.char2idx = checkpoint['char2idx']
                    self.charset = [''] * len(self.char2idx)
                    for char, idx in self.char2idx.items():
                        self.charset[idx] = char
                else:
                    self.charset = DEFAULT_CHARSET

                if 'config' in checkpoint:
                    self.config = checkpoint['config']
                if 'num_styles' in checkpoint:
                    self.num_styles = checkpoint['num_styles']
                elif 'style_names' in checkpoint:
                    self.num_styles = len(checkpoint['style_names'])
                else:
                    self.num_styles = 5  # Default

                state_dict = checkpoint.get('model_state_dict', checkpoint.get('model', checkpoint.get('state_dict', checkpoint)))
            else:
                self.charset = DEFAULT_CHARSET
                state_dict = checkpoint

            # Build char2idx and idx2char
            self.char2idx = {c: i for i, c in enumerate(self.charset)}
            self.idx2char = {i: c for c, i in self.char2idx.items()}

            num_chars = len(self.charset)
            self.model_type = model_type

            # Create model based on type
            if model_type == "CNN-BiLSTM-CTC (Baseline)":
                self.model = OCR_MultiTask_V3(num_chars=num_chars, num_styles=self.num_styles, blank_idx=0)
            elif model_type == "Conformer-CTC":
                self.model = ConformerCTC(num_chars=num_chars, num_styles=self.num_styles,
                                          d_model=256, d_ff=1024, num_heads=4, num_layers=8,
                                          kernel_size=31, dropout=0.1, blank_idx=0)
            elif model_type == "TrOCR-SMALL":
                if not TIMM_AVAILABLE:
                    return "Error: TrOCR-SMALL requires the 'timm' library. Install with: pip install timm"
                self.model = TrOCRSmall(vocab_size=num_chars, num_styles=self.num_styles,
                                        d_model=384, nhead=6, num_decoder_layers=6,
                                        dim_feedforward=1536, dropout=0.1, max_seq_len=128,
                                        img_size=(64, 384), pad_idx=0, bos_idx=1, eos_idx=2)
            elif model_type == "ViT + OCR":
                self.model = VisionTransformerOCR(vocab_size=num_chars, num_styles=self.num_styles,
                                                  img_h=64, img_w=384, patch_size=16,
                                                  embed_dim=384, enc_depth=6, dec_depth=4,
                                                  num_heads=6, mlp_ratio=4.0, dropout=0.1,
                                                  max_seq_len=64, pad_idx=0, sos_idx=1, eos_idx=2)
            else:
                return f"Error: Unknown model type '{model_type}'"

            # Load state dict (handle potential key mismatches)
            try:
                self.model.load_state_dict(state_dict, strict=True)
            except RuntimeError as e:
                # Try loading with strict=False to see what loaded
                self.model.load_state_dict(state_dict, strict=False)
                print(f"Warning: Some weights may not have loaded correctly: {e}")

            self.model.to(DEVICE)
            self.model.eval()

            return f"‚úÖ Model loaded successfully!\nType: {model_type}\nCharset size: {num_chars}\nDevice: {DEVICE}"

        except Exception as e:
            return f"Error loading model: {str(e)}"

    @torch.no_grad()
    def predict(self, image):
        """Run inference on an image."""
        if self.model is None:
            return "Error: No model loaded. Please load a model first."

        if image is None:
            return "Error: No image provided."

        try:
            # Preprocess based on model type
            if self.model_type in ["CNN-BiLSTM-CTC (Baseline)", "Conformer-CTC"]:
                img_t = preprocess_image_ctc(image)
            elif self.model_type == "TrOCR-SMALL":
                img_t = preprocess_image_trocr(image)
            elif self.model_type == "ViT + OCR":
                img_t = preprocess_image_vit(image)
            else:
                return "Error: Unknown model type"

            img_t = img_t.to(DEVICE)

            # Run inference
            if self.model_type == "CNN-BiLSTM-CTC (Baseline)":
                ctc_logits, _, _ = self.model(img_t)
                probs = F.softmax(ctc_logits, dim=-1)
                pred_indices = probs.argmax(dim=-1).squeeze().cpu().numpy()
                text = decode_ctc(pred_indices, self.idx2char, blank_idx=0, rtl_reverse=True)

            elif self.model_type == "Conformer-CTC":
                ctc_logits, _ = self.model(img_t)
                probs = F.softmax(ctc_logits, dim=-1)
                pred_indices = probs.argmax(dim=-1).squeeze().cpu().numpy()
                text = decode_ctc(pred_indices, self.idx2char, blank_idx=0, rtl_reverse=True)

            elif self.model_type == "TrOCR-SMALL":
                generated = self.model.generate(img_t, max_len=128, temperature=1.0)
                pred_indices = generated[0].cpu().numpy()
                text = decode_seq2seq(pred_indices, self.idx2char, pad_idx=0, bos_idx=1, eos_idx=2)

            elif self.model_type == "ViT + OCR":
                generated = self.model.generate(img_t, max_len=64, temperature=1.0)
                pred_indices = generated[0].cpu().numpy()
                text = decode_seq2seq(pred_indices, self.idx2char, pad_idx=0, bos_idx=1, eos_idx=2)

            return text if text else "(empty prediction)"

        except Exception as e:
            return f"Error during inference: {str(e)}"


# ====================== Gradio Interface ======================
model_wrapper = ModelWrapper()

def load_model_handler(weights_file, model_type):
    """Handler for loading model."""
    if weights_file is None:
        return "Please upload a weights file (.pt)"
    return model_wrapper.load_model(weights_file.name, model_type)

def predict_handler(image):
    """Handler for prediction."""
    return model_wrapper.predict(image)

# Build the UI
with gr.Blocks(title="Arabic OCR Model Tester", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # üî§ Arabic OCR Model Tester

    Test your trained OCR models on Arabic text images. Upload your model weights and an image to see the prediction.

    **Supported Architectures:**
    - CNN-BiLSTM-CTC (Baseline)
    - Conformer-CTC
    - TrOCR-SMALL
    - ViT + OCR
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 1Ô∏è‚É£ Load Model")
            model_type = gr.Dropdown(
                choices=[
                    "CNN-BiLSTM-CTC (Baseline)",
                    "Conformer-CTC",
                    "TrOCR-SMALL",
                    "ViT + OCR"
                ],
                value="CNN-BiLSTM-CTC (Baseline)",
                label="Model Architecture"
            )
            weights_file = gr.File(label="Upload Model Weights (.pt file)", file_types=[".pt", ".pth"])
            load_btn = gr.Button("Load Model", variant="primary")
            load_status = gr.Textbox(label="Load Status", lines=4, interactive=False)

        with gr.Column(scale=1):
            gr.Markdown("### 2Ô∏è‚É£ Test Prediction")
            input_image = gr.Image(label="Upload Image", type="numpy")
            predict_btn = gr.Button("Predict", variant="primary")
            prediction_output = gr.Textbox(label="Prediction", lines=3, interactive=False, rtl=True)

    # Connect handlers
    load_btn.click(
        fn=load_model_handler,
        inputs=[weights_file, model_type],
        outputs=load_status
    )

    predict_btn.click(
        fn=predict_handler,
        inputs=input_image,
        outputs=prediction_output
    )

    gr.Markdown("""
    ---
    ### ‚ÑπÔ∏è Instructions

    1. **Select the model architecture** that matches your trained model
    2. **Upload the weights file** (.pt file containing model state_dict)
    3. **Upload an image** containing Arabic text
    4. **Click Predict** to see the OCR result

    **Note:** The model expects the weights file to be in the same format as saved during training.
    For best results, use images with clear Arabic text on a light background.
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch(share=False)


Using device: cuda


  with gr.Blocks(title="Arabic OCR Model Tester", theme=gr.themes.Soft()) as demo:


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.
* To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>