In [9]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""  # Force CPU

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from difflib import SequenceMatcher

# ✅ Arabic Char Normalization
def normalize_arabic(text):
    import re
    text = re.sub('[ًٌٍَُِّْ]', '', text)  # Remove harakat
    text = re.sub('[إأآا]', 'ا', text)
    text = re.sub('ى', 'ي', text)
    text = re.sub('ؤ', 'و', text)
    text = re.sub('ئ', 'ي', text)
    text = re.sub('ة', 'ه', text)
    return text

# ✅ Define Arabic Characters
arabic_chars = list("ابتثجحخدذرزسشصضطظعغفقكلمنهوي")
extra_tokens = ['<PAD>', '<SOS>', '<EOS>']
all_chars = extra_tokens + arabic_chars
char_to_idx = {char: idx for idx, char in enumerate(all_chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
PAD_IDX = char_to_idx['<PAD>']

# ✅ Config
device = torch.device("cpu")
BATCH_SIZE = 16
IMG_HEIGHT = 64
IMG_WIDTH = 256
MAX_TEXT_LEN = 32
EPOCHS = 3

# ✅ Dataset
class OCRDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_test=False):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test

    def __len__(self):
        return len(self.df)

    def encode_text(self, text):
        text = normalize_arabic(text)
        encoded = [char_to_idx[c] for c in text if c in char_to_idx]
        return encoded

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row["image"])
        image = Image.open(img_path).convert("L")
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)
        if not self.is_test:
            target = self.encode_text(row["text"])
            return image, torch.tensor(target, dtype=torch.long)
        else:
            return image, row["image"]

# ✅ Collate
def collate_fn(batch):
    images, targets = zip(*batch)
    images = torch.stack(images)
    lengths = torch.tensor([len(t) for t in targets])
    targets = torch.cat(targets)
    return images, targets, lengths

# ✅ Model
class SimpleCRNN(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # (32, 32, 128)
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # (64, 16, 64)
        )
        self.rnn = nn.LSTM(64 * 16, 128, num_layers=2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(128 * 2, vocab_size)

    def forward(self, x):
        x = self.cnn(x)  # (B, C, H, W)
        b, c, h, w = x.size()
        x = x.permute(0, 3, 1, 2)  # (B, W, C, H)
        x = x.view(b, w, -1)  # (B, W, C*H)
        x, _ = self.rnn(x)  # (B, W, 2*H)
        x = self.fc(x)  # (B, W, vocab)
        return x.log_softmax(2)

# ✅ CER
def cer(s1, s2):
    return 1 - SequenceMatcher(None, s1, s2).ratio()

# ✅ Decode
def decode_prediction(logits):
    pred_indices = logits.argmax(2).cpu().numpy()
    texts = []
    for pred in pred_indices:
        tokens = []
        last = -1
        for p in pred:
            if p != last and p != PAD_IDX:
                tokens.append(idx_to_char.get(p, ''))
            last = p
        texts.append("".join(tokens))
    return texts

# ✅ Load Data
train_df = pd.read_csv("/kaggle/input/mc-datathon-2025-arabic-manuscripts-digitization/train_df.csv", encoding="utf-8")
test_df = pd.read_csv("/kaggle/input/mc-datathon-2025-arabic-manuscripts-digitization/test_df.csv", encoding="utf-8")

# ✅ Image directories
TRAIN_IMG_DIR = "/kaggle/input/mc-datathon-2025-arabic-manuscripts-digitization/train/train"
TEST_IMG_DIR = "/kaggle/input/mc-datathon-2025-arabic-manuscripts-digitization/test/test"

transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor()
])

# ✅ Dataloaders
train_dataset = OCRDataset(train_df, TRAIN_IMG_DIR, transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

test_dataset = OCRDataset(test_df, TEST_IMG_DIR, transform, is_test=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# ✅ Model + Optim
model = SimpleCRNN(len(all_chars)).to(device)
criterion = nn.CTCLoss(blank=PAD_IDX, zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# ✅ Training Loop
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0
    for images, targets, lengths in train_loader:
        images = images.to(device)
        targets = targets.to(device)
        logits = model(images)
        log_probs = logits.permute(1, 0, 2)
        input_lengths = torch.full((logits.size(0),), log_probs.size(0), dtype=torch.long)
        loss = criterion(log_probs, targets, input_lengths, lengths)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}/{EPOCHS} - Loss: {total_loss/len(train_loader):.4f}")

    # CER Eval
    model.eval()
    with torch.no_grad():
        cers = []
        for i in range(min(20, len(train_dataset))):
            image, target = train_dataset[i]
            image = image.unsqueeze(0).to(device)
            logits = model(image)
            decoded = decode_prediction(logits)[0]
            truth = "".join([idx_to_char[i.item()] for i in target])
            cers.append(cer(decoded, truth))
        print(f"Epoch {epoch} CER: {np.mean(cers):.4f}")

# ✅ Submission
model.eval()
predictions = []
with torch.no_grad():
    for image, file_name in test_loader:
        image = image.to(device)
        logits = model(image)
        decoded = decode_prediction(logits)[0]
        predictions.append({"image": file_name[0], "text": decoded})

submission = pd.DataFrame(predictions)
submission.to_csv("submission.csv", index=False)
print("✅ Saved submission.csv")


Epoch 1/3 - Loss: 1.5568
Epoch 1 CER: 0.9956
Epoch 2/3 - Loss: 1.3946
Epoch 2 CER: 0.9862
Epoch 3/3 - Loss: 1.3262
Epoch 3 CER: 0.9033
✅ Saved submission.csv
