# Quick Draw Sketch Recognition Model Training

This notebook demonstrates the training process for the sketch recognition model using the processed Quick Draw dataset.

In [None]:
# Import required libraries
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from pathlib import Path
import time
from datetime import datetime

# Add parent directory to path to import from app modules
sys.path.append('..')

In [None]:
# Import custom modules
from app.services.data_loader_processed import ProcessedDataLoader
from app.services.model_builder import QuickDrawModelBuilder
from app.utils.visualization import plot_confusion_matrix, visualize_model_predictions

## 1. Define Paths and Configuration

In [None]:
# Define paths
base_dir = Path('..')
data_dir = base_dir / "app" / "datasets" / "processed"
model_dir = base_dir / "app" / "models" / "quickdraw"

# Ensure directories exist
model_dir.mkdir(parents=True, exist_ok=True)

# Configuration parameters
config = {
    'model_type': 'advanced',  # 'simple', 'advanced', or 'mobilenet'
    'batch_size': 64,
    'epochs': 20,
    'learning_rate': 0.001,
    'max_per_class': None,  # Limit samples per class (None for all)
    'data_augmentation': True
}

# Show configuration
for key, value in config.items():
    print(f"{key}: {value}")

## 2. Load and Prepare Dataset

In [None]:
# Load the processed dataset
data_loader = ProcessedDataLoader(data_dir)

# Get class names
class_names = data_loader.class_names
print(f"Found {len(class_names)} classes: {class_names}")

In [None]:
# Create data generators with augmentation
train_generator, val_generator, test_generator, _ = data_loader.get_data_generators(
    batch_size=config['batch_size'],
    augmentation=config['data_augmentation'],
    max_per_class=config['max_per_class']
)

In [None]:
# Load a small sample of images to visualize
dataset = data_loader.load_dataset(max_per_class=100)
X_train, y_train, _ = dataset['train']
X_val, y_val, _ = dataset['validation']
X_test, y_test, _ = dataset['test']

print(f"X_train shape: {X_train.shape}")
print(f"X_val shape: {X_val.shape}")
print(f"X_test shape: {X_test.shape}")

In [None]:
# Visualize some training examples
plt.figure(figsize=(12, 8))
for i in range(15):  # Display 15 images
    plt.subplot(3, 5, i+1)
    img = X_train[i]
    if img.shape[-1] == 1:  # Grayscale
        plt.imshow(img.reshape(img.shape[0], img.shape[1]), cmap='gray')
    else:  # RGB
        plt.imshow(img)
    class_idx = np.argmax(y_train[i])
    plt.title(class_names[class_idx])
    plt.axis('off')
plt.tight_layout()
plt.show()

## 3. Build and Train the Model

In [None]:
# Check if GPU is available
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    print(f"Using {len(gpus)} GPU(s)")
    for gpu in gpus:
        print(f"  {gpu}")
    # Configure memory growth to prevent OOM errors
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(f"Error configuring GPU: {e}")
else:
    print("No GPU found, using CPU")

In [None]:
# Initialize model builder
model_builder = QuickDrawModelBuilder()

# Build the selected model architecture
if config['model_type'] == 'simple':
    model = model_builder.build_simple_cnn(len(class_names))
elif config['model_type'] == 'advanced':
    model = model_builder.build_advanced_cnn(len(class_names))
elif config['model_type'] == 'mobilenet':
    model = model_builder.build_mobilenet_based(len(class_names))
else:
    raise ValueError(f"Unknown model type: {config['model_type']}")

# Model summary
model.summary()

In [None]:
# Set up callbacks
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"quickdraw_model_{config['model_type']}_{timestamp}.h5"
model_path = model_dir / model_filename

# TensorBoard callback
log_dir = model_dir / "logs" / f"{config['model_type']}_{timestamp}"
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1
)

# ModelCheckpoint callback
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=str(model_path),
    monitor='val_accuracy',
    save_best_only=True,
    verbose=1
)

# Early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

# Reduce learning rate on plateau
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=3,
    min_lr=0.00001,
    verbose=1
)

callbacks_list = [tensorboard_callback, checkpoint_callback, early_stopping, reduce_lr]

In [None]:
# Train the model
start_time = time.time()
history = model_builder.train(
    train_generator,
    val_generator,
    epochs=config['epochs'],
    batch_size=config['batch_size'],
    callbacks_list=callbacks_list
)

