In [9]:
# =====================================================
# STAGE 1: OCR with TrOCR (extract text from handwriting PNGs)
# =====================================================

!pip install -q transformers datasets
!pip install wandb

import os
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
from tqdm import tqdm

# Path to dataset (root directory with 657 subdirectories)
data_dir = "/kaggle/input/iam-handwritten-forms-dataset/data"
output_corpus = "iam_corpus.txt"

# Load TrOCR pretrained model
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten").to("cuda")

texts = []

# Iterate through all PNG files with batching
batch_size = 16
total_pngs = 0
for root, dirs, files in os.walk(data_dir):
    png_files = [f for f in files if f.endswith(".png")]
    total_pngs += len(png_files)
    for i in tqdm(range(0, len(png_files), batch_size), desc=f"OCR in {os.path.basename(root)}"):
        batch_files = png_files[i:i + batch_size]
        images = []
        for file in batch_files:
            img_path = os.path.join(root, file)
            try:
                image = Image.open(img_path).convert("RGB")
                images.append(image)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
        if not images:
            continue
        
        pixel_values = processor(images=images, return_tensors="pt").pixel_values.to("cuda")
        generated_ids = model.generate(pixel_values, max_length=512)
        batch_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
        texts.extend(batch_texts)

# Save extracted text corpus
with open(output_corpus, "w", encoding="utf-8") as f:
    for line in texts:
        f.write(line.strip() + "\n")

print(f"OCR complete. Processed {total_pngs} PNGs, extracted {len(texts)} lines of text")
print("Example OCR result:", texts[0] if texts else "No text extracted")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
OCR in data: 0it [00:00, ?it/s]
OCR in 515: 100%|██████████| 1/1 [00:00<00:00,  3.14it/s]
OCR in 248: 100%|██████████| 1/1 [00:00<00:00,  1.01it/s]
OCR in 625: 100%|██████████| 1/1 [00:00<00:00,  5.34it/s]
OCR in 135: 100%|██████████| 1/1 [00:00<00:00,  2.62it/s]
OCR in 479: 100%|██████████| 1/1 [00:00<00:00,  4.69it/s]
OCR in 183: 100%|██████████| 1/1 [00:00<00:00,  4.78it/s]
OCR in 642: 100%|██████████| 1/1 [00:00<00:00,  4.89it/s]
OCR in 313: 100%|██████████| 1/1 [00:00<00:00,  4.66it/s]
OCR in 600: 100%|██████████| 1/1 [00:00<00:00,  4.63it/s]
OCR in 086: 100%|██████████| 1/1 [00:00<00:00,  5.31it/s]
OCR in 466: 100%|██████████| 1/1 [00:00<00:00,  4.58it/s]
OCR in

OCR complete. Processed 1539 PNGs, extracted 1539 lines of text
Example OCR result: 0 0000





In [10]:
# =====================================================
# STAGE 2a: Train LSTM-based Language Model
# =====================================================

import torch.nn as nn
import torch.optim as optim

# Load text corpus and preprocess
with open(output_corpus, "r", encoding="utf-8") as f:
    corpus = f.read()

# Filter out non-printable chars and normalize
import re
corpus = re.sub(r'[^\x20-\x7E]', ' ', corpus)  # Keep ASCII printable
chars = sorted(list(set(corpus)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}

def encode(s): return [stoi.get(c, 0) for c in s]  # Default to 0 for unknown
def decode(l): return ''.join([itos.get(i, '') for i in l])

data = torch.tensor(encode(corpus), dtype=torch.long)
split = int(0.9 * len(data))
train_data, val_data = data[:split], data[split:]

block_size = 128
batch_size = 64

def get_batch(split):
    d = train_data if split == 'train' else val_data
    ix = torch.randint(len(d) - block_size, (batch_size,))
    x = torch.stack([d[i:i+block_size] for i in ix])
    y = torch.stack([d[i+1:i+block_size+1] for i in ix])
    return x.cuda(), y.cuda()

class CharLSTM(nn.Module):
    def __init__(self, vocab_size, hidden_size=256, n_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        if hidden is None:
            batch_size = x.size(0)
            h0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).cuda()
            c0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).cuda()
            hidden = (h0, c0)
        x = self.embed(x)
        out, hidden = self.lstm(x, hidden)
        logits = self.fc(out)
        return logits, hidden

