## Step 1. Data Processing

### 1.1 Loading the Data

In [None]:
import os

# Change to your desired directory
# os.chdir('/common/users/$USER/df_words') # change $USER to netid

# Confirm it's changed
# print("Current directory:", os.getcwd())

images_dir = '.\iam_words'

In [None]:
import pandas as pd
from PIL import Image
label_file_path = images_dir + '\words.txt'
image_file_path = images_dir + '\iam_words\words'

data = []
with open(label_file_path, 'r') as f:
    lines = f.readlines()

for idx, line in enumerate(lines[18:]):
    if idx % 1000 == 0:
        print(f"Processing line {idx}")
    row = []
    tokens = line.strip().split()
    if len(tokens) < 2:
        continue

    subfolder = tokens[0].split('-')[0]
    subfolder2 = subfolder + "-" + tokens[0].split('-')[1]
    image_file_name = subfolder + "\\" + subfolder2 + "\\" + tokens[0] + ".png"
    image_path = os.path.join(image_file_path, image_file_name)
    try:
        with Image.open(image_path) as img:
            if img.size[0] >= 10 and img.size[1] >= 10:
                img_rgb = img.convert("RGB")  # Convert to RGB
                img_copy = img_rgb.copy()     # Copy after conversion
    except FileNotFoundError as e:
        print(f"Image file not found: {image_file_path}. Error: {e}")
        continue
    except Image.UnidentifiedImageError as e:
        print(f"Unidentified image error for file {image_file_path}: {e}")
        continue
    except Exception as e:
        print(f"Error opening image file {image_file_path}: {e}")
        continue
    row = [image_path, tokens[1], tokens[2], tokens[-1], img_copy]
    # if len(row) != 10:
    #     print(f"Row length mismatch: {len(row)} elements in row: {row}")
    #     continue
    data.append(row)


print(f"Length of a row in data: {len(data[0])}")  # Should print 10

print(data[0])
loaded = pd.DataFrame(data, columns=['image_id', 'segmentation_status', 'graylevel','text', 'image'])

In [None]:
loaded_dfwords = loaded.copy()
print(loaded_dfwords.info())
print(loaded_dfwords.head(30))

### 1.2 Cleaning the Data

In [None]:
import matplotlib.pyplot as plt

def show_image(df, row):
    # plt.imshow(df.iloc[row]['image'], cmap='gray')
    img = Image.open(df.iloc[row]['image_id'])
    plt.imshow(img, cmap='gray')
    plt.show()

In [None]:
show_image(loaded_dfwords, 10)  # Show the first image

In [None]:
import re

# Patter for all Special characters
special_char_pattern = r'[^a-zA-Z0-9\s]'  # Matches anything not alphanumeric or whitespace

# Select rows with special characters
special_char_rows = loaded_dfwords[loaded_dfwords['text'].str.contains(special_char_pattern, regex=True, na=False)]

In [None]:
special_char_rows.head(10)

In [None]:
allowed_pattern = r'^[\w\s\.,!?;:\-+*/=()\[\]{}<>@#\$%^&_\'"\t\n]+$'
mask = ~loaded_dfwords['text'].str.contains(allowed_pattern, regex=True)
non_standard_rows = loaded_dfwords[mask]

In [None]:
non_standard_rows.head()

In [None]:
mask = loaded_dfwords['text'].str.contains(r'\\', regex=True)
check_rows= loaded_dfwords[mask]

In [None]:
loaded_dfwords['text'] = loaded_dfwords['text'].str.replace('\\/', '/', regex=False)

In [None]:
mask = ~loaded_dfwords['text'].str.contains(allowed_pattern, regex=True)
non_standard_rows2 = loaded_dfwords[mask]

In [None]:
print("Words with special character:", len(non_standard_rows2), ", Percentage: ", len(non_standard_rows2)/len(loaded_dfwords))
print("Images with special character:", len(non_standard_rows2['image_id'].unique()), ", Percentage: ", len(non_standard_rows2['image_id'].unique())/len(loaded_dfwords['image_id'].unique()))

In [None]:
loaded_dfwords=loaded_dfwords[~mask]

In [None]:
print("total number of words", len(loaded_dfwords))


In [None]:
# confirm there is no special characters
count_matching = loaded_dfwords['text'].str.contains(allowed_pattern, regex=True, na=False).sum()
print(f"Number of rows with allowed characters: {count_matching}")

In [None]:
pattern = r'^[^a-zA-Z0-9]+$'  # Matches strings with no alphanumeric chars at all
non_alnum_rows = loaded_dfwords[loaded_dfwords['text'].str.contains(pattern, regex=True, na=False)]

