# CNN Model Training for Product Classification

This notebook handles the training of a Convolutional Neural Network (CNN) model for product image classification.

## Objectives:
- Prepare training data from scraped product images
- Design and implement CNN architecture
- Train the model with proper validation
- Evaluate model performance
- Save trained model for deployment

In [None]:
# Import required libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import cv2
from PIL import Image
import json

# Deep learning libraries
try:
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers, models, optimizers, callbacks
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report, confusion_matrix
    print(f"TensorFlow version: {tf.__version__}")
    TENSORFLOW_AVAILABLE = True
except ImportError:
    print("TensorFlow not available. Please install: pip install tensorflow")
    TENSORFLOW_AVAILABLE = False

# Add parent directory to path
sys.path.append('..')
from services.cnn_model import CNNModel
from services.scraper import WebScraper

# Set random seeds for reproducibility
np.random.seed(42)
if TENSORFLOW_AVAILABLE:
    tf.random.set_seed(42)

# Set style for plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Libraries imported successfully!")

In [None]:
# Configuration
CONFIG = {
    'image_size': (224, 224),
    'batch_size': 32,
    'epochs': 50,
    'learning_rate': 0.001,
    'validation_split': 0.2,
    'test_split': 0.1,
    'data_dir': '../data/scraped_images',
    'model_save_path': '../models/cnn_product_classifier.h5',
    'min_images_per_class': 10
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

In [None]:
# Check if we have scraped data, if not create sample data
data_dir = Path(CONFIG['data_dir'])

if not data_dir.exists() or len(list(data_dir.glob('*'))) == 0:
    print("No scraped data found. Creating sample training data...")
    
    # Create sample product categories and images
    sample_products = [
        'wireless_headphones',
        'laptop_computer', 
        'smartphone',
        'coffee_maker',
        'running_shoes',
        'backpack',
        'watch',
        'camera',
        'tablet',
        'speaker'
    ]
    
    # Initialize web scraper
    scraper = WebScraper(download_dir=str(data_dir))
    
    # Scrape images for each product
    scraped_data = scraper.scrape_product_images(
        sample_products, 
        images_per_product=CONFIG['min_images_per_class']
    )
    
    # Create training dataset CSV
    csv_path = scraper.create_training_dataset(scraped_data)
    print(f"Training dataset CSV created: {csv_path}")
    
    # Get scraping statistics
    stats = scraper.get_scraping_stats()
    print(f"Scraping statistics: {stats}")
else:
    print(f"Found existing data directory: {data_dir}")

# List available product categories
if data_dir.exists():
    categories = [d.name for d in data_dir.iterdir() if d.is_dir()]
    print(f"\nAvailable product categories: {len(categories)}")
    for i, category in enumerate(categories, 1):
        image_count = len(list((data_dir / category).glob('*')))
        print(f"  {i}. {category}: {image_count} images")
else:
    print("No training data available. Please run the web scraper first.")
    categories = []

In [None]:
# Data preparation and augmentation
if TENSORFLOW_AVAILABLE and categories:
    print("Preparing data generators...")
    
    # Data augmentation for training
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        zoom_range=0.2,
        fill_mode='nearest',
        validation_split=CONFIG['validation_split']
    )
    
    # No augmentation for validation
    val_datagen = ImageDataGenerator(
        rescale=1./255,
        validation_split=CONFIG['validation_split']
    )
    
    # Create data generators
    train_generator = train_datagen.flow_from_directory(
        str(data_dir),
        target_size=CONFIG['image_size'],
        batch_size=CONFIG['batch_size'],
        class_mode='categorical',
        subset='training',
        shuffle=True
    )
    
    validation_generator = val_datagen.flow_from_directory(
        str(data_dir),
        target_size=CONFIG['image_size'],
        batch_size=CONFIG['batch_size'],
        class_mode='categorical',
        subset='validation',
        shuffle=False
    )
    
    # Get class information
    num_classes = train_generator.num_classes
    class_names = list(train_generator.class_indices.keys())
    
    print(f"\nDataset Information:")
    print(f"  Number of classes: {num_classes}")
    print(f"  Training samples: {train_generator.samples}")
    print(f"  Validation samples: {validation_generator.samples}")
    print(f"  Class names: {class_names}")
    
    # Display class distribution
    class_counts = {}
    for class_name in class_names:
        class_dir = data_dir / class_name
        if class_dir.exists():
            class_counts[class_name] = len(list(class_dir.glob('*')))
    
    print("\nClass distribution:")
    for class_name, count in class_counts.items():
        print(f"  {class_name}: {count} images")
