# Code Overview: Update these constants before running this code:
---
Before loading this data, run 'iam/dataset/process_dataset.ipynb' and 'imgur/dataset/process_dataset.ipynb'. The Data Processing section assumes the data has been loaded in the same way according to those files.
When loading this data, update the directory constants 'IMGUR_DATA_DIR' and 'IAM_DATA_DIR' with your own directories:
- The Imgur directory should contain 11 pickle files of word data images.
- The IAM directory should contain a 'words.txt' file containing the text labels and '/words' subfolder containing additional organized subfolders of word images.

Before training the model, update the model constant 'MODEL' to the model of your choice:
- 'microsoft/trocr-base-handwritten'
- 'microsoft/trocr-base-stage1'
- 'microsoft/trocr-small-stage1'

Update the model's output directory constant 'OUTPUT_DIR' to a directory of your choice. If you are running and saving multiple models, change this output directory before each run or the save will be overwritten.

Before evaluating the model, update the saved model checkpoint directory 'CHECKPOINT_DIR' to the saved model checkpoint that you would like to evaluate.

In [None]:
# Update these constants with your own directories
IMGUR_DATA_DIR = '/home/user/imgur' # Directory of Imgur data
IAM_DATA_DIR = '/home/user/iam' # Directory of IAM data - should contain '/words' subfolder containing the word images and 'words.txt' file containing the text labels
MODEL = 'microsoft/trocr-base-handwritten' # change to the model you would like to use
OUTPUT_DIR = '/home/user/output/models' # Directory to save the model
CHECKPOINT_DIR = '/home/user/output/models/checkpoint-####' # Directory of the saved model checkpoint that you would like to evaluate

## Step 1. Data Processing

### 1.1 Loading the Data

In [None]:
import os

# Change to your desired directory
os.chdir(IMGUR_DATA_DIR) # change $USER to netid

# Confirm it's changed
print("Current directory:", os.getcwd())

Imgur Data:

In [None]:
import pickle

with open('./dfwords_0_20000.pkl', 'rb') as file:
    imgur_df1 = pickle.load(file)

with open('./dfwords_20000_40000.pkl', 'rb') as file:
    imgur_df2 = pickle.load(file)

with open('./dfwords_40000_60000.pkl', 'rb') as file:
    imgur_df3 = pickle.load(file)

with open('./dfwords_60000_80000.pkl', 'rb') as file:
    imgur_df4 = pickle.load(file)

with open('./dfwords_80000_100000.pkl', 'rb') as file:
    imgur_df5 = pickle.load(file)

with open('./dfwords_100000_120000.pkl', 'rb') as file:
    imgur_df6 = pickle.load(file)

with open('./dfwords_120000_140000.pkl', 'rb') as file:
    imgur_df7 = pickle.load(file)

with open('./dfwords_140000_160000.pkl', 'rb') as file:
    imgur_df8 = pickle.load(file)

with open('./dfwords_160000_180000.pkl', 'rb') as file:
    imgur_df9 = pickle.load(file)

with open('./dfwords_180000_200000.pkl', 'rb') as file:
    imgur_df10 = pickle.load(file)

with open('./dfwords_200000_227055.pkl', 'rb') as file:
    imgur_df11 = pickle.load(file)

In [None]:
import pandas as pd

imgur_df = pd.concat([imgur_df1, imgur_df2], ignore_index=True)
imgur_df = pd.concat([imgur_df, imgur_df3], ignore_index=True)
imgur_df = pd.concat([imgur_df, imgur_df4], ignore_index=True)
imgur_df = pd.concat([imgur_df, imgur_df5], ignore_index=True)
imgur_df = pd.concat([imgur_df, imgur_df6], ignore_index=True)
imgur_df = pd.concat([imgur_df, imgur_df7], ignore_index=True)
imgur_df = pd.concat([imgur_df, imgur_df8], ignore_index=True)
imgur_df = pd.concat([imgur_df, imgur_df9], ignore_index=True)
imgur_df = pd.concat([imgur_df, imgur_df10], ignore_index=True)
imgur_df = pd.concat([imgur_df, imgur_df11], ignore_index=True)

IAM data:

In [None]:
import pandas as pd
from PIL import Image
label_file_path = IAM_DATA_DIR + '\words.txt'
image_file_path = IAM_DATA_DIR + '\words'

data = []
with open(label_file_path, 'r') as f:
    lines = f.readlines()

