## Step 1. Data Processing

### 1.1 Get Dataset
1. Follow the instruction from [IMGUR5K Handwriting Dataset Repository](https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset) to collect the dataset.

2. Uploaded the dataset to your Google Drive

### 1.2 Import Dataset

In [1]:
# your path to image dataset
image_dataset_path = './dataset'

In [2]:
'''
return a DataFrame of image Dataset
with two column:
id (as labeled in the dataset): each image has a unique id
url: url of the source of the image
'''
import os
import pandas as pd

def create_image_df():
    assert image_dataset_path is not None, "Set image dataset to the folder that include IMGUR5K Handwriting Dataset "
    # List to hold file names (without extensions)
    ids = []

    # Iterate over all files in the directory
    for filename in os.listdir(image_dataset_path):
        # Get the full path of the file
        full_path = os.path.join(image_dataset_path, filename)

        # Skip directories, only process files
        if os.path.isfile(full_path):
            # Split the filename into name and extension
            name, _ = os.path.splitext(filename)

            # Add the file name (without extension) to the list
            ids.append(name)

    # Create a DataFrame with two columns: ids and url
    df = pd.DataFrame({
        "id": ids
    })

    return df


# call function
df = create_image_df()

In [3]:
# overview of df
df.head()

Unnamed: 0,id
0,A3b6IRB
1,PgtVlbz
2,5xiSlCK
3,ZE1HUdf
4,9VHX0vQ


In [4]:
from PIL import Image

def read_image(path):
  image = Image.open(path).convert("RGB")
  return image

In [5]:

# quick access to image in the dataset
def read_my_image(index):
  assert image_dataset_path is not None, "Set image dataset to the folder that include IMGUR5K Handwriting Dataset "
  path = image_dataset_path+"/"+ str(index) + '.jpg'
  return read_image(path)

In [6]:
import matplotlib.pyplot as plt

def show_image(index):
  # Read the image
  img = read_my_image(index)  # Replace 'index' with your actual index

  # Display the image in a smaller size
  plt.figure(figsize=(10, 6))  # Adjust width and height in inches
  plt.imshow(img)
  plt.axis('off')  # Turn off axis numbers and ticks
  plt.show()

In [None]:
# check image in df
show_image("qjXbC0Y")

### 1.3 Label Ground Truth Text

In [8]:
import csv
# Get data infomation from the lst
data_info = './dataset_info/imgur5k_data.lst'

# Read url: actual url of the image
# Location: the location of the text based on the image
# text: Ground Truth Text
df_info = pd.read_csv(data_info, sep='\t', header=None, names=['url', 'location', 'text'], quoting=csv.QUOTE_NONE, )

# Extract the 'id' from the URL
df_info['id'] = df_info['url'].apply(lambda x: x.split('/')[-1].split('.')[0])

# Reorder columns to match the desired output
df_info = df_info[['id', 'url', 'location', 'text']]

In [9]:
df = pd.merge(df, df_info, on='id', how='left')

In [10]:
df.head()

Unnamed: 0,id,url,location,text
0,A3b6IRB,https://i.imgur.com/A3b6IRB.jpg,"[808.0, 179.33, 239.0, 47.0, -21.0]",SURGERY
1,A3b6IRB,https://i.imgur.com/A3b6IRB.jpg,"[775.67, 202.67, 90.33, 34.0, -18.67]",#2
2,A3b6IRB,https://i.imgur.com/A3b6IRB.jpg,"[932.67, 371.33, 139.67, 55.67, -18.0]",N2O2
3,A3b6IRB,https://i.imgur.com/A3b6IRB.jpg,"[661.67, 374.0, 152.0, 40.33, -25.0]",CRASH
4,A3b6IRB,https://i.imgur.com/A3b6IRB.jpg,"[803.33, 431.33, 142.33, 40.0, -22.0]",CART


In [11]:
df = df.dropna()

### 1.4 Data Cleaning

We need to check if all the ground-truth text matches the actual content.

##### Speical characters
Check if the special characters are being processed corrected

In [12]:
import pandas as pd

# Patter for all Special characters
special_char_pattern = r'[^a-zA-Z0-9\s]'  # Matches anything not alphanumeric or whitespace

# Select rows with special characters
special_char_rows = df[df['text'].str.contains(special_char_pattern, regex=True, na=False)]

In [13]:
special_char_rows.head()

