# Bangla Handwritten OCR Model Training

This notebook trains and tests the OCR model using datasets from `banglaWrittenWordOCR-main`.

## Pipeline:
1. **Detection**: YOLOv8 for character detection
2. **Recognition**: ResNet34 for character recognition (grapheme root + vowel diacritic + consonant diacritic)
3. **Spelling Correction**: Word2Vec for post-processing


In [None]:
# Install required packages
!pip install torch torchvision pillow opencv-python ultralytics pandas numpy gensim pretrainedmodels


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image, ImageOps, ImageEnhance
import pandas as pd
import os
import numpy as np
import cv2
import json
import warnings
from ultralytics import YOLO
import sys

# Add paths
sys.path.append('banglaWrittenWordOCR-main')
sys.path.append('banglaWrittenWordOCR-main/recongnition_model')

warnings.filterwarnings("ignore")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


## Step 1: Load Dataset and Create DataLoaders


In [None]:
# Dataset class for BanglaGrapheme dataset
class BanglaGraphemeDataset:
    def __init__(self, img_H, img_W, type='train', data_path='banglaWrittenWordOCR-main/recongnition_model/data/BanglaGrapheme'):
        csv_path = os.path.join(data_path, f'{type}.csv')
        if not os.path.exists(csv_path):
            print(f"Warning: {csv_path} not found. Using empty dataset.")
            self.image_ids = []
            self.grapheme_root = []
            self.vowel_diacritic = []
            self.consonant_diacritic = []
        else:
            df = pd.read_csv(csv_path)
            df = df[['image_id', 'grapheme_root', 'vowel_diacritic', 'consonant_diacritic']]
            
            self.image_ids = df.image_id.values
            self.grapheme_root = df.grapheme_root.values
            self.vowel_diacritic = df.vowel_diacritic.values
            self.consonant_diacritic = df.consonant_diacritic.values
        
        self.width = img_W
        self.height = img_H
        self.type = type
        self.data_path = data_path
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, item):
        image_folder = os.path.join(self.data_path, self.type)
        image_path = os.path.join(image_folder, f"{self.image_ids[item]}.jpg")
        
        if not os.path.exists(image_path):
            # Return a blank image if file doesn't exist
            image = Image.new('RGB', (self.width, self.height), color='white')
        else:
            image = Image.open(image_path)
            image = image.resize((self.width, self.height))
            image = image.convert('RGB')
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        
        image = transform(image)
        return {
            'image': image,
            'grapheme_root': torch.tensor(self.grapheme_root[item], dtype=torch.long),
            'vowel_diacritic': torch.tensor(self.vowel_diacritic[item], dtype=torch.long),
            'consonant_diacritic': torch.tensor(self.consonant_diacritic[item], dtype=torch.long)
        }

BATCH_SIZE = 64
IMG_H, IMG_W = 128, 224

def create_dataloaders():
    train_dataset = BanglaGraphemeDataset(IMG_H, IMG_W, 'train')
    val_dataset = BanglaGraphemeDataset(IMG_H, IMG_W, 'val')
    test_dataset = BanglaGraphemeDataset(IMG_H, IMG_W, 'test')
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = create_dataloaders()


## Step 2: Load ResNet34 Model Architecture


In [None]:
# Import ResNet34 model
try:
    from recongnition_model.models.model import resnet34
    print("Loaded model from recongnition_model")
except:
    try:
        import pretrainedmodels
        import torch.nn.functional as F
        
        class resnet34(nn.Module):
            def __init__(self):
                super(resnet34, self).__init__()
                self.model = pretrainedmodels.__dict__["resnet34"](pretrained=None)
                self.l0 = nn.Linear(512, 168)  # grapheme_root
                self.l1 = nn.Linear(512, 11)   # vowel_diacritic
                self.l2 = nn.Linear(512, 7)    # consonant_diacritic

            def forward(self, x):
                bs, _, _, _ = x.shape
                x = self.model.features(x)
                x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
                l0 = self.l0(x)
                l1 = self.l1(x)
                l2 = self.l2(x)
                return l0, l1, l2
        print("Created model from scratch")
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Please ensure pretrainedmodels is installed: pip install pretrainedmodels")

