# Import Libraries

In [2]:
import pandas as pd
import evaluate
from datasets import Dataset, load_dataset, DatasetDict
from transformers import (
    TrOCRProcessor, 
    VisionEncoderDecoderModel,
    TrainingArguments,
    Trainer,
    default_data_collator
)
from PIL import Image
import os
import io
import torch

KeyboardInterrupt: 

# Run Configuration

In [None]:
print("--- RUN CONFIGURATION: PRINTED MODEL ---")

# --- 1. Define Your Project Paths ---
PROCESSED_DATA_DIR = "../output/processed_data/" 
RAW_DATA_DIR = "../../data/"
IMAGES_BASE_DIR = "../../data/images/"
OUTPUT_DIR = "../output/printed_model/"

# --- 2. Define Model & Data Paths ---
MODEL_NAME = "microsoft/trocr-base-printed"
TRAIN_PARQUET = "printed_streaming.parquet" 
VAL_PARQUET = "val_printed.parquet"

# Final paths
TRAIN_PARQUET_PATH = os.path.join(PROCESSED_DATA_DIR, TRAIN_PARQUET)
VAL_PARQUET_PATH = os.path.join(IMAGES_BASE_DIR, VAL_PARQUET) 

print(f"Model: {MODEL_NAME}")
print(f"Output: {OUTPUT_DIR}")
print(f"Training data: {TRAIN_PARQUET_PATH}")
print(f"Validation data: {VAL_PARQUET_PATH}")

# --- 3. Define Column Names ---
IMAGE_DATA_COLUMN = "image"
TEXT_LABEL_COLUMN = "text"

# Model and Data Loading

## Load Processor and Model

In [None]:
# --- 4. Load the Processor & Model ---
try:
    processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
    model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
    print(f"Processor and Model loaded from {MODEL_NAME}")
    
    # Set model config for fine-tuning
    model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.config.vocab_size = model.config.decoder.vocab_size
    
    # Set generation config
    model.config.eos_token_id = processor.tokenizer.sep_token_id
    model.config.max_length = 64
    model.config.early_stopping = True
    model.config.no_repeat_ngram_size = 3
    model.config.length_penalty = 2.0
    model.config.num_beams = 4
    
except Exception as e:
    print(f"Error loading processor/model: {e}")
    raise

## Load Metadata

In [None]:
try:
    print(f"Loading training data from: {TRAIN_PARQUET_PATH}")
    # Load the training data (we know it has 'image', 'text', 'source')
    # We use split='train' to extract the dataset from the dict
    train_dataset = load_dataset("parquet", data_files={"train": TRAIN_PARQUET_PATH}, split="train") 
    
    print(f"Loading validation data from: {VAL_PARQUET_PATH}")
    # Load the validation data (we know it has 'image' (struct), 'label')
    val_dataset = load_dataset("parquet", data_files={"validation": VAL_PARQUET_PATH}, split="validation")
    
    print("Successfully loaded datasets separately.")
    print(f"\nRaw train dataset: {train_dataset}")
    print(f"Raw val dataset: {val_dataset}")

except Exception as e:
    print(f"Error loading Parquet files: {e}")
    raise

In [None]:
# --- Standardize Column Names ---
print("\nStandardizing column schemas...")

# Rename 'label' in validation to 'text'
if TEXT_LABEL_COLUMN not in val_dataset.column_names and 'label' in val_dataset.column_names:
    print("  - Renaming 'label' to 'text' in validation set.")
    val_dataset = val_dataset.rename_column('label', TEXT_LABEL_COLUMN)

# Remove 'source' from training data to match validation
if 'source' in train_dataset.column_names:
    print("  - Removing 'source' column from training set for consistency.")
    train_dataset = train_dataset.remove_columns(['source'])

print("\n...Standardization complete.")
print(f"Cleaned train dataset features: {train_dataset.features}")
print(f"Cleaned val dataset features: {val_dataset.features}")

In [None]:
# Downsample Validation Set
VAL_SUBSET_SIZE = 10000 

if len(val_dataset) > VAL_SUBSET_SIZE:
    print(f"\nValidation set is very large ({len(val_dataset)}).")
    
    # Shuffle the dataset and select a random subset
    val_dataset = val_dataset.shuffle(seed=42).select(range(VAL_SUBSET_SIZE))
    
    print(f"Using a random subset of {len(val_dataset)} samples for faster validation.")
else:
    print(f"\nUsing full validation set of {len(val_dataset)} samples.")

# Data Processing

In [None]:
# Define Processing Function
def prepare_sample(example):
    """
    This function loads the image data, processes it, and tokenizes the text.
    It can now handle either raw bytes (from train) or PIL Image objects (from val).
    """
    # This variable will contain EITHER raw bytes OR a PIL.Image object
    image_data_or_object = example[IMAGE_DATA_COLUMN]
    text = example[TEXT_LABEL_COLUMN]

    # Check for invalid text labels (None, NaN, or empty/whitespace-only string)
    if not isinstance(text, str) or len(text.strip()) == 0:
        return None
    
    try:
        if isinstance(image_data_or_object, bytes):
            # If we get bytes (from train_dataset), open it
            image = Image.open(io.BytesIO(image_data_or_object)).convert("RGB")
        else:
            # If we get a PIL Image (from val_dataset), just use it
            # We just need to ensure it's in RGB format
            image = image_data_or_object.convert("RGB")

        model_inputs = processor(
            images=image, 
            text=text, 
            padding="max_length",
            truncation=True,
            max_length=64
        )
        
        return model_inputs
        
    except Exception as e:
        print(f"Warning: Error processing sample with text '{str(text)[:50]}'. Skipping. Error: {e}")
        return None

In [None]:
# Apply the Function to the Datasets
print("\nApplying processing function to datasets...")

# Make sure progress bars are off
import datasets
datasets.disable_progress_bar()

processed_train_ds = train_dataset.map(
    prepare_sample, 
    remove_columns=train_dataset.column_names
)
processed_val_ds = val_dataset.map(
    prepare_sample, 
    remove_columns=val_dataset.column_names
)

# Filter out any samples that failed to load
processed_train_ds = processed_train_ds.filter(lambda x: x is not None)
processed_val_ds = processed_val_ds.filter(lambda x: x is not None)

print("...Processing complete.")
print(f"Total processed training samples: {len(processed_train_ds)}")
print(f"Total processed validation samples: {len(processed_val_ds)}")