vocab_size = len(chars)
model_lstm = CharLSTM(vocab_size).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_lstm.parameters(), lr=0.003)

# Train with validation
for epoch in range(3):
    model_lstm.train()
    total_train_loss = 0
    for _ in range(100):  # Limit iterations for demo
        x, y = get_batch('train')
        optimizer.zero_grad()
        logits, _ = model_lstm(x)
        loss = criterion(logits.view(-1, vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
    print(f"Epoch {epoch+1}, Train Loss: {total_train_loss/100:.4f}")
    
    model_lstm.eval()
    val_loss = 0
    with torch.no_grad():
        x, y = get_batch('val')
        logits, _ = model_lstm(x)
        val_loss = criterion(logits.view(-1, vocab_size), y.view(-1)).item()
    print(f"Val Loss: {val_loss:.4f}")

torch.save(model_lstm.state_dict(), "lstm_lm.pth")
print("LSTM LM trained and saved.")

Epoch 1, Train Loss: 0.7941
Val Loss: 0.7910
Epoch 2, Train Loss: 0.2664
Val Loss: 0.6952
Epoch 3, Train Loss: 0.2157
Val Loss: 0.8229
LSTM LM trained and saved.


In [11]:
# =====================================================
# STAGE 2b: Fine-tune DistilGPT2 on OCR corpus
# =====================================================

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
import os

# Prepare dataset
dataset = Dataset.from_dict({"text": texts})
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token  # Handle padding

def tokenize(batch):
    # Tokenize the text
    encodings = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=128, return_tensors="pt")
    # Create labels by shifting input_ids (predict next token)
    encodings["labels"] = encodings["input_ids"].clone()
    return encodings.data  # Return as dict without extra tensor structure

tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])
print(f"Tokenized dataset size: {len(tokenized_dataset)}")

# Split train/val
train_test = tokenized_dataset.train_test_split(test_size=0.1)
train_dataset = train_test["train"]
eval_dataset = train_test["test"]
print(f"Train dataset size: {len(train_dataset)}")
print(f"Eval dataset size: {len(eval_dataset)}")

# Load model and move to GPU
model_gpt = AutoModelForCausalLM.from_pretrained("distilgpt2").to("cuda")

# Training args
training_args = TrainingArguments(
    output_dir="./gpt2-finetuned",
    do_eval=True,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
    save_strategy="epoch"
)

# Disable WandB to avoid conflicts
os.environ["WANDB_DISABLED"] = "true"