else:
    print("Cannot prepare data - TensorFlow not available or no categories found")

In [None]:
# Visualize sample images
if TENSORFLOW_AVAILABLE and categories:
    print("Visualizing sample images...")
    
    # Get a batch of training images
    sample_batch = next(train_generator)
    images, labels = sample_batch
    
    # Plot sample images
    fig, axes = plt.subplots(2, 4, figsize=(15, 8))
    axes = axes.ravel()
    
    for i in range(min(8, len(images))):
        # Convert image back to displayable format
        img = images[i]
        
        # Get class name
        class_idx = np.argmax(labels[i])
        class_name = class_names[class_idx]
        
        axes[i].imshow(img)
        axes[i].set_title(f'Class: {class_name}')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Save the plot
    os.makedirs('../static/images', exist_ok=True)
    plt.savefig('../static/images/sample_training_images.png', dpi=300, bbox_inches='tight')
    print("Sample images saved to ../static/images/sample_training_images.png")
    
    # Reset generator
    train_generator.reset()

In [None]:
# Build CNN model
if TENSORFLOW_AVAILABLE and categories:
    print("Building CNN model...")
    
    # Initialize CNN model service
    cnn_service = CNNModel()
    
    # Create model architecture
    model = cnn_service.create_model(
        num_classes=num_classes,
        input_shape=(*CONFIG['image_size'], 3)
    )
    
    # Display model architecture
    print("\nModel Architecture:")
    model.summary()
    
    # Plot model architecture
    try:
        tf.keras.utils.plot_model(
            model, 
            to_file='../static/images/model_architecture.png',
            show_shapes=True,
            show_layer_names=True,
            rankdir='TB'
        )
        print("Model architecture diagram saved to ../static/images/model_architecture.png")
    except Exception as e:
        print(f"Could not save model diagram: {e}")
    
    # Calculate total parameters
    total_params = model.count_params()
    trainable_params = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
    non_trainable_params = total_params - trainable_params
    
    print(f"\nModel Parameters:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print(f"  Non-trainable parameters: {non_trainable_params:,}")
else:
    print("Cannot build model - TensorFlow not available or no categories found")

In [None]:
# Train the model
if TENSORFLOW_AVAILABLE and categories and 'model' in locals():
    print("Starting model training...")
    
    # Define callbacks
    callbacks_list = [
        callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=10,
            restore_best_weights=True,
            verbose=1
        ),
        callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=5,
            min_lr=0.0001,
            verbose=1
        ),
        callbacks.ModelCheckpoint(
            filepath=CONFIG['model_save_path'],
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        )
    ]
    
    # Calculate steps per epoch
    steps_per_epoch = train_generator.samples // CONFIG['batch_size']
    validation_steps = validation_generator.samples // CONFIG['batch_size']
    
    print(f"Training configuration:")
    print(f"  Epochs: {CONFIG['epochs']}")
    print(f"  Batch size: {CONFIG['batch_size']}")
    print(f"  Steps per epoch: {steps_per_epoch}")
    print(f"  Validation steps: {validation_steps}")
    
    # Train the model
    history = model.fit(
        train_generator,
        steps_per_epoch=steps_per_epoch,
        epochs=CONFIG['epochs'],
        validation_data=validation_generator,
        validation_steps=validation_steps,
        callbacks=callbacks_list,
        verbose=1
    )
    
    print("\nTraining completed!")
    
    # Save class names and metadata
    metadata = {
        'class_names': class_names,
        'num_classes': num_classes,
        'input_shape': [*CONFIG['image_size'], 3],
        'training_config': CONFIG
    }
    
    metadata_path = CONFIG['model_save_path'].replace('.h5', '_metadata.json')
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"Model metadata saved to {metadata_path}")
else:
    print("Cannot train model - requirements not met")

