In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision import transforms

from datasets import load_dataset, load_from_disk, DatasetDict, concatenate_datasets
from transformers import AutoTokenizer

from PIL import Image
import cv2

from collections import Counter
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

import requests
from io import BytesIO
import json
import os
from google.colab import drive, files

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
drive.mount('/content/drive')
drive_base_path = "/content/drive/MyDrive/hf_ocr_data"
os.makedirs(drive_base_path, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
dataset_dict = load_dataset("MonlamAI/OCR-Tibetan_line_to_text_benchmark")

def is_valid(example):
    return example['label'] is not None and len(example['label'].strip()) > 0

filtered_splits = [
    dataset_dict[split].filter(is_valid)
    for split in dataset_dict.keys()
]

print(filtered_splits)


[Dataset({
    features: ['filename', 'label', 'image_url', 'BDRC_work_id', 'char_len', 'script', 'writing_style', 'print_method'],
    num_rows: 223080
}), Dataset({
    features: ['filename', 'label', 'image_url', 'BDRC_work_id', 'char_len', 'script', 'writing_style', 'print_method'],
    num_rows: 41784
}), Dataset({
    features: ['filename', 'label', 'image_url', 'BDRC_work_id', 'char_len', 'script', 'writing_style', 'print_method'],
    num_rows: 16326
}), Dataset({
    features: ['filename', 'label', 'image_url', 'BDRC_work_id', 'char_len', 'script', 'writing_style', 'print_method'],
    num_rows: 1217
}), Dataset({
    features: ['filename', 'label', 'image_url', 'BDRC_work_id', 'char_len', 'script', 'writing_style', 'print_method'],
    num_rows: 75168
}), Dataset({
    features: ['filename', 'label', 'image_url', 'BDRC_work_id', 'char_len', 'script', 'writing_style', 'print_method'],
    num_rows: 1317
}), Dataset({
    features: ['filename', 'label', 'image_url', 'BDRC_work_

In [None]:
full_dataset = concatenate_datasets(filtered_splits)

split = full_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split['train']
test_dataset = split['test']

print(train_dataset)
print(test_dataset)

Dataset({
    features: ['filename', 'label', 'image_url', 'BDRC_work_id', 'char_len', 'script', 'writing_style', 'print_method'],
    num_rows: 379442
})
Dataset({
    features: ['filename', 'label', 'image_url', 'BDRC_work_id', 'char_len', 'script', 'writing_style', 'print_method'],
    num_rows: 42161
})


In [None]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((64, 512)),
    transforms.ToTensor(),          # [C x H x W]
])

def build_char_vocab(dataset):
    counter = Counter()
    for text in dataset['label']:
        counter.update(text)
    chars = sorted(counter)
    char2id = {ch: i + 1 for i, ch in enumerate(chars)}  # Reserve 0 for <pad>
    char2id['<pad>'] = 0
    id2char = {i: ch for ch, i in char2id.items()}
    return char2id, id2char

char2id, id2char = build_char_vocab(train_dataset)

num_classes = len(char2id) + 1         # +1 for CTC blank token
ctc_blank_id = len(char2id)            # last index is CTC blank

print(list(char2id.items())[:10])
print(f"Vocab size (w/ pad): {len(char2id)}")
print(f"Total output classes (with blank): {num_classes}")
print(f"CTC blank index: {ctc_blank_id}")

[('\t', 1), (' ', 2), ('.', 3), (':', 4), ('ऽ', 5), ('।', 6), ('ༀ', 7), ('༄', 8), ('༅', 9), ('༆', 10)]
Vocab size (w/ pad): 137
Total output classes (with blank): 138
CTC blank index: 137


In [None]:
# Converts images and Tibetan text into tensors + character IDs
def process_example(example):
    try:
        response = requests.get(example['image_url'], timeout=3)
        image = Image.open(BytesIO(response.content)).convert('RGB')
        image = transform(image)  # [1, 64, 512]

        label_ids = [char2id.get(c, char2id["<pad>"]) for c in example['label']]

        return {
            'image_tensor': image,
            'label': example['label'],
            'label_ids': label_ids
        }

    except Exception as e:
        return {
            'image_tensor': torch.zeros(1, 64, 512),  # match expected shape
            'label': '',
            'label_ids': [char2id["<pad>"]]
        }

In [None]:
import shutil
drive_base_path = "/content/drive/MyDrive/hf_ocr_data"
os.makedirs(drive_base_path, exist_ok=True)

processed_path = os.path.join(drive_base_path, "processed_train")
tmp_path = os.path.join(drive_base_path, "processed_train_tmp")
'''
if os.path.exists(processed_path):
    processed_old = load_from_disk(processed_path)
    already_processed = len(processed_old)
    print(f"Already processed: {already_processed}")
else:
    processed_old = None
    already_processed = 0
    print("Starting fresh.")

batch_size = 10000
end_index = already_processed + batch_size

if already_processed >= len(train_dataset):
    print("All training data has been processed!")
else:
    new_raw_dataset = train_dataset.select(range(already_processed, min(end_index, len(train_dataset))))
    print(f"Processing entries {already_processed} to {min(end_index, len(train_dataset))}")

    processed_new = new_raw_dataset.map(
        process_example,
        remove_columns=new_raw_dataset.column_names,
        num_proc=4  # Adjust based on your CPU cores
    )

    if processed_old:
        merged_dataset = concatenate_datasets([processed_old, processed_new])
    else:
        merged_dataset = processed_new

    merged_dataset.save_to_disk(tmp_path)

    if os.path.exists(processed_path):
        shutil.rmtree(processed_path)
    shutil.move(tmp_path, processed_path)

    print("Updated processed dataset saved.")'''

'\n# === 4. Determine Already Processed Count ===\nif os.path.exists(processed_path):\n    processed_old = load_from_disk(processed_path)\n    already_processed = len(processed_old)\n    print(f"Already processed: {already_processed}")\nelse:\n    processed_old = None\n    already_processed = 0\n    print("Starting fresh.")\n\n# === 5. Get Next Batch and Preprocess ===\nbatch_size = 10000\nend_index = already_processed + batch_size\n\nif already_processed >= len(train_dataset):\n    print("All training data has been processed!")\nelse:\n    new_raw_dataset = train_dataset.select(range(already_processed, min(end_index, len(train_dataset))))\n    print(f"Processing entries {already_processed} to {min(end_index, len(train_dataset))}")\n\n    processed_new = new_raw_dataset.map(\n        process_example,\n        remove_columns=new_raw_dataset.column_names,\n        num_proc=4  # Adjust based on your CPU cores\n    )\n\n    # === 6. Merge and Save Safely ===\n    if processed_old:\n       

In [None]:
# Adjust paths as per your Drive setup
drive_base_path = "/content/drive/MyDrive/hf_ocr_data"
test_processed_path = os.path.join(drive_base_path, "processed_test")
test_tmp_path = os.path.join(drive_base_path, "processed_test_tmp")
'''
if os.path.exists(test_processed_path):
    processed_old_test = load_from_disk(test_processed_path)
    already_processed_test = len(processed_old_test)
    print(f"Test set already processed: {already_processed_test}")
else:
    processed_old_test = None
    already_processed_test = 0
    print("Starting fresh test preprocessing.")

batch_size = 10000
end_index_test = already_processed_test + batch_size

if already_processed_test >= len(test_dataset):
    print("All test data has been processed!")
else:
    new_raw_test = test_dataset.select(range(already_processed_test, min(end_index_test, len(test_dataset))))
    print(f"Processing test entries {already_processed_test} to {min(end_index_test, len(test_dataset))}")

    processed_new_test = new_raw_test.map(
        process_example,
        remove_columns=new_raw_test.column_names,
        num_proc=4  # adjust as needed
    )

    if processed_old_test:
        merged_test = concatenate_datasets([processed_old_test, processed_new_test])
    else:
        merged_test = processed_new_test

    merged_test.save_to_disk(test_tmp_path)

    if os.path.exists(test_processed_path):
        shutil.rmtree(test_processed_path)
    shutil.move(test_tmp_path, test_processed_path)

    print("Updated processed test dataset saved.")'''

'\n# === Check already processed count for test dataset ===\nif os.path.exists(test_processed_path):\n    processed_old_test = load_from_disk(test_processed_path)\n    already_processed_test = len(processed_old_test)\n    print(f"Test set already processed: {already_processed_test}")\nelse:\n    processed_old_test = None\n    already_processed_test = 0\n    print("Starting fresh test preprocessing.")\n\n# === Process next batch of test data ===\nbatch_size = 10000\nend_index_test = already_processed_test + batch_size\n\nif already_processed_test >= len(test_dataset):\n    print("All test data has been processed!")\nelse:\n    new_raw_test = test_dataset.select(range(already_processed_test, min(end_index_test, len(test_dataset))))\n    print(f"Processing test entries {already_processed_test} to {min(end_index_test, len(test_dataset))}")\n\n    processed_new_test = new_raw_test.map(\n        process_example,\n        remove_columns=new_raw_test.column_names,\n        num_proc=4  # adjust

In [None]:
processed_train_dataset = load_from_disk(processed_path)
processed_val_dataset = load_from_disk(test_processed_path)

char2id_path = '/content/drive/MyDrive/hf_ocr_data/char2id.json'
id2char_path = '/content/drive/MyDrive/hf_ocr_data/id2char.json'

with open(char2id_path, "r", encoding="utf-8") as f:
    char2id = json.load(f)
with open(id2char_path, "r", encoding="utf-8") as f:
    id2char = json.load(f)

id2char = {int(k): v for k, v in id2char.items()}

num_classes = len(char2id) + 1
ctc_blank_id = len(char2id)

print(f"Characters in vocab: {len(char2id)}")
print(f"Total output classes (with CTC blank): {num_classes}")
print(f"CTC blank index: {ctc_blank_id}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Characters in vocab: 137
Total output classes (with CTC blank): 138
CTC blank index: 137
Using device: cuda


In [None]:
class TibetanImageTensorDataset(Dataset):
    def __init__(self, data, target_height=64, target_width=768, max_label_len=128):
        self.target_height = target_height
        self.target_width = target_width
        self.data = data.filter(lambda x: len(x['label_ids']) <= max_label_len and len(x['label_ids']) > 0)
        self.data.set_format(type='torch', columns=['image_tensor', 'label_ids'])
        print(f"Dataset loaded: {len(self.data)} samples (after filtering label length ≤ {max_label_len})")

    def __getitem__(self, idx):
        item = self.data[idx]
        image = item['image_tensor']
        label = item['label_ids']
        if image.ndim == 2:
            image = image.unsqueeze(0)
        image = F.interpolate(
            image.unsqueeze(0),
            size=(self.target_height, self.target_width),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)
        return image, label, len(label)

    def __len__(self):
        return len(self.data)


In [None]:

batch_size = 32

train_ds = TibetanImageTensorDataset(processed_train_dataset)
val_ds = TibetanImageTensorDataset(processed_val_dataset)

def collate_fn(batch):
    images = torch.stack([x[0] for x in batch])  # [B, C, H, W]
    labels = [x[1] for x in batch]
    label_lengths = torch.tensor([x[2] for x in batch], dtype=torch.long) # helps ignore the padding in loss calculation by using the lengths
    padded_labels = pad_sequence(labels, batch_first=True, padding_value=0) # pad the labels to the max length in the batch
    return images, padded_labels, label_lengths

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn
)

Dataset loaded: 12146 samples (after filtering label length ≤ 128)
Dataset loaded: 12103 samples (after filtering label length ≤ 128)




In [None]:
images, padded_labels, label_lengths = next(iter(train_loader))
print(images.shape)
print(padded_labels.shape)
print(label_lengths)


torch.Size([32, 1, 64, 768])
torch.Size([32, 113])
tensor([ 63,  60,  10,  72,  60,  65,  70,  69,  58,  61,  12,  74,  57,  72,
         59,  67,  65,  63,  65,  73,  55, 113,  67,  51,  66,  76,  55,  23,
         77,  65,  61,  31])


In [None]:
all_label_ids = []
for sample in processed_train_dataset:
    all_label_ids.extend(sample["label_ids"])

max_label_id = max(all_label_ids)
print("Maximum label ID found in dataset:", max_label_id)
print("Vocabulary size (char2id):", len(char2id))
print("CTC blank index (should be vocab size):", ctc_blank_id)
assert max_label_id < ctc_blank_id, "Your label IDs exceed expected vocab size!"

In [None]:
num_classes = len(char2id) + 1         # real chars + CTC blank
ctc_blank_id = len(char2id)            # blank is last class

class MiniResNetCTC(nn.Module):
    def __init__(self, num_classes):
        super(MiniResNetCTC, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.linear = nn.Linear(1024, 128)
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.cnn(x)
        x = x.permute(0, 3, 1, 2)
        x = x.reshape(x.size(0), x.size(1), -1)
        x = self.linear(x)
        x = self.classifier(x)
        return x.log_softmax(dim=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MiniResNetCTC(num_classes=num_classes).to(device)

In [None]:
criterion = nn.CTCLoss(blank=ctc_blank_id, zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print(processed_train_dataset[0].keys())

In [None]:
batch_size = 32

train_ds = TibetanImageTensorDataset(processed_train_dataset)
val_ds   = TibetanImageTensorDataset(processed_val_dataset)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

for images, labels, label_lengths in train_loader:
    print("Images:", type(images), "Is Tensor:", torch.is_tensor(images))
    print("Labels:", type(labels), "Is Tensor:", torch.is_tensor(labels))
    print("Label Lengths:", type(label_lengths), "Is Tensor:", torch.is_tensor(label_lengths))
    break

In [None]:
def ctc_greedy_decode_single(pred_ids, blank):
    decoded = []
    prev = None
    for i in pred_ids:
        if i != blank and i != prev:
            decoded.append(i)
        prev = i
    return decoded

In [None]:
def evaluate(model, val_loader, id2char, device):
    model.eval()
    exact_matches = 0
    total = 0

    with torch.no_grad():
        for batch in val_loader:
            if batch is None:
                continue

            images, labels, label_lengths = batch
            images = images.to(device)

            # Forward pass
            outputs = model(images)
            log_probs = outputs.log_softmax(2)
            decoded = log_probs.argmax(2)

            batch_size = decoded.size(0)

            # Decode batch predictions to strings
            for i in range(batch_size):
                pred_ids = decoded[i].cpu().numpy()
                pred_decoded = ctc_greedy_decode_single(pred_ids, blank=ctc_blank_id)
                pred_text = ''.join([id2char.get(idx, '') for idx in pred_decoded])
                label_ids = labels[i][:label_lengths[i]].cpu().numpy()
                true_text = ''.join([id2char.get(idx, '') for idx in label_ids])
                if pred_text == true_text:
                    exact_matches += 1
                total += 1

    accuracy = exact_matches / total if total > 0 else 0.0
    return accuracy

In [None]:
num_epochs = 10
loss_history = []

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    skipped_batches = 0

    for batch in train_loader:
        if batch is None:
            skipped_batches += 1
            continue

        images, labels, label_lengths = batch
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)

        # Forward pass
        outputs = model(images)
        log_probs = outputs.log_softmax(2).permute(1, 0, 2)  # -> [T, B, C]
        T = log_probs.size(0)
        B = log_probs.size(1)

        input_lengths = torch.full(size=(B,), fill_value=T, dtype=torch.long, device=device)

        # Filter out invalid samples
        if (label_lengths > T).any():
            print(f"Skipping batch: label length > input length")
            skipped_batches += 1
            continue

        if (label_lengths == 0).any():
            print(f"Skipping batch: empty label found")
            skipped_batches += 1
            continue

        valid_indices = [i for i in range(B) if label_lengths[i] > 0]
        if not valid_indices:
            print("All labels in this batch are empty, skipping batch.")
            skipped_batches += 1
            continue

        try:
            # Flatten valid label sequences
            flattened_labels = torch.cat([
                labels[i, :label_lengths[i]] for i in valid_indices
            ])
        except Exception as e:
            print(f"Failed to flatten labels: {e}")
            skipped_batches += 1
            continue

        # CTC loss
        loss = criterion(log_probs, flattened_labels, input_lengths, label_lengths)

        if torch.isnan(loss):
            print("Skipping batch: loss is NaN")
            skipped_batches += 1
            continue

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        total_loss += loss.item()

    num_processed_batches = len(train_loader) - skipped_batches
    if num_processed_batches == 0:
        print("All batches skipped! Check your data and model.")
        break

    average_loss = total_loss / num_processed_batches
    loss_history.append(average_loss)

    #accuracy
    val_accuracy = evaluate(model, val_loader, id2char, device)
    print(f"Epoch {epoch + 1}/{num_epochs} | Loss: {average_loss:.4f} | "
          f"Validation Accuracy: {val_accuracy*100:.2f}% | Skipped Batches: {skipped_batches}")

In [None]:
from difflib import SequenceMatcher

def char_level_accuracy(pred, target):
    """
    Computes character-level similarity between prediction and ground truth.
    Returns a float in [0, 1].
    """
    if len(target) == 0:
        return 0.0
    matcher = SequenceMatcher(None, pred, target)
    return matcher.ratio()

def ctc_greedy_decode_batch(pred_ids, blank):
    """
    Applies greedy decoding with CTC post-processing (remove duplicates + blanks).
    Accepts a batch of predictions and returns a list of decoded ID sequences.
    """
    decoded_texts = []
    for pred in pred_ids:
        prev_id = blank
        decoded = []
        for idx in pred:
            if idx != blank and idx != prev_id:
                decoded.append(idx)
            prev_id = idx
        decoded_texts.append(decoded)
    return decoded_texts

def evaluate_character_accuracy(model, val_loader, id2char, device):
    model.eval()
    total_score = 0.0
    total_samples = 0

    with torch.no_grad():
        for batch in val_loader:
            if batch is None:
                continue

            images, labels, label_lengths = batch
            images = images.to(device)

            outputs = model(images)  # [B, T, C]
            pred_ids = outputs.argmax(2).cpu().numpy()  # [B, T]

            decoded_batch = ctc_greedy_decode_batch(pred_ids, blank=ctc_blank_id)

            for i in range(images.size(0)):
                pred_text = ''.join([id2char.get(idx, '') for idx in decoded_batch[i]])

                label_ids = labels[i][:label_lengths[i]].cpu().numpy()
                true_text = ''.join([id2char.get(idx, '') for idx in label_ids])

                acc = char_level_accuracy(pred_text, true_text)
                total_score += acc
                total_samples += 1

    return total_score / total_samples if total_samples > 0 else 0.0

char_acc = evaluate_character_accuracy(model, val_loader, id2char, device)
print(f"Validation Character-Level Accuracy: {char_acc * 100:.2f}%")

In [None]:
model.eval()
model.to(device)

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((64, 512)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# image upload
uploaded_files = files.upload()
img_filename = next(iter(uploaded_files.keys()))

img = cv2.imread(img_filename)
if img is None:
    raise FileNotFoundError(f"Image not found or failed to load: {img_filename}")
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

image_tensor = transform(img_pil).unsqueeze(0).to(device)

with torch.no_grad():
    output = model(image_tensor)  # [B, T, C]
    output = output.log_softmax(2)
    pred_ids = output.argmax(2).squeeze(0).cpu().numpy()  # [T]

decoded_ids = ctc_greedy_decode_single(pred_ids, blank=ctc_blank_id)

#blank info
print(f"CTC blank index: {ctc_blank_id}")
if ctc_blank_id in id2char:
    print(f"Warning: Blank index maps to: {repr(id2char[ctc_blank_id])}")
else:
    print("Blank index is not in id2char (safe)")

# character frequency
from collections import Counter
counter = Counter(decoded_ids)
print("Top 10 predicted character IDs and frequencies:")
for idx, count in counter.most_common(10):
    ch = id2char.get(idx, '?')
    print(f"ID {idx}: {repr(ch)} — {count} times")

pred_text = ''.join([id2char[i] for i in decoded_ids if i in id2char])
print("Final Predicted Text:", pred_text)