# Initialize Trainer with GPU model
trainer = Trainer(
    model=model_gpt,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

print("Starting training...")
trainer.train()
print("Training completed.")
model_gpt.save_pretrained("gpt2-finetuned")
tokenizer.save_pretrained("gpt2-finetuned")

print("DistilGPT-2 fine-tuned and saved.")

Map:   0%|          | 0/1539 [00:00<?, ? examples/s]

Tokenized dataset size: 1539
Train dataset size: 1385
Eval dataset size: 154


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Starting training...


Step,Training Loss
50,0.2319
100,0.0149
150,0.0107


Training completed.
DistilGPT-2 fine-tuned and saved.


In [12]:
# =====================================================
# STAGE 2c: Evaluate Language Models
# =====================================================

from math import exp
import numpy as np

# LSTM Perplexity and Top-k Accuracy
def compute_metrics_lstm():
    model_lstm.eval()
    x, y = get_batch('val')
    with torch.no_grad():
        logits, _ = model_lstm(x)
        loss = criterion(logits.view(-1, vocab_size), y.view(-1)).item()
        ppl = exp(loss)
        _, topk_ids = logits.topk(5, dim=-1)  # Top-5
        y_flat = y.view(-1)
        topk_flat = topk_ids.view(-1, 5)
        top1_correct = (topk_ids[:, :, 0] == y).float().mean().item() * 100
        top5_correct = (topk_flat == y_flat.unsqueeze(-1)).any(dim=-1).float().mean().item() * 100
    return {
        'Perplexity (↓)': ppl,
        'Top-1 Accuracy (↑)': top1_correct,
        'Top-5 Accuracy (↑)': top5_correct
    }

lstm_metrics = compute_metrics_lstm()
print("LSTM Metrics:", {k: f"{v:.2f}" for k, v in lstm_metrics.items()})

LSTM Metrics: {'Perplexity (↓)': '2.53', 'Top-1 Accuracy (↑)': '78.61', 'Top-5 Accuracy (↑)': '95.96'}


In [13]:
from torch.nn.utils.rnn import pad_sequence

def compute_metrics_gpt2(eval_dataset, tokenizer, model_gpt, batch_size=8, device="cuda"):
    model_gpt.eval()
    losses = []
    top1_correct = 0
    top5_correct = 0
    total = 0

    # Determine dataset type
    is_hf_dataset = hasattr(eval_dataset, "column_names")  # Hugging Face dataset

    for i in tqdm(range(0, len(eval_dataset), batch_size), desc="GPT-2 Eval"):
        # Get batch depending on type
        if is_hf_dataset:
            # Hugging Face dataset -> dict of columns
            batch_input_ids = eval_dataset[i:i + batch_size]["input_ids"]
            input_ids_list = [torch.tensor(x, dtype=torch.long) for x in batch_input_ids]
            input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
        else:
            batch = eval_dataset[i:i + batch_size]
            if isinstance(batch[0], dict) and "input_ids" in batch[0]:
                input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in batch]
                input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
            elif isinstance(batch[0], list) or isinstance(batch[0], torch.Tensor):
                input_ids_list = [torch.tensor(x, dtype=torch.long) for x in batch]
                input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
            elif isinstance(batch[0], str):
                encodings = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
                input_ids = encodings["input_ids"].to(device)
            else:
                raise ValueError(f"Unsupported data type: {type(batch[0])}")

        # Shift for next-token prediction
        labels = input_ids[:, 1:].contiguous()
        input_ids = input_ids[:, :-1].contiguous()

        with torch.no_grad():
            outputs = model_gpt(input_ids, labels=labels)
            loss = outputs.loss
            losses.append(loss.item())

            logits = outputs.logits
            _, topk_ids = logits.topk(5, dim=-1)
            top1_correct += (topk_ids[:, :, 0] == labels).float().sum().item()
            top5_correct += (topk_ids == labels.unsqueeze(-1)).any(dim=-1).float().sum().item()
            total += labels.numel()

    ppl = exp(np.mean(losses))
    top1_acc = (top1_correct / total) * 100
    top5_acc = (top5_correct / total) * 100

    return {
        'Perplexity (↓)': ppl,
        'Top-1 Accuracy (↑)': top1_acc,
        'Top-5 Accuracy (↑)': top5_acc
    }

# Example usage
gpt2_metrics = compute_metrics_gpt2(eval_dataset, tokenizer, model_gpt, batch_size=8, device="cuda")
print("GPT-2 Metrics:", {k: f"{v:.2f}" for k, v in gpt2_metrics.items()})

GPT-2 Eval: 100%|██████████| 20/20 [00:01<00:00, 15.72it/s]

GPT-2 Metrics: {'Perplexity (↓)': '2.82', 'Top-1 Accuracy (↑)': '99.01', 'Top-5 Accuracy (↑)': '99.87'}





In [14]:
# =====================================================
# STAGE 3: Predict Next Words with Confidence
# =====================================================

import torch
from transformers import AutoTokenizer
import numpy as np

