# Machine Learning Pipeline for Enhanced OCR Accuracy (PyTorch)

This notebook demonstrates how to build and train a machine learning model to improve OCR accuracy and data extraction from receipts using PyTorch.

## 1. Import Required Libraries
We will use PyTorch, torchvision, pandas, and PIL for data handling, model building, and image processing.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
import numpy as np

## 2. Prepare Dataset
Load receipt images and their OCR text from the uploads folder and database. Preprocess images and text for model input.

In [None]:
# Example: Load OCR data from SQLite and images from uploads folder
import sqlite3
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'ocr.sqlite3')
UPLOADS_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'uploads')

def load_ocr_data():
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute('SELECT file_name, raw_ocr, parsed_json FROM vouchers_master')
    data = cur.fetchall()
    conn.close()
    return data

data = load_ocr_data()
print(f"Loaded {len(data)} records from DB.")

# Example transform for images
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

class ReceiptDataset(Dataset):
    def __init__(self, data, uploads_path, transform=None):
        self.data = data
        self.uploads_path = uploads_path
        self.transform = transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        file_name, raw_ocr, parsed_json = self.data[idx]
        img_path = os.path.join(self.uploads_path, file_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        # For demo: use raw_ocr as target (could be structured fields)
        return image, raw_ocr

dataset = ReceiptDataset(data, UPLOADS_PATH, transform)
print(f"Dataset size: {len(dataset)}")

## 3. Create DataLoader for Batching
Use DataLoader to efficiently batch and shuffle data for training.

In [None]:
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Example: iterate through one batch
for images, targets in dataloader:
    print(f"Batch images shape: {images.shape}")
    print(f"Batch targets: {targets}")
    break

## 4. Define Neural Network Model
Implement a simple CNN+RNN model for document understanding.

In [None]:
class SimpleOCRModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(32 * 64 * 64, 128)
        self.out = nn.Linear(128, 1)  # For demo: regression on OCR text length
    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.out(x)
        return x

model = SimpleOCRModel()
print(model)

## 5. Set Up Loss Function and Optimizer
Choose a loss function and optimizer for training.

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# For demo: target is OCR text length (regression)
def ocr_target(text):
    return torch.tensor([len(text)], dtype=torch.float32)

## 6. Batch Training Loop
Train the model using batches of data, updating weights and tracking loss.

In [None]:
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, targets in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        batch_targets = torch.stack([ocr_target(t) for t in targets])
        loss = criterion(outputs, batch_targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

## 7. Evaluate Model Performance
Assess the trained model's accuracy or loss on a validation or test set.

In [None]:
model.eval()
with torch.no_grad():
    total_loss = 0.0
    for images, targets in dataloader:
        outputs = model(images)
        batch_targets = torch.stack([ocr_target(t) for t in targets])
        loss = criterion(outputs, batch_targets)
        total_loss += loss.item() * images.size(0)
    avg_loss = total_loss / len(dataset)
    print(f"Validation Loss: {avg_loss:.4f}")