Unnamed: 0,id,url,location,text
1,A3b6IRB,https://i.imgur.com/A3b6IRB.jpg,"[775.67, 202.67, 90.33, 34.0, -18.67]",#2
5,A3b6IRB,https://i.imgur.com/A3b6IRB.jpg,"[611.0, 393.0, 91.5, 47.5, -27.0]",W/
24,PgtVlbz,https://i.imgur.com/PgtVlbz.jpg,"[1110.5, 650.5, 201.0, 66.5, 0.0]",NCM's
61,PgtVlbz,https://i.imgur.com/PgtVlbz.jpg,"[1604.5, 1972.5, 177.5, 53.5, -1.0]","come,"
74,PgtVlbz,https://i.imgur.com/PgtVlbz.jpg,"[2150.5, 1133.5, 249.0, 73.0, -0.5]",2:30pm


In [None]:
show_image("A3b6IRB")

The special charcters are being processed correctly. Confirm more cases.

In [15]:
# Math Symbols

In [16]:
# Define the pattern of allowed characters:
# - \w for word characters (letters, numbers, underscore)
# - \s for whitespace
# - Specific punctuation, math symbols, and escape characters
# - Double backslash (\\) represents a single backslash
allowed_pattern = r'^[\w\s\.,!?;:\-+*/=()\[\]{}<>@#\$%^&_\'"\t\n]+$'
mask = ~df['text'].str.contains(allowed_pattern, regex=True)
non_standard_rows = df[mask]

In [17]:
non_standard_rows.head()

Unnamed: 0,id,url,location,text
335,sNpIWnz,https://i.imgur.com/sNpIWnz.jpg,"[153.0, 851.0, 23.0, 13.0, 162.0]",·USA
341,sNpIWnz,https://i.imgur.com/sNpIWnz.jpg,"[130.0, 785.0, 22.0, 11.0, -21.0]",Japan|
342,sNpIWnz,https://i.imgur.com/sNpIWnz.jpg,"[150.0, 791.0, 23.0, 10.0, -16.0]",Japan|
407,hhk8nvy,https://i.imgur.com/hhk8nvy.jpg,"[611.0, 317.5, 145.0, 72.5, 2.5]",4\/22\/2016.
645,J8E8rsy,https://i.imgur.com/J8E8rsy.jpg,"[756.0, 646.0, 506.0, 162.0, -2.0]",\/r\/handwriting!


In [None]:
show_image("sNpIWnz")

In [None]:
show_image("hhk8nvy")

The text that contain "\\/" seems to be just "/", check more to confirm. 

In [20]:
mask = df['text'].str.contains(r'\\', regex=True)
check_rows= df[mask]

In [21]:
check_rows.head()

Unnamed: 0,id,url,location,text
407,hhk8nvy,https://i.imgur.com/hhk8nvy.jpg,"[611.0, 317.5, 145.0, 72.5, 2.5]",4\/22\/2016.
645,J8E8rsy,https://i.imgur.com/J8E8rsy.jpg,"[756.0, 646.0, 506.0, 162.0, -2.0]",\/r\/handwriting!
5632,QikXzEi,https://i.imgur.com/QikXzEi.jpg,"[657.0, 181.0, 71.0, 44.0, 1.0]",\/L=
5662,QikXzEi,https://i.imgur.com/QikXzEi.jpg,"[357.0, 778.0, 121.0, 56.0, 0.0]",K+\/H+
5680,QikXzEi,https://i.imgur.com/QikXzEi.jpg,"[382.0, 998.0, 158.0, 67.0, 0.0]",mEq\/L


In [None]:
show_image("QikXzEi")

It appears that the text that contain "\\\/" seems to be just "/", replace them

In [23]:
df['text'] = df['text'].str.replace('\\/', '/', regex=False)

In [24]:
df[df['text'].str.contains(r'/', regex=True)].head()

Unnamed: 0,id,url,location,text
5,A3b6IRB,https://i.imgur.com/A3b6IRB.jpg,"[611.0, 393.0, 91.5, 47.5, -27.0]",W/
320,IvfnYNp,https://i.imgur.com/IvfnYNp.jpg,"[735.0, 877.33, 161.33, 106.0, -12.0]",41/365
407,hhk8nvy,https://i.imgur.com/hhk8nvy.jpg,"[611.0, 317.5, 145.0, 72.5, 2.5]",4/22/2016.
450,85jTx48,https://i.imgur.com/85jTx48.jpg,"[1328.5, 1837.0, 176.0, 89.5, -0.5]",320/365
645,J8E8rsy,https://i.imgur.com/J8E8rsy.jpg,"[756.0, 646.0, 506.0, 162.0, -2.0]",/r/handwriting!


