In [None]:

import os
import torch
import evaluate
import numpy as np
import pandas as pd
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from PIL import Image
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from torchvision import datasets
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)

block_plot = False
plt.rcParams['figure.figsize'] = (12, 9)

def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE:    int = 30
    EPOCHS:        int = 5
    LEARNING_RATE: float = 0.0001

@dataclass(frozen=True)
class ModelConfig:
    MODEL_NAME: str = 'microsoft/trocr-small-handwritten'

# EMNIST Dataset Transformation to correct orientation and convert to RGB
emnist_correction = transforms.Compose([
    transforms.Lambda(lambda x: x.rotate(-90)),  # Rotate 90 degrees clockwise
    transforms.Lambda(lambda x: x.transpose(Image.FLIP_LEFT_RIGHT)),
    transforms.Grayscale(num_output_channels=3),
])

# Data Augmentations
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=0.5),
    transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 3)),
])

class EMNISTDataset(Dataset):
    def __init__(self, emnist_dataset, processor, max_target_length=4):
        self.emnist_dataset = emnist_dataset
        self.processor = processor
        self.max_target_length = max_target_length
        self.label_to_char = self._create_label_mapping()

    def _create_label_mapping(self):
        label_to_char = {}
        for idx in range(62):
            if idx < 10:
                label_to_char[idx] = str(idx)
            elif idx < 36:
                label_to_char[idx] = chr(65 + idx - 10)
            else:
                label_to_char[idx] = chr(97 + idx - 36)
        return label_to_char

    def __len__(self):
        return len(self.emnist_dataset)

    def __getitem__(self, idx):
        image, label_idx = self.emnist_dataset[idx]
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values
        text = self.label_to_char[label_idx]

        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids

        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        return {
            "pixel_values": pixel_values.squeeze(),
            "labels": torch.tensor(labels)
        }

# Load EMNIST datasets
emnist_train = datasets.EMNIST(
    root='./data',
    split='byclass',
    train=True,
    download=True,
    transform=emnist_correction
)

emnist_test = datasets.EMNIST(
    root='./data',
    split='byclass',
    train=False,
    download=True,
    transform=emnist_correction
)

# Initialize processor and datasets
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_dataset = EMNISTDataset(emnist_train, processor)
valid_dataset = EMNISTDataset(emnist_test, processor)

# Initialize model
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
model.to(device)

# Model configuration adjustments
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 4  # Adjusted for single character prediction
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 0
model.config.length_penalty = 1.0
model.config.num_beams = 1

# Metrics
cer_metric = evaluate.load('cer')

def compute_cer(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    return {"cer": cer}

# Training arguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
    per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
    fp16=True,
    output_dir='trocr-emnist-model',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=3,
    report_to='tensorboard',
    num_train_epochs=TrainingConfig.EPOCHS
)

# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_cer,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator  # Use the default data collator
)

# Start training
trainer.train()

# Save the final model
model.save_pretrained("trocr-emnist-final")
processor.save_pretrained("trocr-emnist-final")

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss


In [7]:
!pip install -q sentencepiece
!pip install -q jiwer
!pip install -q datasets
!pip install -q evaluate
!pip install -q -U accelerate

!pip install -q matplotlib
!pip install -q protobuf==3.20.1
!pip install -q tensorboard



In [None]:
!pip install transformers==4.44