for idx, line in enumerate(lines[18:]):
    if idx % 1000 == 0:
        print(f"Processing line {idx}")
    row = []
    tokens = line.strip().split()
    if len(tokens) < 2:
        continue

    subfolder = tokens[0].split('-')[0]
    subfolder2 = subfolder + "-" + tokens[0].split('-')[1]
    image_file_name = subfolder + "\\" + subfolder2 + "\\" + tokens[0] + ".png"
    image_path = os.path.join(image_file_path, image_file_name)
    try:
        with Image.open(image_path) as img:
            if img.size[0] >= 10 and img.size[1] >= 10:
                img_rgb = img.convert("RGB")  # Convert to RGB
                img_copy = img_rgb.copy()     # Copy after conversion
    except FileNotFoundError as e:
        print(f"Image file not found: {image_file_path}. Error: {e}")
        continue
    except Image.UnidentifiedImageError as e:
        print(f"Unidentified image error for file {image_file_path}: {e}")
        continue
    except Exception as e:
        print(f"Error opening image file {image_file_path}: {e}")
        continue
    row = [image_path, tokens[1], tokens[2], tokens[-1], img_copy]
    # if len(row) != 10:
    #     print(f"Row length mismatch: {len(row)} elements in row: {row}")
    #     continue
    data.append(row)


print(f"Length of a row in data: {len(data[0])}")  # Should print 10

print(data[0])
iam_df = pd.DataFrame(data, columns=['id','text', 'image'])

### 1.2 Cleaning the Data

Before cleaning the data, set loaded_dfwords to the data frame with the data you would like to clean (either 'imgur_df' or 'iam_df'). This code is the same for cleaning both dataframes.

In [None]:
loaded_dfwords = df # replace 'df' with either 'imgur_df' or 'iam_df'

In [None]:
import matplotlib.pyplot as plt

def show_image(row):
    plt.imshow(loaded_dfwords.iloc[row,2])
    plt.show()

In [None]:
import re
allowed_pattern = r'^[\w\s\.,!?;:\-+*/=()\[\]{}<>@#\$%^&_\'"\t\n]+$'
loaded_dfwords['text'] = loaded_dfwords['text'].str.replace('\\/', '/', regex=False)
mask = ~loaded_dfwords['text'].str.contains(allowed_pattern, regex=True)
loaded_dfwords=loaded_dfwords[~mask]

In [25]:
loaded_dfwords = loaded_dfwords[loaded_dfwords['text'] != '.']

In [None]:
loaded_dfwords = loaded_dfwords[loaded_dfwords['text'] != '-----------------------------------------------------']

In [None]:
loaded_dfwords.info()

### 1.3 Splitting the Data into Training and Testing Subsets

In [29]:
import numpy as np

# Get unique groups
unique_images = loaded_dfwords['id'].unique()


# Randomly select 10% for test 
np.random.seed(42)
test_images = np.random.choice(unique_images, 
                              size=int(len(unique_images)*0.2), 
                              replace=False)

In [30]:
test_df = loaded_dfwords[loaded_dfwords['id'].isin(test_images)]
training_df = loaded_dfwords[~loaded_dfwords['id'].isin(test_images)]

### 1.4 Splitting the Training Data into Training and Validation Subsets

In [34]:
from sklearn.model_selection import train_test_split
import pandas as pd

train_df, eval_df = train_test_split(training_df, test_size=0.2, random_state=42)

train_df = train_df.reset_index(drop=True)
eval_df = eval_df.reset_index(drop=True)

## Step 2. Running the Model

Reminder to update model constant 'MODEL' with the model of your choice.

### 2.1 Loading the Model

In [39]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

In [None]:
# get base model
processor = TrOCRProcessor.from_pretrained(MODEL)
model = VisionEncoderDecoderModel.from_pretrained(MODEL)

### 2.2 Loading the Datasets

In [41]:
import torch
from torch.utils.data import Dataset
from PIL import Image

class StyleDataset(Dataset):
    def __init__(self, df, processor, max_target_length=512):
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
      try:
          text = self.df['text'][idx]
          if not isinstance(text, str) or not text.strip():
              raise ValueError(f"Invalid text at index {idx}: {repr(text)}")
          image_id = self.df['id'][idx]
          try:
              image = self.df['image'][idx]
          except Exception as e:
              raise ValueError(f"Failed to load image for ID {image_id} at index {idx}") from e
          try:
              pixel_values = self.processor(image, return_tensors="pt").pixel_values
          except Exception as e:
              raise ValueError(f"Image processing failed at index {idx}") from e

          if torch.isnan(pixel_values).any() or torch.isinf(pixel_values).any():
              raise ValueError(f"Invalid pixel values (NaN/inf) at index {idx}")
          try:
              labels = self.processor.tokenizer(
                  text,
                  padding="max_length",
                  max_length=self.max_target_length
              ).input_ids
          except Exception as e:
              raise ValueError(f"Tokenization failed for text at index {idx}") from e

          # Replace pad_token_id with -100 for loss masking
          labels = [
              label if label != self.processor.tokenizer.pad_token_id else -100
              for label in labels
          ]
          encoding = {
              "pixel_values": pixel_values.squeeze(),
              "labels": torch.tensor(labels)
          }

          if encoding["pixel_values"].dim() != 3:
              raise ValueError(f"Invalid pixel_values shape at index {idx}")

          if encoding["labels"].numel() != self.max_target_length:
              raise ValueError(f"Labels length mismatch at index {idx}")

          return encoding

      except Exception as e:
          print(f"\nError in sample {idx}:")
          print(f"   Error type: {type(e).__name__}")
          print(f"   Details: {str(e)}")
          if hasattr(e, '__cause__') and e.__cause__:
              print(f"   Underlying error: {type(e.__cause__).__name__}: {str(e.__cause__)}")
          print(f"   DataFrame row:\n{self.df.iloc[idx]}")
          return None

