# OCR Training Walkthrough

This notebook demonstrates the process of preparing the dataset, converting it to LMDB format, and training the UTRNet model.

## Steps
1. **Setup**: Import libraries and configure paths.
2. **Data Preparation**: Convert raw images and labels to LMDB format.
3. **Model Training**: Initialize the model and run the training loop.

In [None]:
import os
import sys
import torch
import random
import numpy as np
import argparse

# Add project root to path to import modules
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from src.create_lmdb_dataset import createDataset
from src.train import train
from models.model import Model
from src.utils import CTCLabelConverter, AttnLabelConverter

## 1. Data Preparation

The model expects data in LMDB format. We will convert the raw dataset (images + labels file) into LMDB.

We will use the `IIITH` dataset located in `data/IIITH` as an example.

In [None]:
# Configuration for data conversion
input_path = os.path.join(project_root, 'data', 'IIITH', 'images')
gt_file = os.path.join(project_root, 'data', 'IIITH', 'labels.txt')
output_path = os.path.join(project_root, 'data', 'IIITH_lmdb')

# Check if input exists
if os.path.exists(input_path) and os.path.exists(gt_file):
    print(f"Converting data from {input_path} to LMDB at {output_path}...")
    createDataset(input_path, gt_file, output_path)
else:
    print("Data not found. Please ensure data/IIITH/images and data/IIITH/labels.txt exist.")

## 2. Model Training

Now that we have the LMDB dataset, we can configure and train the model.

We define a configuration class `Opt` to hold all training parameters.

In [None]:
class Opt:
    def __init__(self):
        self.exp_name = 'demo_training'
        self.train_data = os.path.join(project_root, 'data') # Root folder containing LMDB datasets
        self.valid_data = os.path.join(project_root, 'data') # Using same root for demo
        self.select_data = 'IIITH_lmdb' # Select the specific LMDB folder
        self.batch_ratio = '1.0'
        self.total_data_usage_ratio = '1.0'
        self.batch_max_length = 100
        self.imgH = 32
        self.imgW = 400
        self.rgb = False
        self.character = '' # Will be loaded from file
        self.sensitive = False
        self.PAD = False
        self.data_filtering_off = False
        self.FeatureExtraction = 'HRNet'
        self.SequenceModeling = 'DBiLSTM'
        self.Prediction = 'CTC'
        self.num_fiducial = 20
        self.input_channel = 1
        self.output_channel = 512
        self.hidden_size = 256
        self.manualSeed = 1111
        self.workers = 0 # Set to 0 for interactive debugging
        self.batch_size = 8 # Small batch size for demo
        self.num_epochs = 1 # Just 1 epoch for demo
        self.valInterval = 10
        self.saved_model = ''
        self.FT = False
        self.adam = True
        self.lr = 1.0
        self.beta1 = 0.9
        self.rho = 0.95
        self.eps = 1e-8
        self.grad_clip = 5
        self.device_id = None

opt = Opt()

if opt.FeatureExtraction == "HRNet":
    opt.output_channel = 32

# Load characters
with open(os.path.join(project_root, "data", "UrduGlyphs.txt"), "r", encoding="utf-8") as f:
    content = f.readlines()
    content = ''.join([str(elem).strip('\n') for elem in content])
    opt.character = content + " "

print(f"Loaded {len(opt.character)} characters.")

### Run Training

We will now run the training loop. This will:
1. Load the dataset.
2. Initialize the model.
3. Train for the specified number of epochs.

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create output directory
os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True)

# Set seeds
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(opt.manualSeed)

# Run training
try:
    train(opt, device)
except Exception as e:
    print(f"An error occurred during training: {e}")