In [None]:
print("total number of words", len(loaded_dfwords))


In [None]:
non_alnum_rows.head(20)

In [None]:
# check other rows that have only characters
pattern = r'^[^a-zA-Z0-9]+$'  # Matches strings with no alphanumeric chars at all
non_alnum_rows2 = loaded_dfwords[loaded_dfwords['text'].str.contains(pattern, regex=True, na=False)]

In [None]:
non_alnum_rows2.head()

In [None]:
non_alnum_rows2['text'].value_counts()

In [None]:
# remove these to match the other imgur dataset preprocessing (it also removes all instances of text as '.' due to incorrect labels)
only_period_rows= loaded_dfwords[loaded_dfwords['text'] == '.']
loaded_dfwords = loaded_dfwords[loaded_dfwords['text'] != '.']

In [None]:

hyphen_row = loaded_dfwords[loaded_dfwords['text'] == '-----------------------------------------------------']
id = hyphen_row['image_id'].to_string()
print(id)
print("image:", id.split('\\')[-1])  # Print the image file name
r = loaded_dfwords[loaded_dfwords['image_id'] == '.\iam_words\iam_words\words\p02\p02-109\p02-109-01-00.png']  # Get the row with the hyphen image
print(r)
show_image(r, 0)  # Show the hyphen image
# show_image(loaded_dfwords, hyphen_row.index[0] + 1)  # Show the hyphen image

In [None]:
loaded_dfwords = loaded_dfwords[loaded_dfwords['text'] != '-----------------------------------------------------']


In [None]:
loaded_dfwords = loaded_dfwords.reset_index(drop=True)

In [None]:
loaded_dfwords.info()

### 1.3 Splitting the Data into Training and Testing Subsets

In [None]:
import numpy as np

# Get unique groups
unique_images = loaded_dfwords['image_id'].unique()


# Randomly select 10% for test 
np.random.seed(42)
test_images = np.random.choice(unique_images, 
                              size=int(len(unique_images)*0.2), 
                              replace=False)

In [None]:
test_df = loaded_dfwords[loaded_dfwords['image_id'].isin(test_images)]
training_df = loaded_dfwords[~loaded_dfwords['image_id'].isin(test_images)]

In [None]:
print("Words in Train Dataset:", len(training_df), ", Percentage: ", len(training_df)/len(loaded_dfwords))

In [None]:
print("Words in Test Dataset:", len(test_df), ", Percentage: ", len(test_df)/len(loaded_dfwords))

In [None]:
print("total number of words", len(loaded_dfwords))

### 1.4 Splitting the Training Data into Training and Validation Subsets

In [None]:
from sklearn.model_selection import train_test_split
import pandas as pd

train_df, eval_df = train_test_split(training_df, test_size=0.2, random_state=42)

train_df = train_df.reset_index(drop=True)
eval_df = eval_df.reset_index(drop=True)

In [None]:
train_df.info()

In [None]:
eval_df.info()

In [None]:
import matplotlib.pyplot as plt

print(train_df.iloc[0])
show_image(train_df, 0)

print(eval_df.iloc[0])
show_image(eval_df, 0)

print(test_df.iloc[0])
show_image(test_df, 0)

### 1.5 Saving the Dataset to CSV

In [None]:
# test_df_copy = test_df
# train_df_copy = training_df

In [None]:
# test_df_copy = test_df_copy.drop('image', axis=1)

In [None]:
# train_df_copy = train_df_copy.drop('image', axis=1)

In [None]:
# test_df_copy.info()

In [None]:
# train_df_copy.info()

In [None]:
# test_df_copy = test_df_copy.reset_index()
# test_df_copy['word_id'] = test_df_copy.index
# test_df_copy = test_df_copy.drop('index', axis=1)
# test_df_copy = test_df_copy.drop('level_0', axis=1)

# print(test_df_copy)

In [None]:
# train_df_copy = train_df_copy.reset_index()
# train_df_copy['word_id'] = train_df_copy.index
# train_df_copy = train_df_copy.drop('index', axis=1)
# train_df_copy = train_df_copy.drop('level_0', axis=1)

# print(train_df_copy)

In [None]:
# train_df_copy = train_df_copy[[train_df_copy.columns[2]] + train_df_copy.columns[:2].tolist()]

In [None]:
# test_df_copy = test_df_copy[[test_df_copy.columns[2]] + test_df_copy.columns[:2].tolist()]

In [None]:
# train_df_copy.info()

In [None]:
# test_df_copy.info()

In [None]:
# train_df_copy.head(10)

In [None]:
# test_df_copy.head(10)