In [25]:
mask = ~df['text'].str.contains(allowed_pattern, regex=True)
non_standard_rows2 = df[mask]

In [26]:
print("Words with special character:", len(non_standard_rows2), ", Percentage: ", len(non_standard_rows2)/len(df))
print("Images with special charatcer:", len(non_standard_rows2['id'].unique()), ", Percentage: ", len(non_standard_rows2['id'].unique())/len(df['id'].unique()))

Words with special character: 2004 , Percentage:  0.008826055361035872
Images with special charatcer: 565 , Percentage:  0.07076653306613226


Since there aren't too many images that contain uncommon special characters, we would remove these images that contains uncommon special characters to ensure the ground-truth text are correct and limited vocabulary.

In [27]:
df=df[~mask]

In [28]:
print("total number of words", len(df))
print("total number of images", len(df['id'].unique()))

total number of words 225051
total number of images 7977


In [29]:
# confirm there is no special characters
count_matching = df['text'].str.contains(allowed_pattern, regex=True, na=False).sum()
print(f"Number of rows with allowed characters: {count_matching}")

Number of rows with allowed characters: 225051


#### All-symbols text
There might be some error if a word only contains symbols

In [30]:
pattern = r'^[^a-zA-Z0-9]+$'  # Matches strings with no alphanumeric chars at all
non_alnum_rows = df[df['text'].str.contains(pattern, regex=True, na=False)]

In [31]:
non_alnum_rows.head(20)

Unnamed: 0,id,url,location,text
105,5xiSlCK,https://i.imgur.com/5xiSlCK.jpg,"[151.0, 729.0, 82.5, 58.5, -0.5]",.
107,5xiSlCK,https://i.imgur.com/5xiSlCK.jpg,"[349.5, 738.0, 27.5, 53.0, -1.5]",.
213,5xiSlCK,https://i.imgur.com/5xiSlCK.jpg,"[750.5, 2156.5, 80.0, 51.0, 2.0]",.
226,5xiSlCK,https://i.imgur.com/5xiSlCK.jpg,"[523.0, 821.0, 152.0, 65.0, 0.0]",.
234,tMQDhwl,https://i.imgur.com/tMQDhwl.jpg,"[633.0, 508.0, 784.0, 123.0, -1.0]",.
236,tMQDhwl,https://i.imgur.com/tMQDhwl.jpg,"[1049.0, 306.0, 1041.0, 158.0, 2.0]",.
237,tMQDhwl,https://i.imgur.com/tMQDhwl.jpg,"[783.0, 651.0, 624.0, 160.0, -3.0]",.
241,tMQDhwl,https://i.imgur.com/tMQDhwl.jpg,"[979.0, 843.0, 532.0, 94.0, -4.0]",.
242,tMQDhwl,https://i.imgur.com/tMQDhwl.jpg,"[715.0, 1006.0, 673.0, 177.0, -2.0]",.
287,FEEMbfi,https://i.imgur.com/FEEMbfi.jpg,"[1755.5, 2735.5, 90.0, 46.0, 90.0]",-


In [None]:
show_image("5xiSlCK")

In [None]:
show_image("tMQDhwl")

In [None]:
show_image("IvfnYNp")

There appear to be some inital fault in labels, all the ones that contains only one period seem to be incorrect. Remove them all

In [35]:
only_period_rows= df[df['text'] == '.']

In [36]:
print("Words with only one period:", len(only_period_rows), ", Percentage: ", len(only_period_rows)/len(df))
print("Images with special charatcer:", len(only_period_rows['id'].unique()), ", Percentage: ", len(only_period_rows['id'].unique())/len(df['id'].unique()))

Words with only one period: 19567 , Percentage:  0.08694473697073107
Images with special charatcer: 3762 , Percentage:  0.4716058668672433


In [37]:
df = df[df['text'] != '.']

In [38]:
# check other rows that have ony characters
pattern = r'^[^a-zA-Z0-9]+$'  # Matches strings with no alphanumeric chars at all
non_alnum_rows2 = df[df['text'].str.contains(pattern, regex=True, na=False)]

In [39]:
non_alnum_rows2.head()