# Initialize model
model = resnet34().to(device)
print(f"Model initialized on {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")


## Step 3: Load Grapheme Mappings


In [None]:
# Load grapheme mappings
mapping_file = 'bangla_ocr_pipeline/grapheme_maps.json'
if os.path.exists(mapping_file):
    with open(mapping_file, 'r', encoding='utf-8') as f:
        grapheme_maps = json.load(f)
    print("Loaded grapheme mappings")
    print(f"Grapheme roots: {len(grapheme_maps['grapheme_root'])}")
    print(f"Vowel diacritics: {len(grapheme_maps['vowel_diacritic'])}")
    print(f"Consonant diacritics: {len(grapheme_maps['consonant_diacritic'])}")
else:
    print(f"Warning: {mapping_file} not found. Creating default mappings.")
    grapheme_maps = {
        "grapheme_root": {str(i): f"root_{i}" for i in range(168)},
        "vowel_diacritic": {str(i): f"vowel_{i}" for i in range(11)},
        "consonant_diacritic": {str(i): f"cons_{i}" for i in range(7)}
    }


## Step 4: Training Setup


In [None]:
# Loss function
def loss_fc(outputs, targets):
    out1, out2, out3 = outputs
    t1, t2, t3 = targets
    
    loss1 = nn.CrossEntropyLoss()(out1, t1)
    loss2 = nn.CrossEntropyLoss()(out2, t2)
    loss3 = nn.CrossEntropyLoss()(out3, t3)
    
    return (loss1 + loss2 + loss3) / 3

# Optimizer and scheduler
LEARNING_RATE = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

# Multi-GPU support
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    print(f"Using {torch.cuda.device_count()} GPUs")


## Step 5: Training Loop


In [None]:
NUM_EPOCHS = 5  # Adjust as needed
train_losses = []
val_losses = []

for epoch in range(NUM_EPOCHS):
    # Training phase
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for batch_idx, data in enumerate(train_loader):
        try:
            image = data["image"].to(device, dtype=torch.float)
            grapheme_root = data["grapheme_root"].to(device, dtype=torch.long)
            vowel_diacritic = data["vowel_diacritic"].to(device, dtype=torch.long)
            consonant_diacritic = data["consonant_diacritic"].to(device, dtype=torch.long)
            
            targets = (grapheme_root, vowel_diacritic, consonant_diacritic)
            
            optimizer.zero_grad()
            outputs = model(image)
            total_loss = loss_fc(outputs, targets)
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
            num_batches += 1
            
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Batch {batch_idx+1}, Loss: {total_loss.item():.4f}")
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    avg_train_loss = running_loss / max(num_batches, 1)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_batches = 0
    
    with torch.no_grad():
        for data in val_loader:
            try:
                image = data["image"].to(device, dtype=torch.float)
                grapheme_root = data["grapheme_root"].to(device, dtype=torch.long)
                vowel_diacritic = data["vowel_diacritic"].to(device, dtype=torch.long)
                consonant_diacritic = data["consonant_diacritic"].to(device, dtype=torch.long)
                
                targets = (grapheme_root, vowel_diacritic, consonant_diacritic)
                outputs = model(image)
                loss = loss_fc(outputs, targets)
                
                val_loss += loss.item()
                val_batches += 1
            except Exception as e:
                continue
    
    avg_val_loss = val_loss / max(val_batches, 1)
    val_losses.append(avg_val_loss)
    
    scheduler.step(avg_val_loss)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}")
    print("-" * 50)
    
    # Save checkpoint
    os.makedirs('models', exist_ok=True)
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses
    }
    torch.save(checkpoint, f'models/ocr_model_epoch_{epoch+1}.pth')
    print(f"Saved checkpoint: models/ocr_model_epoch_{epoch+1}.pth")

print("Training completed!")


## Step 6: Save Final Model


In [None]:
# Save the final model
os.makedirs('models', exist_ok=True)
if hasattr(model, 'module'):
    torch.save(model.module.state_dict(), 'models/ocr_model_final.pth')
else:
    torch.save(model.state_dict(), 'models/ocr_model_final.pth')