In [None]:
# test_df_copy.to_csv('df_test.csv', index=False)
# train_df_copy.to_csv('df_train.csv', index=False)

## Step 2. Running the Model

### 2.1 Loading the Model

In [None]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

In [None]:
# get base model
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-stage1')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-stage1')

### 2.2 Splitting the Training Data into Training and Validation Datasets

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image

class StyleDataset(Dataset):
    def __init__(self, df, processor, max_target_length=512):
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
      try:
          text = self.df['text'][idx]
          if not isinstance(text, str) or not text.strip():
              raise ValueError(f"Invalid text at index {idx}: {repr(text)}")
          image_id = self.df['image_id'][idx]
          try:
              image = self.df['image'][idx]
          except Exception as e:
              raise ValueError(f"Failed to load image for ID {image_id} at index {idx}") from e
          try:
              pixel_values = self.processor(image, return_tensors="pt").pixel_values
          except Exception as e:
              raise ValueError(f"Image processing failed at index {idx}") from e

          if torch.isnan(pixel_values).any() or torch.isinf(pixel_values).any():
              raise ValueError(f"Invalid pixel values (NaN/inf) at index {idx}")
          try:
              labels = self.processor.tokenizer(
                  text,
                  padding="max_length",
                  max_length=self.max_target_length
              ).input_ids
          except Exception as e:
              raise ValueError(f"Tokenization failed for text at index {idx}") from e

          # Replace pad_token_id with -100 for loss masking
          labels = [
              label if label != self.processor.tokenizer.pad_token_id else -100
              for label in labels
          ]
          encoding = {
              "pixel_values": pixel_values.squeeze(),
              "labels": torch.tensor(labels)
          }

          if encoding["pixel_values"].dim() != 3:
              raise ValueError(f"Invalid pixel_values shape at index {idx}")

          if encoding["labels"].numel() != self.max_target_length:
              raise ValueError(f"Labels length mismatch at index {idx}")

          return encoding

      except Exception as e:
          print(f"\nError in sample {idx}:")
          print(f"   Error type: {type(e).__name__}")
          print(f"   Details: {str(e)}")
          if hasattr(e, '__cause__') and e.__cause__:
              print(f"   Underlying error: {type(e.__cause__).__name__}: {str(e.__cause__)}")
          print(f"   DataFrame row:\n{self.df.iloc[idx]}")
          return None

In [None]:
# Tokenized
train_dataset = StyleDataset(df=train_df,processor=processor)
eval_dataset= StyleDataset(df=eval_df,processor=processor)

In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

In [None]:
# get the label string from encoding
def get_label_str(encoding):
  labels = encoding['labels']
  labels[labels == -100] = processor.tokenizer.pad_token_id
  label_str = processor.decode(labels, skip_special_tokens=True)
  return label_str

In [None]:
print(1)
get_label_str(train_dataset[0])


In [None]:
print(train_df.iloc[0])
show_image(train_df, 0)

### 2.3 Model Configuration

In [None]:
# Analyze your dataset first
avg_target_len = training_df['text'].apply(len).mean()
print("average target length", avg_target_len)
max_target_len = int(training_df['text'].apply(len).quantile(0.95))
print("maximum target length", max_target_len)

In [None]:
# Token Alignment
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = len(processor.tokenizer)

In [None]:
from transformers import GenerationConfig

generation_config = GenerationConfig(
    max_length=64,
    early_stopping=True,
    num_beams=4,
    length_penalty=2.0,
    no_repeat_ngram_size=3,
    eos_token_id=processor.tokenizer.sep_token_id,
    decoder_start_token_id=processor.tokenizer.cls_token_id,
    pad_token_id=processor.tokenizer.pad_token_id,
    skip_special_tokens=True  # Added for consistent decoding
)

### 2.4 Metrics

In [None]:
from evaluate import load
cer_metric = load("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

## Step 3. Fine-tune

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    eval_strategy="steps",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    fp16=True,
    output_dir="./output/models/",
    logging_steps=2,
    save_steps=1000,
    eval_steps=200,
    num_train_epochs=6,
    generation_config=generation_config)

In [None]:
for idx, data in enumerate(train_dataset):
    if data is None or any(d is None for d in data.values()):
        print(f"None found in dataset at index {idx}: {data}")

In [None]:
from transformers import default_data_collator
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    processing_class=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)
trainer.train()
processor.save_pretrained("./output/models/")


## Step 4. Evaluate

### 4.1 Load model

In [None]:
from transformers import VisionEncoderDecoderModel
from transformers import TrOCRProcessor
model_path = "./output/models/checkpoint-110333"
model =  VisionEncoderDecoderModel.from_pretrained(model_path).to("cuda")
# processor = TrOCRProcessor.from_pretrained(model_path)