Unnamed: 0,id,url,location,text
287,FEEMbfi,https://i.imgur.com/FEEMbfi.jpg,"[1755.5, 2735.5, 90.0, 46.0, 90.0]",-
363,hhk8nvy,https://i.imgur.com/hhk8nvy.jpg,"[20.33, 544.33, 17.67, 25.67, 0.0]",*
402,hhk8nvy,https://i.imgur.com/hhk8nvy.jpg,"[630.0, 646.5, 34.0, 51.0, 1.0]",*
890,Q5aAt8W,https://i.imgur.com/Q5aAt8W.jpg,"[265.0, 1172.0, 23.0, 27.0, 10.0]",Σ
895,Q5aAt8W,https://i.imgur.com/Q5aAt8W.jpg,"[359.0, 1159.0, 8.0, 10.0, 0.0]",=


In [None]:
show_image("hhk8nvy")

In [53]:
df.to_csv('df_cleaned_info.csv')

### 1.4 Final Test Set
Create a final test set to evalute the final model and comapre between different model

In [41]:
import numpy as np

In [45]:
# Get unique groups
unique_images = df['id'].unique()

# Randomly select 10% for test 
np.random.seed(42)
test_images = np.random.choice(unique_images, 
                              size=int(len(unique_images)*0.2), 
                              replace=False)

In [49]:
test_df = df[df['id'].isin(test_images)]
training_df= df[~df['id'].isin(test_images)]

In [50]:
print("Words in Test Dataset:", len(test_df), ", Percentage: ", len(test_df)/len(df))
print("Images in Test Dataset:", len(test_df['id'].unique()), ", Percentage: ", len(test_df['id'].unique())/len(df['id'].unique()))

Words in Test Dataset: 41885 , Percentage:  0.20383582176714488
Images in Test Dataset: 1565 , Percentage:  0.19992335206949413


In [51]:
test_df.to_csv('df_test_info.csv')

In [52]:
training_df.to_csv('df_train_info.csv')

## Step 2. Base Model Evaluation