print("Model saved to models/ocr_model_final.pth")


# Bangla Handwritten OCR Model Training

This notebook trains and tests the OCR model using datasets from `banglaWrittenWordOCR-main`.

## Pipeline:
1. **Detection**: YOLOv8 for character detection
2. **Recognition**: ResNet34 for character recognition (grapheme root + vowel diacritic + consonant diacritic)
3. **Spelling Correction**: Word2Vec for post-processing


In [None]:
# Install required packages
!pip install torch torchvision pillow opencv-python ultralytics pandas numpy gensim pretrainedmodels


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image, ImageOps, ImageEnhance
import pandas as pd
import os
import numpy as np
import cv2
import json
import warnings
from ultralytics import YOLO
import sys

# Add paths
sys.path.append('banglaWrittenWordOCR-main')
sys.path.append('banglaWrittenWordOCR-main/recongnition_model')

warnings.filterwarnings("ignore")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


## Step 1: Load Dataset and Create DataLoaders


In [None]:
# Dataset class for BanglaGrapheme dataset
class BanglaGraphemeDataset:
    def __init__(self, img_H, img_W, type='train', data_path='banglaWrittenWordOCR-main/recongnition_model/data/BanglaGrapheme'):
        csv_path = os.path.join(data_path, f'{type}.csv')
        if not os.path.exists(csv_path):
            print(f"Warning: {csv_path} not found. Using empty dataset.")
            self.image_ids = []
            self.grapheme_root = []
            self.vowel_diacritic = []
            self.consonant_diacritic = []
        else:
            df = pd.read_csv(csv_path)
            df = df[['image_id', 'grapheme_root', 'vowel_diacritic', 'consonant_diacritic']]
            
            self.image_ids = df.image_id.values
            self.grapheme_root = df.grapheme_root.values
            self.vowel_diacritic = df.vowel_diacritic.values
            self.consonant_diacritic = df.consonant_diacritic.values
        
        self.width = img_W
        self.height = img_H
        self.type = type
        self.data_path = data_path
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, item):
        image_folder = os.path.join(self.data_path, self.type)
        image_path = os.path.join(image_folder, f"{self.image_ids[item]}.jpg")
        
        if not os.path.exists(image_path):
            # Return a blank image if file doesn't exist
            image = Image.new('RGB', (self.width, self.height), color='white')
        else:
            image = Image.open(image_path)
            image = image.resize((self.width, self.height))
            image = image.convert('RGB')
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        
        image = transform(image)
        return {
            'image': image,
            'grapheme_root': torch.tensor(self.grapheme_root[item], dtype=torch.long),
            'vowel_diacritic': torch.tensor(self.vowel_diacritic[item], dtype=torch.long),
            'consonant_diacritic': torch.tensor(self.consonant_diacritic[item], dtype=torch.long)
        }

BATCH_SIZE = 64
IMG_H, IMG_W = 128, 224

def create_dataloaders():
    train_dataset = BanglaGraphemeDataset(IMG_H, IMG_W, 'train')
    val_dataset = BanglaGraphemeDataset(IMG_H, IMG_W, 'val')
    test_dataset = BanglaGraphemeDataset(IMG_H, IMG_W, 'test')
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = create_dataloaders()


## Step 2: Load ResNet34 Model Architecture


In [None]:
# Import ResNet34 model
try:
    from recongnition_model.models.model import resnet34
    print("Loaded model from recongnition_model")
except:
    try:
        import pretrainedmodels
        import torch.nn.functional as F
        
        class resnet34(nn.Module):
            def __init__(self):
                super(resnet34, self).__init__()
                self.model = pretrainedmodels.__dict__["resnet34"](pretrained=None)
                self.l0 = nn.Linear(512, 168)  # grapheme_root
                self.l1 = nn.Linear(512, 11)   # vowel_diacritic
                self.l2 = nn.Linear(512, 7)    # consonant_diacritic

            def forward(self, x):
                bs, _, _, _ = x.shape
                x = self.model.features(x)
                x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
                l0 = self.l0(x)
                l1 = self.l1(x)
                l2 = self.l2(x)
                return l0, l1, l2
        print("Created model from scratch")
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Please ensure pretrainedmodels is installed: pip install pretrainedmodels")

