# Training and Saving Models for CIFAR Classifier

This notebook demonstrates how to train a model on the CIFAR-100 dataset and save it in a format that can be loaded by our TensorFlow.js application.

## Setup and Dependencies

First, let's install the required packages:

In [None]:
%pip install tensorflow tensorflowjs matplotlib numpy

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

## Load and Preprocess CIFAR-100 Dataset

In [None]:
# Load CIFAR-100 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()

# Normalize pixel values to be between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0

# Print dataset shapes
print(f'Training data shape: {x_train.shape}')
print(f'Training labels shape: {y_train.shape}')
print(f'Test data shape: {x_test.shape}')
print(f'Test labels shape: {y_test.shape}')

## Visualize Some Examples

In [None]:
# Define CIFAR-100 class names
class_names = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
    'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
    'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
    'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
    'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
    'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
    'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
    'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
    'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
    'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
    'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
    'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
    'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
    'worm'
]

# Display some training images
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i])
    # The CIFAR100 labels are 2D arrays with a single value, so we need to flatten them
    plt.xlabel(class_names[y_train[i][0]])
plt.tight_layout()
plt.show()

## Create a CNN Model for CIFAR-100

In [None]:
def create_model():
    model = tf.keras.Sequential([
        # Input layer - expects 32x32 RGB images
        tf.keras.layers.Input(shape=(32, 32, 3)),
        
        # First convolutional block
        tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Dropout(0.2),
        
        # Second convolutional block
        tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Dropout(0.3),
        
        # Third convolutional block
        tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Dropout(0.4),
        
        # Flatten and dense layers
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(100, activation='softmax')  # 100 classes for CIFAR-100
    ])
    
    return model

# Create and compile the model
model = create_model()
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Display model summary
model.summary()

## Train the Model

This will take some time to run. You may want to use a GPU runtime if available.

In [None]:
# Define callbacks for training
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3, min_lr=1e-6)
]

# Train the model
history = model.fit(
    x_train, y_train,
    epochs=50,  # You may want to reduce this for testing
    batch_size=64,
    validation_split=0.2,
    callbacks=callbacks
)

## Evaluate the Model

In [None]:
# Evaluate on test data
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f'Test loss: {test_loss:.4f}')
print(f'Test accuracy: {test_acc:.4f}')
# Make predictions on test data
y_pred = model.predict(x_test)

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.show()

## Save the Model for TensorFlow.js

Now we'll save the model in a format that can be loaded by TensorFlow.js.

In [None]:
# First, save the model in Keras format
model.save('cifar100_model.h5')

## Test Model Loading

Let's verify that we can load the saved model back.

In [None]:
# Load the saved model
loaded_model = tf.keras.models.load_model('cifar100_model.h5')

# Test on a few examples
test_images = x_test[:5]
predictions = loaded_model.predict(test_images)

# Display the images and predictions
plt.figure(figsize=(12, 12))
for i in range(5):
    plt.subplot(1, 5, i+1)
    plt.imshow(test_images[i])
    plt.title(f'True: {class_names[y_test[i][0]]}Pred: {class_names[np.argmax(predictions[i])]}')
    plt.axis('off')
plt.tight_layout()
plt.show()

## Instructions for Using the Model in the Web App

After running this notebook, you'll have a TensorFlow.js model saved in the `models/cifar100` directory. To use this model in your web application:

1. Copy the entire `models/cifar100` directory to your web application's public directory
2. Make sure the model files are accessible at the path specified in your `loadModel` function (`./models/cifar100/model.json`)
3. The application should now be able to load your trained model instead of using the mock model

### Model Directory Structure

The converted model will have the following structure:
```
models/
  cifar100/
    model.json           # Model topology
    group1-shard1of1.bin # Model weights
```

### Notes on Model Size

- The converted model may be quite large (tens of MB)
- For production use, consider using model quantization or a smaller architecture
- TensorFlow.js also supports loading models from a server, which can be more efficient for large models

## Training a MobileNet or EfficientNet Model

If you want to train a MobileNet or EfficientNet model instead, you can use transfer learning with pre-trained models from TensorFlow Hub. Here's a brief example for MobileNet:

In [None]:
# This is just a code example - not meant to be run in this notebook
# unless you want to train a MobileNet model

import tensorflow_hub as hub

# Resize images to 224x224 for MobileNet
x_train_resized = tf.image.resize(x_train, (224, 224))
x_test_resized = tf.image.resize(x_test, (224, 224))

# Create a MobileNet model with transfer learning
base_model = hub.KerasLayer('https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
                          trainable=False)

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(224, 224, 3)),
    base_model,
    tf.keras.layers.Dense(100, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model
model.fit(
    x_train_resized, y_train,
    epochs=10,
    validation_split=0.2
)

# Save for TensorFlow.js
model.save('mobilenet_cifar100.h5')
!mkdir -p models/mobilenet
!tensorflowjs_converter --input_format=keras mobilenet_cifar100.h5 models/mobilenet