# Training CNN Model for Text Validation

This notebook trains a CNN model to validate text recognized by easyOCR.

In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from cnn_validator import BrailleValidator

## 1. Load and Prepare Dataset

In [None]:
# Define paths
train_dir = 'dataset/train'
test_dir = 'dataset/test'

# Check if dataset exists
if not os.path.exists(train_dir) or not os.path.exists(test_dir):
    print("Error: Dataset directories not found.")
else:
    print("Dataset found. Proceeding with training...")
    
# Count samples in each class
classes = sorted(os.listdir(train_dir))
class_counts = {}

for cls in classes:
    class_path = os.path.join(train_dir, cls)
    if os.path.isdir(class_path):
        count = len(os.listdir(class_path))
        class_counts[cls] = count
        
print(f"Found {len(class_counts)} classes with {sum(class_counts.values())} total training samples")

# Display class distribution
plt.figure(figsize=(15, 5))
plt.bar(class_counts.keys(), class_counts.values())
plt.title('Number of Training Samples per Class')
plt.xlabel('Class')
plt.ylabel('Number of Samples')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## 2. Create and Train the Model

In [None]:
# Create a custom data generator for the split data
class CustomDataGenerator:
    def __init__(self, image_paths, labels, batch_size=32, target_size=(64, 64), augment=False):
        self.image_paths = image_paths
        self.labels = labels
        self.batch_size = batch_size
        self.target_size = target_size
        self.augment = augment
        self.n_samples = len(image_paths)
        self.indices = np.arange(self.n_samples)
        self.current_idx = 0
        np.random.shuffle(self.indices)
    
    def __len__(self):
        return int(np.ceil(self.n_samples / self.batch_size))
    
    def on_epoch_end(self):
        np.random.shuffle(self.indices)
        self.current_idx = 0
    
    def __getitem__(self, idx):
        batch_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = np.zeros((len(batch_indices), *self.target_size, 1))
        batch_y = np.zeros(len(batch_indices))
        
        for i, idx in enumerate(batch_indices):
            img_path = self.image_paths[idx]
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Warning: Could not read {img_path}")
                img = np.zeros(self.target_size)
            else:
                img = cv2.resize(img, self.target_size)
            
            # Apply augmentation if needed
            if self.augment:
                # Random rotation
                if np.random.rand() > 0.5:
                    angle = np.random.uniform(-10, 10)
                    M = cv2.getRotationMatrix2D((self.target_size[0]//2, self.target_size[1]//2), angle, 1)
                    img = cv2.warpAffine(img, M, self.target_size)
                
                # Random shift
                if np.random.rand() > 0.5:
                    tx = np.random.uniform(-3, 3)
                    ty = np.random.uniform(-3, 3)
                    M = np.float32([[1, 0, tx], [0, 1, ty]])
                    img = cv2.warpAffine(img, M, self.target_size)
                
                # Random zoom
                if np.random.rand() > 0.5:
                    scale = np.random.uniform(0.9, 1.1)
                    M = cv2.getRotationMatrix2D((self.target_size[0]//2, self.target_size[1]//2), 0, scale)
                    img = cv2.warpAffine(img, M, self.target_size)
            
            # Normalize and add channel dimension
            img = img.astype('float32') / 255.0
            batch_x[i] = img.reshape(*self.target_size, 1)
            batch_y[i] = self.labels[idx]
        
        return batch_x, batch_y
    
    def next(self):
        if self.current_idx >= len(self):
            self.on_epoch_end()
        
        batch = self.__getitem__(self.current_idx)
        self.current_idx += 1
        return batch

# Create data generators
train_generator = CustomDataGenerator(train_images, train_labels, batch_size=32, augment=True)
val_generator = CustomDataGenerator(val_images, val_labels, batch_size=32, augment=False)

# Create validator instance
validator = BrailleValidator()

# Set training parameters
epochs = 20
batch_size = 32

# Train the model with custom generators
model = validator.build_model()

# Create callbacks
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    'braille_cnn_model.h5',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    patience=5,
    restore_best_weights=True,
    mode='max',
    verbose=1
)

# Train the model
history = model.fit(
    train_generator.next,
    steps_per_epoch=len(train_generator),
    epochs=epochs,
    validation_data=val_generator.next,
    validation_steps=len(val_generator),
    callbacks=[checkpoint, early_stopping]
)

# Save the model
validator.model = model
validator.save_model('braille_cnn_model.h5')

## 3. Evaluate Training Results

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

# Plot accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

## 4. Save the Trained Model

In [None]:
# Save the model
model_path = 'braille_cnn_model.h5'
validator.save_model(model_path)
print(f"Model saved to {model_path}")

## 5. Test the Model with Sample Images

In [None]:
import cv2
from braille_utils import magic_filter_bw

def test_with_sample(image_path):
    # Load and preprocess image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Error: Could not load image {image_path}")
        return
    
    # Convert to grayscale if needed
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image
    
    # Apply preprocessing
    processed = magic_filter_bw(gray)
    
    # Display original and processed images
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(processed, cmap='gray')
    plt.title('Processed Image')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Test with a few sample characters from the image
    # This is just a simplified example - in practice, you would segment characters first
    # For demonstration, we'll just resize the whole image and predict
    resized = cv2.resize(processed, (64, 64))
    prediction, confidence = validator.predict(resized)
    
    print(f"Predicted character: {prediction}")
    print(f"Confidence: {confidence:.2f}")

In [None]:
# Test with sample images from the test set
# You can replace these with actual paths to test images
sample_images = [
    'dataset/test/A/sample1.jpg',
    'dataset/test/B/sample1.jpg',
    'dataset/test/C/sample1.jpg'
]

for img_path in sample_images:
    if os.path.exists(img_path):
        print(f"Testing with {img_path}")
        test_with_sample(img_path)
        print()
    else:
        print(f"Image not found: {img_path}")