### 2.1 Import Base Model
The base model is from [TrOCR](https://huggingface.co/microsoft/trocr-base-handwritten) (base-sized model, fine-tuned on IAM) model fine-tuned on the IAM dataset. We accessed it with Hugging Face

In [None]:
# import base model
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import requests

processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')

In [None]:
# input: image as RGB and model to process
# output: text
def image_to_text(model, image):
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
  generated_ids = model.generate(pixel_values)
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  return generated_text

### 2.2 Test Pre-trained Model

We select several different complexity and format of images in our dataset to test the capacity of the pre-trained model

In [None]:
#  single word, no angle
print_image("a0zqSCF")

In [None]:
# test base model on single word image
image = read_my_image("a0zqSCF")
text = image_to_text(model, image)
print(text)

In [None]:
# multi-line, single word no angle
print_image("ZIsH9sg")

In [None]:
image = read_my_image("ZIsH9sg")
text = image_to_text(model, image)
print(text)

In [None]:
# single line, multi-word, slightly angled
print_image("9W9vgCw")

In [None]:
image = read_my_image("9W9vgCw")
text = image_to_text(model, image)
print(text)

In [None]:
df[df['id']=='9W9vgCw']

In [None]:
# single line, multi word, no angle
print_image("7buo10E")

In [None]:
image = read_my_image("7buo10E")
text = image_to_text(model, image)
print(text)

In [None]:
print_image("y7YkG7L")

In [None]:
image = read_my_image("y7YkG7L")
text = image_to_text(model, image)
print(text)

In [None]:
print_image("d3JaHtH")

In [None]:
image = read_my_image("d3JaHtH")
text = image_to_text(model, image)
print(text)

#### Conlusion
1. *Single-word, single-line, no angle*: Highest successes, depedning on the complexity of each style, but the model is able to process the characters and read the basics
2. *Multi-word, single-line, no angle*: The model has some more difficulties, a little less acurate than single word
3. *Single-word, single-line, no angle*:
4. *Single-word, single-line, no angle*:

## Step 3. Fine-tune On Whole-page Images
We would first fine-tune the model on original dataset, which are full-page images.  

### 3.1 Get Dataset based on full image

In [None]:
# remove unrelated columns for full images
df_full_images = df.drop(['url','location'], axis=1)

In [None]:
df_full_images= df.groupby('id')['text'].agg(lambda x: ' '.join(map(str, x))).reset_index()

In [None]:
df_full_images['text']=df_full_images['text'].astype(str)

In [None]:
df_full_images.head()

After trials, we realize that the model has a limit od 512 token size and the GPU doesn't support high token size. So we would take the subset with text that have less than 128 tokens.

In [None]:
subdf_full_images = df_full_images[df_full_images['text'].apply(lambda x: len(x) <= 128)]

In [None]:
len(df_full_images)

In [None]:
len(subdf_full_images)

### 3.2 Format Dataset for training

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image

class StyleDataset(Dataset):
    def __init__(self, df, processor, max_target_length=128):
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
      try:
          text = self.df['text'][idx]
          if not isinstance(text, str) or not text.strip():
              raise ValueError(f"Invalid text at index {idx}: {repr(text)}")
          image_id = self.df['id'][idx]
          try:
              image = read_my_image(image_id)
          except Exception as e:
              raise ValueError(f"Failed to load image for ID {image_id} at index {idx}") from e
          try:
              pixel_values = self.processor(image, return_tensors="pt").pixel_values
          except Exception as e:
              raise ValueError(f"Image processing failed at index {idx}") from e

          if torch.isnan(pixel_values).any() or torch.isinf(pixel_values).any():
              raise ValueError(f"Invalid pixel values (NaN/inf) at index {idx}")
          try:
              labels = self.processor.tokenizer(
                  text,
                  padding="max_length",
                  max_length=self.max_target_length
              ).input_ids
          except Exception as e:
              raise ValueError(f"Tokenization failed for text at index {idx}") from e

          # Replace pad_token_id with -100 for loss masking
          labels = [
              label if label != self.processor.tokenizer.pad_token_id else -100
              for label in labels
          ]
          encoding = {
              "pixel_values": pixel_values.squeeze(),
              "labels": torch.tensor(labels)
          }

          if encoding["pixel_values"].dim() != 3:
              raise ValueError(f"Invalid pixel_values shape at index {idx}")

          if encoding["labels"].numel() != self.max_target_length:
              raise ValueError(f"Labels length mismatch at index {idx}")

          return encoding

      except Exception as e:
          print(f"\nError in sample {idx}:")
          print(f"   Error type: {type(e).__name__}")
          print(f"   Details: {str(e)}")
          if hasattr(e, '__cause__') and e.__cause__:
              print(f"   Underlying error: {type(e.__cause__).__name__}: {str(e.__cause__)}")
          print(f"   DataFrame row:\n{self.df.iloc[idx]}")
          return None

### 3.2 Divide Into Validation and Training sets

In [None]:
from sklearn.model_selection import train_test_split
import pandas as pd

train_df, eval_df = train_test_split(subdf_full_images, test_size=0.2, random_state=42)

train_df = train_df.reset_index(drop=True)
eval_df = eval_df.reset_index(drop=True)

# Split into validation and training sets
train_dataset = StyleDataset(df=train_df,processor=processor)
eval_dataset= StyleDataset(df=eval_df,processor=processor)

In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

In [None]:
def get_label_str(encoding):
  labels = encoding['labels']
  labels[labels == -100] = processor.tokenizer.pad_token_id
  label_str = processor.decode(labels, skip_special_tokens=True)
  return label_str


In [None]:
get_label_str(train_dataset[42])

In [None]:
get_label_str(eval_dataset[42])

### 3.3 Full-page Image


In [None]:
# Analyze your dataset first
avg_target_len = df['labels'].apply(len).mean()
max_target_len = int(df['labels'].apply(len).quantile(0.95))

# ======================
# Model Configuration
# Aligned with Dataset
# ======================

# Token Alignment
model.config.update({
    # Must match tokenizer's vocabulary
    'vocab_size': len(processor.tokenizer),
    
    # Special tokens (critical for sequence processing)
    'pad_token_id': processor.tokenizer.pad_token_id,       # Padding token
    'bos_token_id': processor.tokenizer.cls_token_id,       # Beginning of sequence
    'eos_token_id': processor.tokenizer.sep_token_id,       # End of sequence
    'decoder_start_token_id': processor.tokenizer.cls_token_id,  # First decoder token
    
    # Sequence Length (adjust based on dataset statistics)
    'max_length': 64,               # Should cover 95% of target sequences
    'min_length': 1,                # Minimum generation length
    
    # Beam Search Parameters
    'num_beams': 4,                 # Balance quality (4) vs speed (1)
    'early_stopping': True,         # Stop when all beams finish
    'length_penalty': 2.0,          # >1 = longer outputs, <1 = shorter
    
    # Repetition Control
    'no_repeat_ngram_size': 3,      # Block repeating 3-grams
    'repetition_penalty': 1.2,      # 1.0=no penalty, >1 reduces repeats
    
    # Sampling Parameters
    'temperature': 0.7,             # 0.5-1.0 (lower = more deterministic)
    'top_k': 50,                    # Top 50 tokens considered
    'top_p': 0.9,                   # Nucleus sampling threshold
    'do_sample': False,             # Set True for creative tasks
    
    # Architecture (typically keep pretrained defaults)
    'encoder_layers': model.config.encoder_layers,
    'decoder_layers': model.config.decoder_layers,
    'd_model': model.config.d_model,
    'attention_heads': model.config.attention_heads
})

# ======================
# Verification Checks
# ======================
assert model.config.vocab_size == len(processor.tokenizer), "Vocabulary mismatch!"
assert model.config.pad_token_id == processor.tokenizer.pad_token_id, "Pad token mismatch!"

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=12,  # Increased from 8 (if GPU memory allows)
    per_device_eval_batch_size=12,
    fp16=True,  # Keep mixed precision
    output_dir="./output/models/",
    logging_steps=100,  # Reduced logging frequency
    save_steps=500,  # Less frequent checkpointing
    eval_steps=250,  # Slightly less frequent evaluation
    gradient_accumulation_steps=1,  # Remove if previously set higher
    warmup_steps=100,  # Reduced warmup
    learning_rate=3e-5,  # Slightly higher LR for faster convergence
    num_train_epochs=3,  # Reduced if your data allows
    optim="adamw_torch",  # More efficient optimizer
    report_to="none",  # Disable expensive logging integrations
)

