In [None]:
# Install necessary libraries
!pip install -q transformers
!pip install -q datasets jiwer evaluate
import pandas as pd

In [None]:
# Step 1: Load and preprocess data
df = pd.read_fwf('/teamspace/studios/this_studio/IAM/gt_test.txt', header=None)
df.rename(columns={0: "file_name", 1: "text"}, inplace=True)
del df[2]
df['file_name'] = df['file_name'].apply(lambda x: x + 'g' if x.endswith('jp') else x)
df.head()

In [None]:
# Split into train and test datasets
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.2)
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

In [None]:
# Step 2: Define the dataset class
import torch
from torch.utils.data import Dataset
from PIL import Image

class IAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        return {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}

In [None]:
# Step 3: Load processor and datasets
from transformers import TrOCRProcessor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
train_dataset = IAMDataset(root_dir='/teamspace/studios/this_studio/IAM/image/',
                           df=train_df,
                           processor=processor)
eval_dataset = IAMDataset(root_dir='/teamspace/studios/this_studio/IAM/image/',
                           df=test_df,
                           processor=processor)

In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

In [None]:
print(train_df['file_name'].head())


In [None]:
encoding = train_dataset[0]
for k,v in encoding.items():
  print(k, v.shape)

In [None]:
image = Image.open(train_dataset.root_dir + train_df['file_name'][0]).convert("RGB")
image

In [None]:
labels = encoding['labels']
labels[labels == -100] = processor.tokenizer.pad_token_id
label_str = processor.decode(labels, skip_special_tokens=True)
print(label_str)

In [None]:
# Step 4: Define dataloaders
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)  # Increased batch size
eval_dataloader = DataLoader(eval_dataset, batch_size=8)

In [None]:
# Step 5: Load model
from transformers import VisionEncoderDecoderModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
model.to(device)

In [None]:
# Set special tokens and model parameters
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [None]:
# Step 6: Define optimizer, scheduler, and metrics
from transformers import AdamW, get_scheduler
optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
from transformers import get_scheduler

num_training_steps = len(train_dataloader) * 10  # 10 epochs
num_warmup_steps = int(0.1 * num_training_steps)  # 10% of total training steps as warmup

lr_scheduler = get_scheduler(
    "linear", 
    optimizer=optimizer, 
    num_warmup_steps=num_warmup_steps, 
    num_training_steps=num_training_steps
)


In [None]:

from evaluate import load
cer_metric = load("cer")

In [None]:
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    return cer_metric.compute(predictions=pred_str, references=label_str)

In [None]:
# Step 7: Add data augmentation
from torchvision.transforms import Compose, RandomRotation, ColorJitter, ToTensor

augmentation = Compose([
    RandomRotation(degrees=10),
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    ToTensor()
])

In [None]:
# Custom dataset class with augmentation
class IAMAugmentedDataset(IAMDataset):
    def __getitem__(self, idx):
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        image = Image.open(self.root_dir + file_name).convert("RGB")
        image = augmentation(image)  # Apply augmentation
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        return {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}

train_dataset = IAMAugmentedDataset(root_dir='/teamspace/studios/this_studio/IAM/image/',
                                    df=train_df,
                                    processor=processor)

In [None]:
# Step 8: Train the model
from tqdm.notebook import tqdm
import torch.cuda.amp as amp  # For mixed precision training

scaler = amp.GradScaler()  # Initialize scaler for mixed precision

for epoch in range(10):  # Train for 10 epochs
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_dataloader):
        for k, v in batch.items():
            batch[k] = v.to(device)
        
        with amp.autocast():  # Mixed precision
            outputs = model(**batch)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        train_loss += loss.item()

    print(f"Loss after epoch {epoch}:", train_loss / len(train_dataloader))
    lr_scheduler.step()

    # Evaluate
    model.eval()
    valid_cer = 0.0
    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            outputs = model.generate(batch["pixel_values"].to(device))
            cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
            valid_cer += cer

    print(f"Validation CER after epoch {epoch}:", valid_cer / len(eval_dataloader))

# Step 9: Save the model
model.save_pretrained("improved_trocr_model")

### Overall Validation CER : ~5%

In [None]:
from tqdm.notebook import tqdm

# Prepare the test dataloader
test_dataloader = DataLoader(eval_dataset, batch_size=4)

# Switch model to evaluation mode
model.eval()

# List to store predictions and ground truth
predictions = []
ground_truths = []

# Iterate through the test data
with torch.no_grad():
    for batch in tqdm(test_dataloader):
        # Move the batch to device
        for k, v in batch.items():
            batch[k] = v.to(device)

        # Generate predictions
        outputs = model.generate(batch["pixel_values"])

        # Decode predictions
        pred_str = processor.batch_decode(outputs, skip_special_tokens=True)

        # Handle the -100 padding tokens in labels
        labels = batch["labels"]
        labels[labels == -100] = processor.tokenizer.pad_token_id  # replace -100 with pad_token_id
        
        # Decode labels
        label_str = processor.batch_decode(labels, skip_special_tokens=True)

        predictions.extend(pred_str)
        ground_truths.extend(label_str)

# Displaying some results
for i in range(5):  # Displaying 5 predictions
    print(f"Ground Truth: {ground_truths[i]}")
    print(f"Prediction:   {predictions[i]}")
    print("-" * 50)

# Optionally, calculate CER on the whole test set
from evaluate import load
cer_metric = load("cer")

# Compute CER for the entire test set
test_cer = cer_metric.compute(predictions=predictions, references=ground_truths)
print(f"Test CER: {test_cer}")


### Test CER : ~5.7%

In [None]:
from PIL import Image
import torch

# Function to predict text from a given image
def predict_from_image(image_path, model, processor):
    # Open the image
    image = Image.open(image_path).convert("RGB")

    # Preprocess the image
    pixel_values = processor(image, return_tensors="pt").pixel_values

    # Run the image through the model to generate predictions
    model.eval()
    with torch.no_grad():
        generated_ids = model.generate(pixel_values.to(model.device))

    # Decode the predicted ids to text
    predicted_text = processor.decode(generated_ids[0], skip_special_tokens=True)
    
    return predicted_text

# Test the function with a test image
image_path = '/teamspace/studios/this_studio/Screenshot 2024-12-22 132529.png'  # Specify the path to your test image
predicted_text = predict_from_image(image_path, model, processor)

# Display the predicted text
print("Predicted Text:", predicted_text)


In [None]:
from PIL import Image
import torch

def predict_from_custom_image(image_path, model, processor, device):
    """
    Predict the text from a custom image using the trained model.
    
    Parameters:
    - image_path (str): Path to the image file.
    - model (PreTrainedModel): TrOCR model.
    - processor (TrOCRProcessor): Processor for TrOCR.
    - device (torch.device): Device to run the model on ('cuda' or 'cpu').
    """
    # Load the image
    image = Image.open(image_path).convert("RGB")
    
    # Preprocess the image
    pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    # Run the model to generate predictions
    model.eval()
    with torch.no_grad():
        generated_ids = model.generate(pixel_values)
    
    # Decode the output to text
    decoded_output = processor.batch_decode(generated_ids, skip_special_tokens=True)

    # Print the predicted text
    print(f"Predicted Text: {decoded_output[0]}")

# Example usage:
image_path = '/teamspace/studios/this_studio/WhatsApp Image 2024-12-29 at 19.00.00_6f4550fl.jpg'  # Replace with your image path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # Move the model to GPU if available

predict_from_custom_image(image_path, model, processor, device)