### 4.2 Load Testing Dataset

In [None]:
test_df.info()

### 4.3 Do Inference

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from tqdm import tqdm

def readText_batch(df, indices):
    """Process multiple images at once"""
    #     subfolder = tokens[0].split('-')[0]
    # subfolder2 = subfolder + "-" + tokens[0].split('-')[1]
    # image_file_name = subfolder + "\\" + subfolder2 + "\\" + tokens[0] + ".png"
    # try:
    #     with Image.open(os.path.join(image_file_path, image_file_name)) as img:
    # subfolder = df['image_id'][indices[0]].split('\\')[0]
    # subfolder2 = subfolder + "-" + df['image_id'][indices[0]].split('-')[1]
    # image_dataset_path = os.path.join(image_file_path, subfolder, subfolder2, '')
    paths = [df['image_id'].iloc[idx] for idx in indices]
    images= [Image.open(path).convert("RGB") for path in paths]
    pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device)
    generated_ids = model.generate(pixel_values)
    return processor.batch_decode(generated_ids, skip_special_tokens=True)

def process_all_rows_batched(df, batch_size=8):
    results = []
    for i in tqdm(range(0, len(df), batch_size), desc="Processing batches"):
        batch_indices = range(i, min(i+batch_size, len(df)))
        try:
            batch_texts = readText_batch(df, batch_indices)
            for idx, text in zip(batch_indices, batch_texts):
                results.append({
                    'id': df['image_id'].iloc[idx],
                    'preds': df['text'].iloc[idx],
                    'labels': text
                })
        except Exception as e:
            print(f"Error in batch {i//batch_size}: {str(e)}")
            for idx in batch_indices:
                results.append({
                    'id': df['image_id'].iloc[idx],
                    'labels': df['text'].iloc[idx],
                    'preds': None,
                    'error': str(e)
                })
    return pd.DataFrame(results)

In [None]:
results_df = process_all_rows_batched(test_df, batch_size=8)

In [None]:
results_df.head()

### 4.4 Evaluate

In [None]:
from evaluate import load
cer = load("cer")

def compute_eval_metrics(pred_str, label_str):
    pred_str=pred_str.strip()
    label_str=label_str.strip()
    # max_len = max(len(pred_str), len(label_str))
    # pred_str = pred_str.ljust(max_len)  
    # label_str = label_str.ljust(max_len)
    try: 
        score = cer.compute(predictions=[pred_str], references=[label_str])
        return score
    except Exception as e:
        print("error", e)
        print(type(pred_str), len(pred_str), pred_str)
        print(type(label_str), len(label_str), label_str)
        return None

In [None]:
from tqdm import tqdm
tqdm.pandas()  # Enable progress_apply for pandas

results_df["metrics"] = results_df.progress_apply(
    lambda row: compute_eval_metrics(row["preds"], row["labels"]),
    axis=1
)

### 4.4 Analyze Performance

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import math

def plot_eval(values):
    # Plotting
    plt.figure(figsize=(8, 5))
    sns.kdeplot(values, shade=True)
    plt.xlabel("Edit Distance")
    plt.title("KDE of Edit Distances")
    plt.show()
        
    # Boxplot
    plt.boxplot(values, vert=False, patch_artist=True)
    plt.xlabel("Edit Distance")
    plt.title("Boxplot of Edit Distances")
    
    plt.tight_layout()
    plt.show()

In [None]:
plot_eval(results_df["metrics"])

In [None]:
# remove outliers
results_normal= results_df[results_df["metrics"]<5]
plot_eval(results_normal["metrics"])

In [None]:
import numpy as np
def show_state(values):
    stats = {
        "mean": np.mean(values),
        "median": np.median(values),
        "std": np.std(values),
        "min": np.min(values),
        "max": np.max(values),
        "quantiles": np.quantile(values, [0.25, 0.5, 0.75]),
        "perfect": np.sum(values == 0)

    }
    
    print("Summary Statistics:")
    print(f"- Mean ± Std: {stats['mean']:.2f} ± {stats['std']:.2f}")
    print(f"- Median (IQR): {stats['median']:.2f} ({stats['quantiles'][0]:.2f}–{stats['quantiles'][2]:.2f})")
    print(f"- Range: [{stats['min']}, {stats['max']}]")
    print(f"- Quantiles (25th, 50th, 75th): {stats['quantiles'].round(2)}")
    print(f"- Perfect Predictions: {stats['perfect']} ({stats['perfect']/len(values)*100:.2f}%)")



In [None]:
show_state(results_df["metrics"])

In [None]:
print(results_df)