# Initialize model
model = resnet34().to(device)
print(f"Model initialized on {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")


## Step 3: Load Grapheme Mappings


In [None]:
# Load grapheme mappings
mapping_file = 'bangla_ocr_pipeline/grapheme_maps.json'
if os.path.exists(mapping_file):
    with open(mapping_file, 'r', encoding='utf-8') as f:
        grapheme_maps = json.load(f)
    print("Loaded grapheme mappings")
    print(f"Grapheme roots: {len(grapheme_maps['grapheme_root'])}")
    print(f"Vowel diacritics: {len(grapheme_maps['vowel_diacritic'])}")
    print(f"Consonant diacritics: {len(grapheme_maps['consonant_diacritic'])}")
else:
    print(f"Warning: {mapping_file} not found. Creating default mappings.")
    grapheme_maps = {
        "grapheme_root": {str(i): f"root_{i}" for i in range(168)},
        "vowel_diacritic": {str(i): f"vowel_{i}" for i in range(11)},
        "consonant_diacritic": {str(i): f"cons_{i}" for i in range(7)}
    }


## Step 4: Training Setup


In [None]:
# Loss function
def loss_fc(outputs, targets):
    out1, out2, out3 = outputs
    t1, t2, t3 = targets
    
    loss1 = nn.CrossEntropyLoss()(out1, t1)
    loss2 = nn.CrossEntropyLoss()(out2, t2)
    loss3 = nn.CrossEntropyLoss()(out3, t3)
    
    return (loss1 + loss2 + loss3) / 3

# Optimizer and scheduler
LEARNING_RATE = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

# Multi-GPU support
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    print(f"Using {torch.cuda.device_count()} GPUs")


## Step 5: Training Loop


In [None]:
NUM_EPOCHS = 5  # Adjust as needed
train_losses = []
val_losses = []

for epoch in range(NUM_EPOCHS):
    # Training phase
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for batch_idx, data in enumerate(train_loader):
        try:
            image = data["image"].to(device, dtype=torch.float)
            grapheme_root = data["grapheme_root"].to(device, dtype=torch.long)
            vowel_diacritic = data["vowel_diacritic"].to(device, dtype=torch.long)
            consonant_diacritic = data["consonant_diacritic"].to(device, dtype=torch.long)
            
            targets = (grapheme_root, vowel_diacritic, consonant_diacritic)
            
            optimizer.zero_grad()
            outputs = model(image)
            total_loss = loss_fc(outputs, targets)
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
            num_batches += 1
            
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Batch {batch_idx+1}, Loss: {total_loss.item():.4f}")
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    avg_train_loss = running_loss / max(num_batches, 1)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_batches = 0
    
    with torch.no_grad():
        for data in val_loader:
            try:
                image = data["image"].to(device, dtype=torch.float)
                grapheme_root = data["grapheme_root"].to(device, dtype=torch.long)
                vowel_diacritic = data["vowel_diacritic"].to(device, dtype=torch.long)
                consonant_diacritic = data["consonant_diacritic"].to(device, dtype=torch.long)
                
                targets = (grapheme_root, vowel_diacritic, consonant_diacritic)
                outputs = model(image)
                loss = loss_fc(outputs, targets)
                
                val_loss += loss.item()
                val_batches += 1
            except Exception as e:
                continue
    
    avg_val_loss = val_loss / max(val_batches, 1)
    val_losses.append(avg_val_loss)
    
    scheduler.step(avg_val_loss)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}")
    print("-" * 50)
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses
    }
    torch.save(checkpoint, f'models/ocr_model_epoch_{epoch+1}.pth')
    print(f"Saved checkpoint: models/ocr_model_epoch_{epoch+1}.pth")

print("Training completed!")


## Step 6: Load Trained Model for Inference


In [None]:
# Load the best model checkpoint
def load_model(checkpoint_path='models/ocr_model_epoch_5.pth'):
    model = resnet34().to(device)
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Loaded model from {checkpoint_path}")
        else:
            model.load_state_dict(checkpoint)
            print(f"Loaded model weights from {checkpoint_path}")
    else:
        print(f"Checkpoint not found: {checkpoint_path}")
        print("Using untrained model")
    model.eval()
    return model