# Calculate training time
training_time = time.time() - start_time
print(f"Training completed in {training_time:.2f} seconds")

## 4. Evaluate and Visualize Model Performance

In [None]:
# Plot training history
history_plot_path = model_dir / f"training_history_{config['model_type']}_{timestamp}.png"
model_builder.plot_training_history(save_path=str(history_plot_path))

In [None]:
# Evaluate on test set
print("Evaluating model on test data...")
metrics = model_builder.evaluate(test_generator)
print(f"Test accuracy: {metrics.get('accuracy', 0):.4f}")

In [None]:
# Generate predictions on test set
y_pred = model.predict(X_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_test_classes = np.argmax(y_test, axis=1)

In [None]:
# Plot confusion matrix
cm = confusion_matrix(y_test_classes, y_pred_classes)
plt.figure(figsize=(12, 10))
plot_confusion_matrix(cm, class_names, normalize=True)
plt.savefig(model_dir / f"confusion_matrix_{config['model_type']}_{timestamp}.png")
plt.show()

In [None]:
# Visualize some predictions
fig = visualize_model_predictions(model, X_test, y_test, class_names, num_images=8)
plt.savefig(model_dir / f"prediction_examples_{config['model_type']}_{timestamp}.png")
plt.show()

## 5. Save the Model with Metadata

In [None]:
# Create metadata
metadata = {
    'input_shape': model.input_shape[1:],
    'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'num_classes': len(class_names),
    'class_names': class_names,
    'model_type': config['model_type'],
    'config': config,
    'metrics': {
        'accuracy': float(metrics.get('accuracy', 0)),
        'loss': float(metrics.get('loss', 0))
    },
    'training_time_seconds': training_time
}

# Add training history
if model_builder.history is not None:
    metadata['training_history'] = {
        'accuracy': float(max(model_builder.history.get('accuracy', [0]))),
        'val_accuracy': float(max(model_builder.history.get('val_accuracy', [0]))),
        'loss': float(min(model_builder.history.get('loss', [0]))),
        'val_loss': float(min(model_builder.history.get('val_loss', [0]))),
        'epochs_trained': len(model_builder.history.get('accuracy', [])),
    }

# Save model with metadata
result = model_builder.save_model(str(model_path), class_names=class_names)
print(f"Model saved to {result['model_path']}")
print(f"Metadata saved to {result['metadata_path']}")

In [None]:
# Convert model to TensorFlow Lite format for inference
from app.utils.model_utils import convert_model_to_tflite

tflite_path = str(model_path).replace('.h5', '.tflite')
tflite_file = convert_model_to_tflite(model, tflite_path, quantize=True)
print(f"TFLite model saved to {tflite_file}")

## 6. Model Inference Example

In [None]:
# Load a single image and perform inference
from app.utils.model_utils import load_model_with_metadata

# Select a random test image
test_idx = np.random.randint(0, len(X_test))
test_image = X_test[test_idx]
true_label = np.argmax(y_test[test_idx])

# Reshape for model input (add batch dimension)
input_image = np.expand_dims(test_image, axis=0)

# Make prediction
predictions = model.predict(input_image)
predicted_class = np.argmax(predictions[0])

# Show the image and prediction
plt.figure(figsize=(6, 6))
if test_image.shape[-1] == 1:  # Grayscale
    plt.imshow(test_image.reshape(test_image.shape[0], test_image.shape[1]), cmap='gray')
else:  # RGB
    plt.imshow(test_image)
plt.title(f"True: {class_names[true_label]}\nPredicted: {class_names[predicted_class]}")
plt.axis('off')

# Show top 3 predictions
top_indices = predictions[0].argsort()[-3:][::-1]
print("Top 3 predictions:")
for i, idx in enumerate(top_indices):
    print(f"{i+1}. {class_names[idx]}: {predictions[0][idx]*100:.2f}%")

## 7. Summary

The model training process is complete. Here's a summary of what we've accomplished:

1. Loaded and visualized the processed Quick Draw dataset
2. Built and trained a CNN model for sketch recognition
3. Evaluated the model performance
4. Saved the trained model with metadata for inference
5. Converted the model to TensorFlow Lite format for efficient deployment