In [None]:

# Install GPU dependencies
!pip uninstall -y torch torchvision torchaudio torchao trdg pillow transformers datasets opencv-python jiwer
!pip install torch==2.3.0+cu121 torchvision==0.18.0+cu121 --index-url https://download.pytorch.org/whl/cu121
!pip install transformers==4.39.3 datasets==3.6.0 opencv-python jiwer pillow>=9.4.0

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from datasets import load_dataset
import cv2
import numpy as np
from PIL import Image
from jiwer import cer, wer
from tqdm import tqdm
import logging
import os
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Setup
logging.basicConfig(level=logging.INFO, force=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
print(f"Using device: {device}")

# Data augmentation
def rotate_image(image, angle):
    (h, w) = image.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    rotated = cv2.warpAffine(image, M, (w, h))
    return rotated

def add_noise(image):
    noise = np.random.normal(0, 10, image.shape).astype(np.uint8)
    noisy_image = cv2.add(image, noise)
    return noisy_image

# Custom dataset
class OCRDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor
        self.valid_indices = self._validate_dataset()

    def _validate_dataset(self):
        valid_indices = []
        for idx in range(len(self.dataset)):
            try:
                item = self.dataset[idx]
                image = Image.fromarray(item["image"]).convert("RGB") if not isinstance(item["image"], Image.Image) else item["image"]
                text = item["text"]
                if not text or not isinstance(text, str):
                    continue
                img_array = np.array(image)
                if img_array.size == 0:
                    continue
                valid_indices.append(idx)
            except:
                continue
        logging.info(f"Validated dataset: {len(valid_indices)}/{len(self.dataset)} samples valid")
        print(f"Validated dataset: {len(valid_indices)}/{len(self.dataset)} samples valid")
        return valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        real_idx = self.valid_indices[idx]
        item = self.dataset[real_idx]
        image = Image.fromarray(item["image"]).convert("RGB") if not isinstance(item["image"], Image.Image) else item["image"]
        text = item["text"]

        img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
        img = cv2.GaussianBlur(img, (3, 3), 0)
        img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
        img = cv2.convertScaleAbs(img, alpha=1.2, beta=10)
        img = rotate_image(img, angle=np.random.uniform(-10, 10))
        img = add_noise(img)
        img = cv2.resize(img, (384, 384))  # Per report
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        img = img / 255.0
        image = Image.fromarray((img * 255).astype(np.uint8))

        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()
        labels = self.processor.tokenizer(
            text, padding="max_length", max_length=128, truncation=True, return_tensors="pt"
        ).input_ids.squeeze()

        return {"pixel_values": pixel_values, "labels": labels, "text": text}

# Load datasets (full IAM)
def load_datasets():
    iam_dataset = load_dataset("alpayariyak/IAM_Sentences", split="train")
    iam_dataset = iam_dataset.shuffle(seed=42)

    total_size = len(iam_dataset)
    train_size = int(0.8 * total_size)  # 80%
    val_size = int(0.1 * total_size)   # 10%
    test_size = total_size - train_size - val_size  # 10%
    train_dataset = OCRDataset(iam_dataset.select(range(train_size)), processor)
    val_dataset = OCRDataset(iam_dataset.select(range(train_size, train_size + val_size)), processor)
    test_dataset = OCRDataset(iam_dataset.select(range(train_size + val_size, total_size)), processor)

    logging.info(f"Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")
    print(f"Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")
    return train_dataset, val_dataset, test_dataset

# Learning rate scheduler with warmup
def get_scheduler(optimizer, num_warmup_steps, total_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        return 1.0
    warmup = LambdaLR(optimizer, lr_lambda)
    cosine = CosineAnnealingLR(optimizer, T_max=total_steps - num_warmup_steps)
    return warmup, cosine

# Fine-tune model
def fine_tune_model(model, train_loader, val_loader, processor, epochs=20, lr=5e-5):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scaler = GradScaler()
    model.to(device)

    num_warmup_steps = 3 * len(train_loader)  # 3-epoch warmup
    total_steps = epochs * len(train_loader)
    warmup_scheduler, cosine_scheduler = get_scheduler(optimizer, num_warmup_steps, total_steps)

    best_val_loss = float("inf")
    patience = 10
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            with autocast():
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss = outputs.loss
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            train_loss += loss.item()
            if epoch * len(train_loader) + batch_idx < num_warmup_steps:
                warmup_scheduler.step()
            else:
                cosine_scheduler.step()
            torch.cuda.empty_cache()

        train_loss /= len(train_loader)
        logging.info(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}")
        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                pixel_values = batch["pixel_values"].to(device)
                labels = batch["labels"].to(device)
                with autocast():
                    outputs = model(pixel_values=pixel_values, labels=labels)
                    val_loss += outputs.loss.item()
        val_loss /= len(val_loader)
        logging.info(f"Epoch {epoch+1}/{epochs} - Val Loss: {val_loss:.4f}")
        print(f"Epoch {epoch+1}/{epochs} - Val Loss: {val_loss:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                logging.info(f"Stopping early at epoch {epoch+1}, val loss {val_loss:.4f}")
                print(f"Stopping early at epoch {epoch+1}, val loss {val_loss:.4f}")
                break

        # Save model every 5 epochs
        if (epoch + 1) % 5 == 0:
            try:
                model.save_pretrained(f"/kaggle/working/fine_tuned_trocr_epoch_{epoch+1}")
                processor.save_pretrained(f"/kaggle/working/fine_tuned_trocr_epoch_{epoch+1}")
                logging.info(f"Model saved to /kaggle/working/fine_tuned_trocr_epoch_{epoch+1}")
                print(f"Model saved to /kaggle/working/fine_tuned_trocr_epoch_{epoch+1}")
            except Exception as e:
                logging.error(f"Error saving model at epoch {epoch+1}: {str(e)}")
                print(f"Error saving model at epoch {epoch+1}: {str(e)}")

# Evaluate model
def evaluate_model(model, test_loader, processor):
    model.eval()
    predictions = []
    ground_truths = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            try:
                pixel_values = batch["pixel_values"].to(device)
                outputs = model.generate(pixel_values)
                preds = processor.batch_decode(outputs, skip_special_tokens=True)
                preds = [p if p else "<empty>" for p in preds]
                predictions.extend(preds)
                ground_truths.extend(batch["text"])
            except Exception as e:
                logging.error(f"Error in evaluation batch: {str(e)}")
                print(f"Error in evaluation batch: {str(e)}")
                continue

    try:
        cer_score = cer(ground_truths, predictions)
        wer_score = wer(ground_truths, predictions)
        logging.info(f"Test CER: {cer_score:.4f}, WER: {wer_score:.4f}")
        print(f"Test CER: {cer_score:.4f}, WER: {wer_score:.4f}")
        return cer_score, wer_score
    except:
        logging.error("Error computing metrics")
        print("Error computing metrics")
        return None, None

# Main
def main():
    global processor
    processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
    model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

    if model.config.decoder_start_token_id is None:
        model.config.decoder_start_token_id = processor.tokenizer.cls_token_id or processor.tokenizer.bos_token_id
    if model.config.pad_token_id is None:
        model.config.pad_token_id = processor.tokenizer.pad_token_id

    train_dataset, val_dataset, test_dataset = load_datasets()

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1)
    test_loader = DataLoader(test_dataset, batch_size=1)

    fine_tune_model(model, train_loader, val_loader, processor)

    output_dir = "/kaggle/working/fine_tuned_trocr"
    try:
        model.save_pretrained(output_dir)
        processor.save_pretrained(output_dir)
        logging.info(f"Model and processor saved to {output_dir}")
        print(f"Model and processor saved to {output_dir}")
    except Exception as e:
        logging.error(f"Error saving model: {str(e)}")
        print(f"Error saving model: {str(e)}")

    cer_score, wer_score = evaluate_model(model, test_loader, processor)
    if cer_score is not None and wer_score is not None:
        logging.info(f"Final CER: {cer_score:.4f}, WER: {wer_score:.4f}")
        print(f"Final CER: {cer_score:.4f}, WER: {wer_score:.4f}")
    else:
        logging.error("Evaluation failed")
        print("Evaluation failed")

if __name__ == "__main__":
    main()


Found existing installation: torch 2.3.0+cu121
Uninstalling torch-2.3.0+cu121:
  Successfully uninstalled torch-2.3.0+cu121
Found existing installation: torchvision 0.18.0+cu121
Uninstalling torchvision-0.18.0+cu121:
  Successfully uninstalled torchvision-0.18.0+cu121
[0mFound existing installation: pillow 11.0.0
Uninstalling pillow-11.0.0:
  Successfully uninstalled pillow-11.0.0
Found existing installation: transformers 4.39.3
Uninstalling transformers-4.39.3:
  Successfully uninstalled transformers-4.39.3
Found existing installation: datasets 3.6.0
Uninstalling datasets-3.6.0:
  Successfully uninstalled datasets-3.6.0
Found existing installation: opencv-python 4.11.0.86
Uninstalling opencv-python-4.11.0.86:
  Successfully uninstalled opencv-python-4.11.0.86
Found existing installation: jiwer 3.1.0
Uninstalling jiwer-3.1.0:
  Successfully uninstalled jiwer-3.1.0
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.3.0+cu121
  Using cached https://download.p

INFO:root:Using device: cuda


Using device: cuda


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:root:Validated dataset: 4530/4530 samples valid


Validated dataset: 4530/4530 samples valid


INFO:root:Validated dataset: 566/566 samples valid


Validated dataset: 566/566 samples valid


INFO:root:Validated dataset: 567/567 samples valid
INFO:root:Dataset sizes: Train=4530, Val=566, Test=567


Validated dataset: 567/567 samples valid
Dataset sizes: Train=4530, Val=566, Test=567


Epoch 1/20: 100%|██████████| 4530/4530 [33:29<00:00,  2.25it/s]
INFO:root:Epoch 1/20 - Train Loss: 1.2547


Epoch 1/20 - Train Loss: 1.2547


INFO:root:Epoch 1/20 - Val Loss: 0.8010


Epoch 1/20 - Val Loss: 0.8010


Epoch 2/20: 100%|██████████| 4530/4530 [33:34<00:00,  2.25it/s]
INFO:root:Epoch 2/20 - Train Loss: 0.7747


Epoch 2/20 - Train Loss: 0.7747


INFO:root:Epoch 2/20 - Val Loss: 0.7539


Epoch 2/20 - Val Loss: 0.7539


Epoch 3/20: 100%|██████████| 4530/4530 [33:34<00:00,  2.25it/s]
INFO:root:Epoch 3/20 - Train Loss: 0.7634


Epoch 3/20 - Train Loss: 0.7634


INFO:root:Epoch 3/20 - Val Loss: 0.7751


Epoch 3/20 - Val Loss: 0.7751


Epoch 4/20: 100%|██████████| 4530/4530 [33:38<00:00,  2.24it/s]
INFO:root:Epoch 4/20 - Train Loss: 0.6560


Epoch 4/20 - Train Loss: 0.6560


INFO:root:Epoch 4/20 - Val Loss: 0.6268


Epoch 4/20 - Val Loss: 0.6268


Epoch 5/20: 100%|██████████| 4530/4530 [33:11<00:00,  2.28it/s]
INFO:root:Epoch 5/20 - Train Loss: 0.5176


Epoch 5/20 - Train Loss: 0.5176


INFO:root:Epoch 5/20 - Val Loss: 0.5932


Epoch 5/20 - Val Loss: 0.5932


INFO:root:Model saved to /kaggle/working/fine_tuned_trocr_epoch_5


Model saved to /kaggle/working/fine_tuned_trocr_epoch_5


Epoch 6/20:  11%|█▏        | 513/4530 [03:42<27:45,  2.41it/s]