# Uncomment to load a specific checkpoint
# model = load_model('models/ocr_model_epoch_5.pth')


## Step 7: Character Recognition Function


In [None]:
def recognize_character(patch, model, grapheme_maps):
    """
    Recognize a single character patch
    Args:
        patch: numpy array or PIL Image of the character
        model: trained ResNet34 model
        grapheme_maps: dictionary with character mappings
    Returns:
        recognized_char: combined character string
        confidence: confidence scores
    """
    # Preprocess image
    if isinstance(patch, np.ndarray):
        img = Image.fromarray(patch)
    else:
        img = patch
    
    # Enhance contrast
    enhancer = ImageEnhance.Contrast(img)
    img_enhanced = enhancer.enhance(2.0)
    img_inv = ImageOps.invert(img_enhanced)
    img_inv = img_inv.resize((224, 128)).convert('RGB')
    
    # Transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    img_tensor = transform(img_inv).unsqueeze(0).to(device)
    
    # Inference
    model.eval()
    with torch.no_grad():
        outputs = model(img_tensor)
        
        # Get predictions
        root_idx = torch.argmax(outputs[0], dim=1).item()
        vowel_idx = torch.argmax(outputs[1], dim=1).item()
        cons_idx = torch.argmax(outputs[2], dim=1).item()
        
        # Get confidence scores
        root_conf = torch.softmax(outputs[0], dim=1)[0][root_idx].item()
        vowel_conf = torch.softmax(outputs[1], dim=1)[0][vowel_idx].item()
        cons_conf = torch.softmax(outputs[2], dim=1)[0][cons_idx].item()
        
        # Map to characters
        root_char = grapheme_maps['grapheme_root'].get(str(root_idx), '?')
        vowel_char = grapheme_maps['vowel_diacritic'].get(str(vowel_idx), '')
        cons_char = grapheme_maps['consonant_diacritic'].get(str(cons_idx), '')
        
        recognized_char = root_char + vowel_char + cons_char
        confidence = (root_conf + vowel_conf + cons_conf) / 3
        
    return recognized_char, confidence, {
        'root': (root_char, root_conf),
        'vowel': (vowel_char, vowel_conf),
        'consonant': (cons_char, cons_conf)
    }


## Step 8: Test on Sample Image


In [None]:
# Test on a sample image from the dataset
def test_sample_image(image_path=None):
    if image_path is None:
        # Use first image from test set
        test_dataset = BanglaGraphemeDataset(IMG_H, IMG_W, 'test')
        if len(test_dataset) > 0:
            sample = test_dataset[0]
            image = sample['image']
            # Convert tensor back to PIL for display
            img_array = (image.permute(1, 2, 0).numpy() + 1) / 2 * 255
            img_array = img_array.astype(np.uint8)
            img_pil = Image.fromarray(img_array)
        else:
            print("No test images available")
            return
    else:
        img_pil = Image.open(image_path).convert('RGB')
        img_pil = img_pil.resize((IMG_W, IMG_H))
    
    # Recognize
    char, conf, details = recognize_character(img_pil, model, grapheme_maps)
    
    print(f"Recognized Character: {char}")
    print(f"Confidence: {conf:.2%}")
    print(f"Details:")
    print(f"  Root: {details['root'][0]} ({details['root'][1]:.2%})")
    print(f"  Vowel: {details['vowel'][0]} ({details['vowel'][1]:.2%})")
    print(f"  Consonant: {details['consonant'][0]} ({details['consonant'][1]:.2%})")
    
    return char, conf, details

# Uncomment to test
# test_sample_image()


## Step 9: Save Model for Web Application


In [None]:
# Save the final model
def save_final_model(model, save_path='models/ocr_model_final.pth'):
    os.makedirs('models', exist_ok=True)
    if hasattr(model, 'module'):
        torch.save(model.module.state_dict(), save_path)
    else:
        torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

# Uncomment to save
# save_final_model(model, 'models/ocr_model_final.pth')