In [42]:
# Tokenized
train_dataset = StyleDataset(df=train_df,processor=processor)
eval_dataset= StyleDataset(df=eval_df,processor=processor)

In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

In [None]:
# get the labeled word from dataset encoding
def get_label_str(encoding):
  labels = encoding['labels']
  labels[labels == -100] = processor.tokenizer.pad_token_id
  label_str = processor.decode(labels, skip_special_tokens=True)
  return label_str

In [None]:
print(get_label_str(train_dataset[0]))
print(get_label_str(eval_dataset[0]))

### 2.3 Model Configuration

In [None]:
# Analyze your dataset first
avg_target_len = training_df['text'].apply(len).mean()
print("average target length", avg_target_len)
max_target_len = int(training_df['text'].apply(len).quantile(0.95))
print("maximum target length", max_target_len)

In [None]:
from transformers import GenerationConfig

generation_config = GenerationConfig(
    max_length=12,
    num_beams=4,
    early_stopping=True,
    length_penalty=1.0,
    repetition_penalty=1.5,
    no_repeat_ngram_size=3,
    decoder_start_token_id=processor.tokenizer.cls_token_id
)

In [None]:
# Token Alignment
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = len(processor.tokenizer)
model.generation_config = generation_config

### 2.4 Metrics

In [50]:
from evaluate import load
cer_metric = load("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

## Step 3. Fine-tune

Reminder to update model output directory 'OUTPUT_DIR' before fine-tuning. These configs are the same as the one we used to fine-tune each model.

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    eval_strategy="steps",
    num_train_epochs=1.87,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    warmup_steps=500,    # Essential for stability
    lr_scheduler_type="cosine",  # Smooth LR decay
    fp16=True,
    output_dir=OUTPUT_DIR,
    logging_steps=500,
    save_steps=2000,
    eval_steps=2000,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="cer",       
    greater_is_better=False,                
    generation_config=generation_config,
)

In [None]:
from transformers import default_data_collator
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)
trainer.train()


## Step 4. Evaluate

### 4.1 Load the Model

In [None]:
from transformers import VisionEncoderDecoderModel
from transformers import TrOCRProcessor
model = VisionEncoderDecoderModel.from_pretrained(CHECKPOINT_DIR).to("cuda")

### 4.2 Load Testing Dataset

In [None]:
# Should have been loaded from before in step 1.3
test_df.info()

### 4.3 Model Inference

There is separate code for running inference on Imgur data and IAM data, since the image id and path to the image file is different. Make sure to run the code corresponding to the testing dataset used.

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Model Inference on Imgur Data:

