In [9]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import get_peft_config, get_peft_model, PrefixTuningConfig, TaskType
from datasets import Dataset
import pandas as pd
import numpy as np

# Load your data
df = pd.read_csv("dataset/train.csv")

# Prepare dataset
def prepare_dataset(df):
    dataset_dict = {
        "ocr_text": df["OCR Text"].tolist(),
        "ground_truth": df["Ground Truth"].tolist()
    }
    return Dataset.from_dict(dataset_dict)

# Initialize tokenizer and model
model_name = "facebook/opt-350m"  # You can change this to other models
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Configure prefix tuning
peft_config = PrefixTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    num_virtual_tokens=20,  # You can adjust this
    prefix_projection=True,
    token_dim=model.config.hidden_size,
    num_layers=model.config.num_hidden_layers
)

# Get PEFT model
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Prepare dataset
dataset = prepare_dataset(df)

def preprocess_function(examples):
    # Format inputs
    inputs = [f"Correct OCR: {text} -> " for text in examples["ocr_text"]]
    targets = examples["ground_truth"]

    # Tokenize inputs and targets
    model_inputs = tokenizer(inputs, truncation=True, max_length=256)
    labels = tokenizer(targets, truncation=True, max_length=256)

    # Prepare the labels
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Preprocess the dataset
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset.column_names
)

# Define training arguments
training_args = TrainingArguments(
    output_dir="ocr_prefix_tuned",
    learning_rate=1e-3,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    save_steps=1000,
    save_total_limit=2,
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

# Save the model
model.save_pretrained("ocr_prefix_tuned_final")

# Test the model
def correct_ocr(text):
    inputs = tokenizer(f"Correct OCR: {text} -> ", return_tensors="pt")
    outputs = model.generate(**inputs, max_length=256, num_beams=5)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

Loading base model...
Loading PEFT model...
Model loaded on cuda

Testing OCR correction:

Original: Ths is an exmple of OCR txt wth errors.
Error during generation: 'past_key_values'
Corrected: Error correcting text: 'past_key_values'

Original: The queck brown tox jumps over tho lazy dag.
Error during generation: 'past_key_values'
Corrected: Error correcting text: 'past_key_values'

Original: Artiticial lntelligence is changing the worid.
Error during generation: 'past_key_values'
Corrected: Error correcting text: 'past_key_values'


In [17]:
import torch
from transformers import MistralConfig, MistralForCausalLM, AutoModelForCausalLM
from peft import PrefixTuningConfig, get_peft_model


def get_model(name):
    if name == "mistral":
        model_config = MistralConfig(
            vocab_size=32000,
            hidden_size=512,
            max_position_embeddings=32768,
            num_attention_heads=16,
            num_hidden_layers=8,
            num_key_value_heads=4,
        )
        return MistralForCausalLM(model_config)

    return AutoModelForCausalLM.from_pretrained(name)


for name in ("gpt2", "facebook/opt-125m", "bigscience/bloomz-560m",  "HuggingFaceH4/tiny-random-LlamaForCausalLM", "mistral"):
    config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=30)
    model = get_model(name)
    model = get_peft_model(model, config)
    model.config.use_cache = False

    input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]])
    attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]])
    try:
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        print(f"PASS: model {name} passed")
    except Exception as e:
        print(f"FAIL: model {name} failed with {e}")

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

PASS: model gpt2 passed


config.json:   0%|          | 0.00/651 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/251M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

PASS: model facebook/opt-125m passed


config.json:   0%|          | 0.00/715 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

FAIL: model bigscience/bloomz-560m failed with 'tuple' object has no attribute 'get_seq_length'


config.json:   0%|          | 0.00/466 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.07M [00:00<?, ?B/s]



generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

FAIL: model HuggingFaceH4/tiny-random-LlamaForCausalLM failed with 'tuple' object has no attribute 'get_seq_length'
FAIL: model mistral failed with 'tuple' object has no attribute 'get_seq_length'