def predict_next_words(input_text, tokenizer, model_lstm, model_gpt, k=5, max_length=10):
    # Ensure models are in eval mode
    model_lstm.eval()
    model_gpt.eval()

    # Tokenize input for both models
    # For LSTM (character-based)
    chars = sorted(list(set("".join(texts) + input_text)))
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for ch, i in stoi.items()}
    input_ids_lstm = torch.tensor([stoi.get(c, 0) for c in input_text], dtype=torch.long).unsqueeze(0).to("cuda")
    
    # For GPT-2 (token-based)
    inputs_gpt = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    input_ids_gpt = inputs_gpt["input_ids"].to("cuda")
    attention_mask = inputs_gpt["attention_mask"].to("cuda")

    # Predict with LSTM
    with torch.no_grad():
        logits, _ = model_lstm(input_ids_lstm)
        probs = torch.softmax(logits[:, -1, :], dim=-1)  # Probabilities for the last character
        topk_values_lstm, topk_indices_lstm = probs.topk(k)
        next_chars_lstm = [itos[i] for i in topk_indices_lstm[0].cpu().numpy()]
        confidences_lstm = topk_values_lstm[0].cpu().numpy()

    # Predict with GPT-2
    with torch.no_grad():
        outputs = model_gpt(input_ids_gpt, attention_mask=attention_mask)
        logits = outputs.logits[:, -1, :]  # Logits for the last token
        probs = torch.softmax(logits, dim=-1)
        topk_values_gpt, topk_indices_gpt = probs.topk(k)
        next_tokens_gpt = tokenizer.convert_ids_to_tokens(topk_indices_gpt[0].cpu().numpy())
        confidences_gpt = topk_values_gpt[0].cpu().numpy()

    # Filter and map to words (simplified word boundary assumption)
    def get_words_from_chars(chars, confs):
        words = []
        current_word = ""
        current_conf = 1.0
        for char, conf in zip(chars, confs):
            if char == " " or len(current_word) > max_length:
                if current_word:
                    words.append((current_word, current_conf))
                current_word = ""
                current_conf = 1.0
            else:
                current_word += char
                current_conf *= conf
        if current_word:
            words.append((current_word, current_conf))
        return words[:k]

    def get_words_from_tokens(tokens, confs):
        words = []
        for token, conf in zip(tokens, confs):
            if token.startswith("##") or token in [".", ",", "!", "?"]:
                continue
            word = token.replace("##", "")
            words.append((word, conf))
        return words[:k]

    # Process predictions
    next_words_lstm = get_words_from_chars(next_chars_lstm, confidences_lstm)
    next_words_gpt = get_words_from_tokens(next_tokens_gpt, confidences_gpt)

    # Print results
    print(f"\nInput: '{input_text}'")
    print("LSTM Next Words with Confidence:")
    for word, conf in next_words_lstm:
        print(f"  - {word}: {conf:.4f}")
    print("GPT-2 Next Words with Confidence:")
    for word, conf in next_words_gpt:
        print(f"  - {word}: {conf:.4f}")

# Example usage
custom_input = "Please write a"
predict_next_words(custom_input, tokenizer, model_lstm, model_gpt)


Input: 'Please write a'
LSTM Next Words with Confidence:
  - nspl: 0.0000
GPT-2 Next Words with Confidence:
  - Ġnote: 0.1753
  - Ġcomment: 0.1711
  - Ġletter: 0.1178
  - Ġresponse: 0.0509
  - Ġreview: 0.0427


In [15]:
# =====================================================
#Predict Next Words from Image Input
# =====================================================

from PIL import Image
import requests  # If loading from URL; otherwise use local file path
from io import BytesIO

# Load TrOCR if not already loaded
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten").to("cuda")

def extract_text_from_image(image_path_or_url):
    """
    Extracts text from a handwritten image using TrOCR.
    - image_path_or_url: Local file path or URL to the image.
    """
    try:
        if image_path_or_url.startswith("http"):  # Load from URL
            response = requests.get(image_path_or_url)
            image = Image.open(BytesIO(response.content)).convert("RGB")
        else:  # Local file
            image = Image.open(image_path_or_url).convert("RGB")
        
        pixel_values = processor(images=image, return_tensors="pt").pixel_values.to("cuda")
        generated_ids = trocr_model.generate(pixel_values, max_length=512)
        extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        print(f"Extracted Text: '{extracted_text}'")
        return extracted_text.strip()
    except Exception as e:
        print(f"Error processing image: {e}")
        return ""

# Modified predict_next_words to handle image input
def predict_next_words_from_image(image_path_or_url, tokenizer, model_lstm, model_gpt, k=5, max_length=10):
    input_text = extract_text_from_image(image_path_or_url)
    if not input_text:
        print("No text extracted from image.")
        return
    
    # Now use the existing predict_next_words function
    predict_next_words(input_text, tokenizer, model_lstm, model_gpt, k, max_length)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
image_input = "/kaggle/input/test-image/WhatsApp Image 2025-10-14 at 17.40.54_ed2fab4d.jpg"
predict_next_words_from_image(image_input, tokenizer, model_lstm, model_gpt)

Extracted Text: '2 . Please write a'

Input: '2 . Please write a'
LSTM Next Words with Confidence:
  - nspl: 0.0000
GPT-2 Next Words with Confidence:
  - Ġnote: 0.1793
  - Ġcomment: 0.0898
  - Ġcopy: 0.0749
  - Ġletter: 0.0371
  - Ġreview: 0.0259