In [None]:
from tqdm import tqdm
from PIL import Image
def readText_batch(df, indices, model, processor):
    """Process multiple images at once"""
    images= [df['image'][idx]for idx in indices]
    pixel_values = processor(images=images, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    return processor.batch_decode(generated_ids, skip_special_tokens=True)

def process_all_rows_batched(df, model, processor, batch_size=8):
    results = []
    for i in tqdm(range(0, len(df), batch_size), desc="Processing batches"):
        batch_indices = range(i, min(i+batch_size, len(df)))
        try:
            batch_texts = readText_batch(df, batch_indices,model, processor)
            for idx, text in zip(batch_indices, batch_texts):
                results.append({
                    'id': df['id'][idx],
                    'true_text': df['text'][idx],
                    'predicted_text': text
                })
        except Exception as e:
            print(f"Error in batch {i//batch_size}: {str(e)}")
            for idx in batch_indices:
                results.append({
                    'id': df['id'][idx],
                    'true_text': df['text'][idx],
                    'predicted_text': None,
                    'error': str(e)
                })
    return pd.DataFrame(results)

Model Inference on IAM data:

In [None]:
# read image
def get_image(df, image_id):
    image_file_path = IAM_DATA_DIR + '/words'
    subfolder = image_id.split('-')[0]
    subfolder2 = subfolder + "-" + image_id.split('-')[1]
    image_file_name = image_id + ".png"
    image_path = os.path.join(image_file_path, subfolder, subfolder2, image_file_name)
    
    try:
        with Image.open(image_path) as img:
            if img.size[0] >= 10 and img.size[1] >= 10:
                img_rgb = img.convert("RGB")  # Convert to RGB
                return img_rgb.copy()  # Return a copy after conversion
    except Exception as e:
        print(f"Error opening image file {image_path}: {e}")
        return None
    
    # Step 2. Make Inference of the IAM-fine-tuned Base Model
    # 2.1 Setup Target Model
    # get targetted model
    from transformers import TrOCRProcessor, VisionEncoderDecoderModel
    processor = TrOCRProcessor.from_pretrained(processor)
    model = VisionEncoderDecoderModel.from_pretrained(model)
    
    
# 2.2 Do Inference- OCRing
from tqdm import tqdm
from PIL import Image
def readText_batch(df, indices, model, processor):
    """Process multiple images at once"""
    images = [get_image(df, df['image_id'][index]) for index in indices]
    pixel_values = processor(images=images, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    return processor.batch_decode(generated_ids, skip_special_tokens=True)

def process_all_rows_batched(df, model, processor, batch_size=8):
    results = []
    for i in tqdm(range(0, len(df), batch_size), desc="Processing batches"):
        batch_indices = range(i, min(i+batch_size, len(df)))
        try:
            batch_texts = readText_batch(df, batch_indices,model, processor)
            for idx, text in zip(batch_indices, batch_texts):
                results.append({
                    'id': df['id'][idx],
                    'true_text': df['text'][idx],
                    'predicted_text': text
                })
        except Exception as e:
            print(f"Error in batch {i//batch_size}: {str(e)}")
            for idx in batch_indices:
                results.append({
                    'id': df['id'][idx],
                    'true_text': df['text'][idx],
                    'predicted_text': None,
                    'error': str(e)
                })
    return pd.DataFrame(results)

In [None]:
results_df = process_all_rows_batched(test_df, batch_size=8)

### 4.4 Evaluate

In [None]:
from evaluate import load
cer = load("cer")

def compute_metrics(pred_str, label_str):
    pred_str=pred_str.strip()
    label_str=label_str.strip()
    try: 
        score = cer.compute(predictions=[pred_str], references=[label_str])
        return score
    except Exception as e:
        print("error", e)
        print(type(pred_str), len(pred_str), pred_str)
        print(type(label_str), len(label_str), label_str)
        return None

In [None]:
from tqdm import tqdm
tqdm.pandas()  
# Run evalution
results_df["metrics"] = results_df.progress_apply(
    lambda row: compute_metrics(row["predicted_text"], row["true_text"]),
    axis=1
)

### 4.4 Analyze Performance

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import math

def plot_eval(values):
    # Plotting
    plt.figure(figsize=(8, 5))
    sns.kdeplot(values, shade=True)
    plt.xlabel("Edit Distance")
    plt.title("KDE of Edit Distances")
    plt.show()
        
    # Boxplot
    plt.boxplot(values, vert=False, patch_artist=True)
    plt.xlabel("Edit Distance")
    plt.title("Boxplot of Edit Distances")
    
    plt.tight_layout()
    plt.show()

In [None]:
plot_eval(results_df["metrics"])

In [None]:
# remove outliers
results_normal= results_df[results_df["metrics"]<5]
plot_eval(results_normal["metrics"])

In [None]:
import numpy as np
def show_state(values):
    stats = {
        "mean": np.mean(values),
        "median": np.median(values),
        "std": np.std(values),
        "min": np.min(values),
        "max": np.max(values),
        "quantiles": np.quantile(values, [0.25, 0.5, 0.75]),
        "perfect": np.sum(values == 0)

    }
    
    print("Summary Statistics:")
    print(f"- Mean ± Std: {stats['mean']:.2f} ± {stats['std']:.2f}")
    print(f"- Median (IQR): {stats['median']:.2f} ({stats['quantiles'][0]:.2f}–{stats['quantiles'][2]:.2f})")
    print(f"- Range: [{stats['min']}, {stats['max']}]")
    print(f"- Quantiles (25th, 50th, 75th): {stats['quantiles'].round(2)}")
    print(f"- Perfect Predictions: {stats['perfect']} ({stats['perfect']/len(values)*100:.2f}%)")

In [None]:
show_state(results_df["metrics"])

In [None]:
print(results_df)