In [None]:
from evaluate import load

cer_metric = load("cer")

In [None]:
from evaluate import load
import numpy as np
from collections import defaultdict

# Initialize metrics
cer_metric = load("cer")
wer_metric = load("wer")

def compute_word_level_metrics(pred_str, label_str):
    """Analyze errors at word level while preserving context"""
    word_cer = []
    error_types = {
        'missing_words': 0,
        'correct_words': 0,
        'partial_words': 0,
        'extra_words': 0
    }
    
    for pred, label in zip(pred_str.split(), label_str.split()):
        word_cer.append(cer_metric.compute(predictions=[pred], references=[label]))
        
        if pred == label:
            error_types['correct_words'] += 1
        elif pred in label_str:
            error_types['partial_words'] += 1
        elif pred not in label_str and len(pred) > 0:
            error_types['extra_words'] += 1
    
    # Count missing words (present in label but not prediction)
    error_types['missing_words'] = len(set(label_str.split()) - set(pred_str.split()))
    
    return {
        'word_cer_avg': np.mean(word_cer),
        'word_cer_distribution': np.percentile(word_cer, [25, 50, 75]),
        **error_types
    }

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # Decode with punctuation preservation
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

    # Standard metrics
    metrics = {
        'cer': cer_metric.compute(predictions=pred_str, references=label_str),
        'wer': wer_metric.compute(predictions=pred_str, references=label_str),
    }

    # Word-level analysis
    word_metrics = defaultdict(list)
    for p, l in zip(pred_str, label_str):
        word_metrics.update(compute_word_level_metrics(p, l))
    
    # Combine results
    return {
        **metrics,
        'word_level': {
            'avg_cer_per_word': np.mean(word_metrics['word_cer_avg']),
            'correct_words_ratio': np.mean([m['correct_words']/(m['correct_words']+m['partial_words']) for m in word_metrics]),
            'missing_words_ratio': np.mean([m['missing_words']/len(l.split()) for m, l in zip(word_metrics, label_str)]),
            'cer_distribution': {
                '25th': np.percentile(word_metrics['word_cer_avg'], 25),
                'median': np.percentile(word_metrics['word_cer_avg'], 50),
                '75th': np.percentile(word_metrics['word_cer_avg'], 75)
            }
        },
        # Normalized scores (0-100 scale)
        'scores': {
            'overall_quality': 100 * (1 - metrics['wer']),  # WER-based
            'word_accuracy': 100 * np.mean([m['correct_words']/len(l.split()) for m, l in zip(word_metrics, label_str)])
        }
    }

In [None]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)
trainer.train()