In [None]:
# Install necessary libraries
!pip install transformers==4.45.0
!pip install -q sentencepiece
!pip install -q jiwer
!pip install -q datasets
!pip install -q evaluate
!pip install -q -U accelerate

!pip install -q matplotlib
!pip install -q protobuf==3.20.1
!pip install -q tensorboard

In [None]:
import os
import torch
import evaluate
import numpy as np
import pandas as pd
import glob as glob
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from PIL import Image
from zipfile import ZipFile
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from urllib.request import urlretrieve
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator,
    get_scheduler
)

In [None]:
block_plot = False
plt.rcParams['figure.figsize'] = (12, 9)

bold = f"\033[1m"
reset = f"\033[0m

In [None]:
# Seed setting for reproducibility
def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
# Download and unzip the dataset
def download_and_unzip(url, save_path):
    print(f"Downloading and extracting assets....", end="")

    urlretrieve(url, save_path)

    try:
        with ZipFile(save_path) as z:
            z.extractall(os.path.split(save_path)[0])

        print("Done")

    except Exception as e:
        print("\nInvalid file.", e)

In [None]:
URL = r"https://www.dropbox.com/scl/fi/vyvr7jbdvu8o174mbqgde/scut_data.zip?rlkey=fs8axkpxunwu6if9a2su71kxs&dl=1"
asset_zip_path = os.path.join(os.getcwd(), "scut_data.zip")
if not os.path.exists(asset_zip_path):
    download_and_unzip(URL, asset_zip_path)

In [None]:
# Configurations
@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE:    int = 48
    EPOCHS:        int = 35
    LEARNING_RATE: float = 0.00005
    WARMUP_STEPS: int = 2000
    LR_MAX: float = 0.0001

In [None]:
@dataclass(frozen=True)
class DatasetConfig:
    DATA_ROOT:     str = 'scut_data'

@dataclass(frozen=True)
class ModelConfig:
    MODEL_NAME: str = 'microsoft/trocr-small-printed'

# Augmentations
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=.5, hue=.3),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
])

In [None]:
# Custom Dataset
class CustomOCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        image = Image.open(self.root_dir + file_name).convert('RGB')
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values
        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [None]:
# Load the processor and datasets
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_df = pd.read_fwf(
    os.path.join(DatasetConfig.DATA_ROOT, 'scut_train.txt'), header=None
)
train_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)
test_df = pd.read_fwf(
    os.path.join(DatasetConfig.DATA_ROOT, 'scut_test.txt'), header=None
)
test_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)

train_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'scut_train/'),
    df=train_df,
    processor=processor
)
valid_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test/'),
    df=test_df,
    processor=processor
)


In [None]:
# Initialize the model
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
model.to(device)

# Set pad_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id

# Set up the optimizer
optimizer = optim.AdamW(
    model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
)

In [None]:
# Learning rate scheduler
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=TrainingConfig.WARMUP_STEPS,
    num_training_steps=TrainingConfig.EPOCHS * len(train_dataset) // TrainingConfig.BATCH_SIZE,
)

In [None]:
# Set up Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
    per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
    fp16=True,
    gradient_accumulation_steps=4,  # Simulate larger batch size
    max_grad_norm=1.0,  # Gradient clipping
    output_dir='seq2seq_model_printed/',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=5,
    report_to='tensorboard',
    num_train_epochs=TrainingConfig.EPOCHS,
    learning_rate=TrainingConfig.LEARNING_RATE,
)

In [None]:
# Initialize the trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator,
    optimizers=(optimizer, lr_scheduler)
)

In [None]:
# Start training
res = trainer.train()

In [None]:
from transformers import AutoTokenizer, VisionEncoderDecoderModel
import evaluate
from tqdm.notebook import tqdm
import torch

In [None]:
# Load CER metric
cer_metric = evaluate.load("cer")

In [None]:
# Function to compute CER on the validation set
def compute_cer(model, tokenizer, valid_dataset):
    model.eval()
    cer_score = 0
    total_samples = len(valid_dataset)

    # Loop through the validation set
    for i in tqdm(range(total_samples)):
        encoding = valid_dataset[i]
        pixel_values = encoding["pixel_values"].unsqueeze(0).to(device)
        labels = encoding["labels"].unsqueeze(0).to(device)

        # Get predictions from the model
        generated_ids = model.generate(pixel_values)
        pred_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        true_text = tokenizer.decode(labels[0], skip_special_tokens=True)

        # Compute CER
        cer_score += cer_metric.compute(predictions=[pred_text], references=[true_text])

    avg_cer = cer_score / total_samples
    return avg_cer

In [None]:
# Calculate CER for the validation set
validation_cer = compute_cer(model, tokenizer, valid_dataset)
print(f"Validation CER: {validation_cer}")