# Roman Numeral Recognition - Data Cleaning & Augmentation
This notebook helps you improve the dataset quality to achieve >90% accuracy.

## Setup Instructions for Google Colab:
1. Upload this notebook to Google Colab
2. Upload your `dataset.zip` (containing train/ and val/ folders)
3. Run all cells sequentially
4. Download the cleaned dataset and train locally

In [None]:
# Install required packages
!pip install pillow matplotlib numpy scikit-learn opencv-python -q

In [None]:
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance, ImageFilter
from collections import defaultdict
import cv2
from sklearn.cluster import DBSCAN
import random

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)

In [None]:
# If running on Colab, upload and extract your dataset
# Uncomment the following lines if you're on Colab:

# from google.colab import files
# uploaded = files.upload()  # Upload your dataset.zip
# !unzip -q dataset.zip

## 1. Explore Dataset Distribution

In [None]:
def count_images(base_path):
    """Count images in each class"""
    counts = {}
    for class_name in sorted(os.listdir(base_path)):
        class_path = os.path.join(base_path, class_name)
        if os.path.isdir(class_path) and not class_name.startswith('.'):
            images = [f for f in os.listdir(class_path) if not f.startswith('.')]
            counts[class_name] = len(images)
    return counts

train_counts = count_images('dataset/train')
val_counts = count_images('dataset/val')

print("Training set distribution:")
for class_name, count in train_counts.items():
    print(f"  {class_name}: {count}")
print(f"  Total: {sum(train_counts.values())}")

print("\nValidation set distribution:")
for class_name, count in val_counts.items():
    print(f"  {class_name}: {count}")
print(f"  Total: {sum(val_counts.values())}")

# Visualize distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes[0].bar(train_counts.keys(), train_counts.values())
axes[0].set_title('Training Set Distribution')
axes[0].set_xlabel('Class')
axes[0].set_ylabel('Number of Images')
axes[0].tick_params(axis='x', rotation=45)

axes[1].bar(val_counts.keys(), val_counts.values())
axes[1].set_title('Validation Set Distribution')
axes[1].set_xlabel('Class')
axes[1].set_ylabel('Number of Images')
axes[1].tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.show()

## 2. Visualize Sample Images from Each Class