In [None]:
# Visualize training history
if 'history' in locals():
    print("Visualizing training history...")
    
    # Plot training history
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Accuracy
    axes[0, 0].plot(history.history['accuracy'], label='Training Accuracy')
    axes[0, 0].plot(history.history['val_accuracy'], label='Validation Accuracy')
    axes[0, 0].set_title('Model Accuracy')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Loss
    axes[0, 1].plot(history.history['loss'], label='Training Loss')
    axes[0, 1].plot(history.history['val_loss'], label='Validation Loss')
    axes[0, 1].set_title('Model Loss')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Top-5 Accuracy (if available)
    if 'top_5_accuracy' in history.history:
        axes[1, 0].plot(history.history['top_5_accuracy'], label='Training Top-5 Accuracy')
        axes[1, 0].plot(history.history['val_top_5_accuracy'], label='Validation Top-5 Accuracy')
        axes[1, 0].set_title('Top-5 Accuracy')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Top-5 Accuracy')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
    
    # Learning rate (if available)
    if 'lr' in history.history:
        axes[1, 1].plot(history.history['lr'], label='Learning Rate')
        axes[1, 1].set_title('Learning Rate Schedule')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_yscale('log')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Save training history plot
    plt.savefig('../static/images/training_history.png', dpi=300, bbox_inches='tight')
    print("Training history saved to ../static/images/training_history.png")
    
    # Print final metrics
    final_train_acc = history.history['accuracy'][-1]
    final_val_acc = history.history['val_accuracy'][-1]
    final_train_loss = history.history['loss'][-1]
    final_val_loss = history.history['val_loss'][-1]
    
    print(f"\nFinal Training Metrics:")
    print(f"  Training Accuracy: {final_train_acc:.4f}")
    print(f"  Validation Accuracy: {final_val_acc:.4f}")
    print(f"  Training Loss: {final_train_loss:.4f}")
    print(f"  Validation Loss: {final_val_loss:.4f}")
    
    # Check for overfitting
    overfitting_threshold = 0.1
    acc_diff = final_train_acc - final_val_acc
    if acc_diff > overfitting_threshold:
        print(f"\n⚠️  Warning: Possible overfitting detected (accuracy difference: {acc_diff:.4f})")
    else:
        print(f"\n✅ Good generalization (accuracy difference: {acc_diff:.4f})")

In [None]:
# Generate training summary report
if 'history' in locals():
    summary_report = f"""
CNN MODEL TRAINING SUMMARY REPORT
=================================

Model Configuration:
- Architecture: Custom CNN
- Input Shape: {CONFIG['image_size']} x 3
- Number of Classes: {num_classes}
- Total Parameters: {total_params:,}
- Trainable Parameters: {trainable_params:,}

Training Configuration:
- Epochs: {len(history.history['accuracy'])}
- Batch Size: {CONFIG['batch_size']}
- Learning Rate: {CONFIG['learning_rate']}
- Validation Split: {CONFIG['validation_split']}

Dataset Information:
- Training Samples: {train_generator.samples}
- Validation Samples: {validation_generator.samples}
- Classes: {', '.join(class_names)}

Final Performance:
- Training Accuracy: {final_train_acc:.4f} ({final_train_acc*100:.2f}%)
- Validation Accuracy: {final_val_acc:.4f} ({final_val_acc*100:.2f}%)
- Training Loss: {final_train_loss:.4f}
- Validation Loss: {final_val_loss:.4f}

Model Files:
- Model: {CONFIG['model_save_path']}
- Metadata: {metadata_path}
- Training History: ../static/images/training_history.png
- Sample Images: ../static/images/sample_training_images.png

Next Steps:
1. Test model on new images
2. Integrate with Flask application
3. Deploy for production use
4. Monitor performance and retrain if needed

Training completed successfully! 🎉
"""
    
    print(summary_report)
    
    # Save report to file
    with open('../data/training_report.txt', 'w') as f:
        f.write(summary_report)
    
    print("\nTraining report saved to ../data/training_report.txt")
else:
    print("No training history available to generate report.")