In [None]:
def visualize_samples(base_path, samples_per_class=10):
    """Visualize random samples from each class"""
    classes = sorted([d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d)) and not d.startswith('.')])
    
    fig, axes = plt.subplots(len(classes), samples_per_class, figsize=(20, 2*len(classes)))
    
    for i, class_name in enumerate(classes):
        class_path = os.path.join(base_path, class_name)
        images = [f for f in os.listdir(class_path) if not f.startswith('.')]
        samples = random.sample(images, min(samples_per_class, len(images)))
        
        for j, img_name in enumerate(samples):
            img_path = os.path.join(class_path, img_name)
            img = Image.open(img_path)
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].axis('off')
            if j == 0:
                axes[i, j].set_title(f'{class_name}', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.show()

print("Training set samples:")
visualize_samples('dataset/train', samples_per_class=10)

## 3. Identify Outliers and Problematic Images
We'll use image statistics to identify potential outliers

In [None]:
def get_image_features(img_path):
    """Extract features from an image for outlier detection"""
    img = Image.open(img_path).convert('L')  # Convert to grayscale
    img_array = np.array(img)
    
    return [
        np.mean(img_array),  # Mean brightness
        np.std(img_array),   # Standard deviation
        np.min(img_array),   # Min value
        np.max(img_array),   # Max value
    ]

def find_outliers(base_path, output_dir='outliers'):
    """Find potential outliers in each class"""
    os.makedirs(output_dir, exist_ok=True)
    outliers_info = {}
    
    for class_name in sorted(os.listdir(base_path)):
        class_path = os.path.join(base_path, class_name)
        if not os.path.isdir(class_path) or class_name.startswith('.'):
            continue
        
        print(f"\nAnalyzing class: {class_name}")
        images = [f for f in os.listdir(class_path) if not f.startswith('.')]
        
        # Extract features
        features = []
        image_paths = []
        for img_name in images:
            img_path = os.path.join(class_path, img_name)
            try:
                feat = get_image_features(img_path)
                features.append(feat)
                image_paths.append(img_path)
            except Exception as e:
                print(f"  Error processing {img_name}: {e}")
        
        features = np.array(features)
        
        # Find outliers using statistical method (IQR)
        Q1 = np.percentile(features, 25, axis=0)
        Q3 = np.percentile(features, 75, axis=0)
        IQR = Q3 - Q1
        
        # Identify outliers
        outlier_mask = np.any((features < (Q1 - 1.5 * IQR)) | (features > (Q3 + 1.5 * IQR)), axis=1)
        outlier_indices = np.where(outlier_mask)[0]
        
        print(f"  Found {len(outlier_indices)} potential outliers out of {len(images)} images")
        outliers_info[class_name] = [image_paths[i] for i in outlier_indices]
        
        # Save outliers for manual review
        class_outlier_dir = os.path.join(output_dir, class_name)
        os.makedirs(class_outlier_dir, exist_ok=True)
        for idx in outlier_indices[:20]:  # Save up to 20 outliers per class
            img_path = image_paths[idx]
            shutil.copy(img_path, class_outlier_dir)
    
    return outliers_info

outliers = find_outliers('dataset/train', 'outliers_detected')

In [None]:
# Visualize detected outliers
def visualize_outliers(outliers_dict, max_per_class=10):
    """Visualize detected outliers"""
    for class_name, outlier_paths in outliers_dict.items():
        if len(outlier_paths) == 0:
            continue
        
        print(f"\nClass: {class_name} ({len(outlier_paths)} outliers)")
        n_show = min(max_per_class, len(outlier_paths))
        
        fig, axes = plt.subplots(1, n_show, figsize=(20, 3))
        if n_show == 1:
            axes = [axes]
        
        for i in range(n_show):
            img = Image.open(outlier_paths[i])
            axes[i].imshow(img, cmap='gray')
            axes[i].axis('off')
            axes[i].set_title(os.path.basename(outlier_paths[i]), fontsize=8)
        
        plt.tight_layout()
        plt.show()

visualize_outliers(outliers)

## 4. Clean Dataset - Remove Outliers

In [None]:
# Create cleaned dataset
def create_cleaned_dataset(base_path, outliers_dict, output_path='dataset_cleaned/train'):
    """Create a cleaned version of the dataset"""
    os.makedirs(output_path, exist_ok=True)
    
    removed_count = 0
    kept_count = 0
    
    for class_name in sorted(os.listdir(base_path)):
        class_path = os.path.join(base_path, class_name)
        if not os.path.isdir(class_path) or class_name.startswith('.'):
            continue
        
        output_class_path = os.path.join(output_path, class_name)
        os.makedirs(output_class_path, exist_ok=True)
        
        outlier_set = set(outliers_dict.get(class_name, []))
        
        for img_name in os.listdir(class_path):
            if img_name.startswith('.'):
                continue
            img_path = os.path.join(class_path, img_name)
            
            # Remove extreme outliers (top 10% most suspicious)
            if img_path in outlier_set and len(outlier_set) > 0:
                # Only remove the most extreme outliers
                if random.random() < 0.3:  # Remove 30% of detected outliers
                    removed_count += 1
                    continue
            
            shutil.copy(img_path, output_class_path)
            kept_count += 1
    
    print(f"\nCleaning complete:")
    print(f"  Kept: {kept_count} images")
    print(f"  Removed: {removed_count} images")
    
    return output_path

cleaned_train_path = create_cleaned_dataset('dataset/train', outliers, 'dataset_cleaned/train')

## 5. Data Augmentation
Apply various augmentation techniques to balance and expand the dataset

In [None]:
def augment_image(img, augmentation_type):
    """Apply various augmentation techniques"""
    if augmentation_type == 'rotate_small':
        angle = random.randint(-15, 15)
        return img.rotate(angle, fillcolor=255)
    
    elif augmentation_type == 'brightness':
        enhancer = ImageEnhance.Brightness(img)
        factor = random.uniform(0.7, 1.3)
        return enhancer.enhance(factor)
    
    elif augmentation_type == 'contrast':
        enhancer = ImageEnhance.Contrast(img)
        factor = random.uniform(0.8, 1.2)
        return enhancer.enhance(factor)
    
    elif augmentation_type == 'blur':
        return img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.5, 1.5)))
    
    elif augmentation_type == 'sharpen':
        return img.filter(ImageFilter.SHARPEN)
    
    elif augmentation_type == 'shift':
        # Small random shift
        shift_x = random.randint(-3, 3)
        shift_y = random.randint(-3, 3)
        return img.transform(img.size, Image.AFFINE, (1, 0, shift_x, 0, 1, shift_y), fillcolor=255)
    
    return img

def balance_and_augment(cleaned_path, output_path='dataset_augmented/train', target_per_class=300):
    """Balance classes and apply augmentation"""
    os.makedirs(output_path, exist_ok=True)
    
    augmentation_types = ['rotate_small', 'brightness', 'contrast', 'blur', 'shift']
    
    for class_name in sorted(os.listdir(cleaned_path)):
        class_path = os.path.join(cleaned_path, class_name)
        if not os.path.isdir(class_path) or class_name.startswith('.'):
            continue
        
        output_class_path = os.path.join(output_path, class_name)
        os.makedirs(output_class_path, exist_ok=True)
        
        images = [f for f in os.listdir(class_path) if not f.startswith('.')]
        current_count = len(images)
        
        print(f"\nClass {class_name}: {current_count} images -> targeting {target_per_class}")
        
        # Copy original images
        for img_name in images:
            shutil.copy(os.path.join(class_path, img_name), output_class_path)
        
        # Augment to reach target
        needed = target_per_class - current_count
        if needed > 0:
            augmented = 0
            while augmented < needed:
                # Pick a random image from this class
                img_name = random.choice(images)
                img_path = os.path.join(class_path, img_name)
                img = Image.open(img_path)
                
                # Apply random augmentation
                aug_type = random.choice(augmentation_types)
                augmented_img = augment_image(img, aug_type)
                
                # Save augmented image
                base_name = os.path.splitext(img_name)[0]
                ext = os.path.splitext(img_name)[1]
                aug_name = f"{base_name}_aug{augmented}_{aug_type}{ext}"
                augmented_img.save(os.path.join(output_class_path, aug_name))
                augmented += 1
            
            print(f"  Added {augmented} augmented images")
    
    return output_path

augmented_train_path = balance_and_augment(cleaned_train_path, 'dataset_augmented/train', target_per_class=280)

## 6. Copy Validation Set

In [None]:
# Copy validation set (no augmentation needed, it's already balanced)
def copy_validation_set(source_path='dataset/val', dest_path='dataset_augmented/val'):
    if os.path.exists(dest_path):
        shutil.rmtree(dest_path)
    shutil.copytree(source_path, dest_path)
    print(f"Validation set copied to {dest_path}")

copy_validation_set()

## 7. Verify Final Dataset

In [None]:
final_train_counts = count_images('dataset_augmented/train')
final_val_counts = count_images('dataset_augmented/val')

print("\nFinal Training set distribution:")
for class_name, count in final_train_counts.items():
    print(f"  {class_name}: {count}")
print(f"  Total: {sum(final_train_counts.values())}")

print("\nFinal Validation set distribution:")
for class_name, count in final_val_counts.items():
    print(f"  {class_name}: {count}")
print(f"  Total: {sum(final_val_counts.values())}")

print(f"\nTotal dataset size: {sum(final_train_counts.values()) + sum(final_val_counts.values())}")
print("(Should be under 10,000 limit)")

# Visualize final distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes[0].bar(final_train_counts.keys(), final_train_counts.values())
axes[0].set_title('Final Training Set Distribution')
axes[0].set_xlabel('Class')
axes[0].set_ylabel('Number of Images')
axes[0].tick_params(axis='x', rotation=45)

axes[1].bar(final_val_counts.keys(), final_val_counts.values())
axes[1].set_title('Final Validation Set Distribution')
axes[1].set_xlabel('Class')
axes[1].set_ylabel('Number of Images')
axes[1].tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.show()

## 8. Package for Download

In [None]:
# Create data_original directory as expected by train.py
if os.path.exists('data_original'):
    shutil.rmtree('data_original')
shutil.copytree('dataset_augmented', 'data_original')
print("Created data_original directory ready for training!")

# Create zip for download
!zip -r data_original.zip data_original/
print("\nCreated data_original.zip - download this and extract in your project directory")

In [None]:
# If on Colab, download the prepared dataset
# from google.colab import files
# files.download('data_original.zip')

## Next Steps:
1. Download `data_original.zip` from Colab
2. Extract it in your project directory (where train.py is located)
3. Run `python train.py` locally
4. Monitor the validation accuracy - aim for >90% (ideally >93%)
5. Submit the `best_model.weights.h5` file

If accuracy is still below target, consider:
- Removing more outliers (increase removal rate in step 4)
- Adding more aggressive augmentation
- Manually reviewing and removing clearly mislabeled images