# ResNet Model Using Custom Architecture and Pre-Trained Model 
- Used Custom Architecture for faster training 
- need 224x224 image size for ResNet50
- Accuracy of ~93%

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
from tensorflow.keras.applications import ResNet50
import numpy as np
import matplotlib.pyplot as plt
import os
import json
from sklearn.model_selection import train_test_split


In [None]:
BATCH_SIZE = 128
IMAGE_SIZE = 28  # Quick Draw sketches are typically 28x28
NUM_CLASSES = 30  # We'll use 10 classes for demonstration
EPOCHS = 100
LEARNING_RATE = 0.0001


In [None]:
def load_quickdraw_data(num_classes=10, samples_per_class=10000):
    """
    Loads data from the Quick Draw dataset.
    We'll download some classes from the npy files.
    """
    # List of available categories
    categories = ['Airplane', 'Apple', 'Bicycle', 'Book', 'Car', 'Cat', 'Chair', 'Clock', 'Dog', 'Door', 'Eye', 'Fish', 'Flower', 'Fork', 'House', 'Key', 'Ladder', 'Moon', 'Mountain', 'Pizza', 'Rainbow', 'Shoe', 'Smiley Face', 'Star', 'Stop Sign', 'Sun', 'Table', 'Tennis Racquet', 'Tree', 'Wheel']

    X = []
    y = []

    print("Loading data...")

    for class_index, category in enumerate(categories):
        print(f"Loading {category} data...")
        # Try to download the data if not available locally
        try:
            # Using the numpy binary files available from Google
            # url = f"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/{category}.npy"
            data = np.load(f"/content/drive/MyDrive/quick_draw_data/{category.lower()}.npy")
            data = data[:20000]  # Take only the specified number of samples
            X.append(data)
            y.append(np.full(data.shape[0], class_index))
            print(f"Successfully loaded {len(data)} samples for {category}")
        except Exception as e:
            print(f"Failed to load {category}: {e}")

    # Combine data from all classes
    X = np.vstack(X)
    y = np.hstack(y)

    # Reshape and normalize the data
    X = X.reshape(-1, 28, 28, 1).astype('float32') / 255.0

    # Convert labels to one-hot encoding
    y = tf.keras.utils.to_categorical(y, num_classes=num_classes)

    return X, y, categories


In [None]:
# Directory to save model checkpoints
checkpoint_dir = './drive/MyDrive/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
# Function to create data generators with augmentation
def create_data_generators(X_train, y_train, X_val, y_val):
    """
    Creates data generators with augmentation for training
    and validation data.
    """
    # Data augmentation for training data
    train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rotation_range=10,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.1,
        zoom_range=0.1,
        horizontal_flip=False,  # Sketches might lose meaning when flipped
        fill_mode='nearest'
    )

    # No augmentation for validation data
    val_datagen = tf.keras.preprocessing.image.ImageDataGenerator()

    # Create data generators
    train_generator = train_datagen.flow(
        X_train, y_train,
        batch_size=BATCH_SIZE
    )

    val_generator = val_datagen.flow(
        X_val, y_val,
        batch_size=BATCH_SIZE
    )

    return train_generator, val_generator


In [None]:
# Create a ResNet-based model for sketch recognition
def create_resnet_model(input_shape, num_classes):
    """
    Creates a ResNet-based model adapted for sketch recognition.
    """
    # For grayscale input, replicate the channel to match ResNet's input
    inputs = layers.Input(shape=input_shape)

    # Convert from grayscale to RGB by replicating the channel
    if input_shape[-1] == 1:
        x = layers.Concatenate()([inputs, inputs, inputs])
    else:
        x = inputs

    # Resize images if needed
    if input_shape[0] < 32:
        x = layers.UpSampling2D(size=(2, 2))(x)

    # Use ResNet50 without the top layer
    base_model = ResNet50(include_top=False, weights='imagenet', input_tensor=x)

    # Add custom top layers
    x = base_model.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs=inputs, outputs=outputs)

    # Freeze some of the early ResNet layers for transfer learning
    for layer in base_model.layers[:100]:
        layer.trainable = False

    return model


In [None]:
def create_callbacks():
    """
    Creates callbacks for model training:
    - Model checkpoint to save the best model
    - Early stopping to prevent overfitting
    - Learning rate reducer
    - TensorBoard for monitoring
    """
    # Model checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, 'model_weights.weights.h5')
    model_checkpoint = callbacks.ModelCheckpoint(
        checkpoint_path,
        monitor='val_accuracy',
        save_best_only=True,
        save_weights_only=True,
        mode='max',
        verbose=1
    )

    # Early stopping
    early_stopping = callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    )

    # Learning rate reducer
    reduce_lr = callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=3,
        min_lr=1e-6,
        verbose=1
    )

    # TensorBoard
    log_dir = './drive/MyDrive/logs'
    os.makedirs(log_dir, exist_ok=True)
    tensorboard = callbacks.TensorBoard(
        log_dir=log_dir,
        histogram_freq=1,
        write_graph=True
    )

    return [model_checkpoint, early_stopping, reduce_lr, tensorboard]


In [None]:
def residual_block(x, filters, kernel_size=3, strides=1, conv_shortcut=False):
    """
    Creates a residual block as described in the ResNet paper.
    """
    shortcut = x

    if conv_shortcut:
        shortcut = layers.Conv2D(filters, 1, strides=strides, padding='same')(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Conv2D(filters, kernel_size, strides=strides, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters, kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)

    x = layers.add([shortcut, x])
    x = layers.Activation('relu')(x)

    return x

# Custom ResNet implementation specifically for sketch recognition
def create_custom_resnet(input_shape, num_classes, depth=20):
    """
    Creates a custom ResNet implementation specifically for sketch recognition.
    The depth parameter determines the number of residual blocks.
    """
    if (depth - 2) % 6 != 0:
        raise ValueError('depth should be 6n+2 (e.g., 20, 32, 44)')

    num_blocks = (depth - 2) // 6

    inputs = layers.Input(shape=input_shape)

    # Initial convolution
    x = layers.Conv2D(16, 3, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # First stack of residual blocks (16 filters)
    for i in range(num_blocks):
        conv_shortcut = i == 0 and True or False
        x = residual_block(x, 16, conv_shortcut=conv_shortcut)

    # Second stack of residual blocks (32 filters)
    for i in range(num_blocks):
        strides = 2 if i == 0 else 1
        x = residual_block(x, 32, strides=strides, conv_shortcut=True if i == 0 else False)

    # Third stack of residual blocks (64 filters)
    for i in range(num_blocks):
        strides = 2 if i == 0 else 1
        x = residual_block(x, 64, strides=strides, conv_shortcut=True if i == 0 else False)

    # Global average pooling and final dense layer
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs=inputs, outputs=outputs)
    return model


In [None]:
def plot_training_history(history):
    """
    Plots the training and validation accuracy and loss.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot accuracy
    ax1.plot(history.history['accuracy'])
    ax1.plot(history.history['val_accuracy'])
    ax1.set_title('Model Accuracy')
    ax1.set_ylabel('Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.legend(['Train', 'Validation'], loc='lower right')

    # Plot loss
    ax2.plot(history.history['loss'])
    ax2.plot(history.history['val_loss'])
    ax2.set_title('Model Loss')
    ax2.set_ylabel('Loss')
    ax2.set_xlabel('Epoch')
    ax2.legend(['Train', 'Validation'], loc='upper right')

    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

# Function to evaluate model on test data
def evaluate_model(model, X_test, y_test, categories):
    """
    Evaluates the model on test data and displays a confusion matrix.
    """
    # Get model predictions
    y_pred = model.predict(X_test)
    y_pred_classes = np.argmax(y_pred, axis=1)
    y_true_classes = np.argmax(y_test, axis=1)

    # Calculate accuracy
    accuracy = np.mean(y_pred_classes == y_true_classes)
    print(f"\nTest Accuracy: {accuracy * 100:.2f}%")

    # Calculate confusion matrix
    from sklearn.metrics import confusion_matrix, classification_report
    cm = confusion_matrix(y_true_classes, y_pred_classes)

    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(categories))
    plt.xticks(tick_marks, categories, rotation=45)
    plt.yticks(tick_marks, categories)

    fmt = 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig('confusion_matrix.png')
    plt.close()

    # Print classification report
    report = classification_report(y_true_classes, y_pred_classes, target_names=categories)
    print("\nClassification Report:")
    print(report)

    # Save some sample predictions
    plt.figure(figsize=(15, 8))
    for i in range(15):
        plt.subplot(3, 5, i+1)
        idx = np.random.randint(0, X_test.shape[0])
        plt.imshow(X_test[idx].reshape(28, 28), cmap='gray')
        true_label = categories[y_true_classes[idx]]
        pred_label = categories[y_pred_classes[idx]]
        plt.title(f"True: {true_label}\nPred: {pred_label}")
        plt.axis('off')

    plt.tight_layout()
    plt.savefig('sample_predictions.png')
    plt.close()

In [None]:
def main():
    """
    Main function to run the sketch recognition model training and evaluation.
    """
    print("Setting up the Quick Draw sketch recognition model...")

    # Load the data
    X, y, categories = load_quickdraw_data(num_classes=NUM_CLASSES, samples_per_class=10000)

    # Split the data into training, validation, and test sets
    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

    print(f"Data loaded: Training: {X_train.shape[0]}, Validation: {X_val.shape[0]}, Test: {X_test.shape[0]}")

    # Create data generators
    train_generator, val_generator = create_data_generators(X_train, y_train, X_val, y_val)

    # Create the model
    # Choose one of the model architectures:
    # 1. Pre-trained ResNet50
    # model = create_resnet_model(input_shape=(28, 28, 1), num_classes=NUM_CLASSES)

    # 2. Custom ResNet implementation
    model = create_custom_resnet(input_shape=(28, 28, 1), num_classes=NUM_CLASSES, depth=20)

    # Compile the model
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    # Print model summary
    model.summary()

    # Create callbacks
    model_callbacks = create_callbacks()

    # Train the model
    print("\nTraining the model...")
    history = model.fit(
        train_generator,
        steps_per_epoch=len(X_train) // BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=val_generator,
        validation_steps=len(X_val) // BATCH_SIZE,
        callbacks=model_callbacks
    )

    # Plot training history
    plot_training_history(history)

    # Load the best model weights
    model.load_weights(os.path.join(checkpoint_dir, 'model_weights.weights.h5'))

    # Evaluate the model
    print("\nEvaluating the model...")
    evaluate_model(model, X_test, y_test, categories)

    # Save the model
    model.save('./drive/MyDrive/Model/quickdraw_resnet_model_50_000.h5')
    print("\nModel saved as 'quickdraw_resnet_model.h5'")

    # Save class mapping
    with open('./drive/MyDrive/Model/class_mapping.json', 'w') as f:
        json.dump({i: category for i, category in enumerate(categories)}, f)
    print("Class mapping saved as 'class_mapping.json'")

if __name__ == "__main__":
    main()

Setting up the Quick Draw sketch recognition model...
Loading data...
Loading Airplane data...
Successfully loaded 20000 samples for Airplane
Loading Apple data...
Successfully loaded 20000 samples for Apple
Loading Bicycle data...
Successfully loaded 20000 samples for Bicycle
Loading Book data...
Successfully loaded 20000 samples for Book
Loading Car data...
Successfully loaded 20000 samples for Car
Loading Cat data...
Successfully loaded 20000 samples for Cat
Loading Chair data...
Successfully loaded 20000 samples for Chair
Loading Clock data...
Successfully loaded 20000 samples for Clock
Loading Dog data...
Successfully loaded 20000 samples for Dog
Loading Door data...
Successfully loaded 20000 samples for Door
Loading Eye data...
Successfully loaded 20000 samples for Eye
Loading Fish data...
Successfully loaded 20000 samples for Fish
Loading Flower data...
Successfully loaded 20000 samples for Flower
Loading Fork data...
Successfully loaded 20000 samples for Fork
Loading House data


Training the model...


  self._warn_if_super_not_called()


Epoch 1/100
[1m2811/2812[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 51ms/step - accuracy: 0.5374 - loss: 1.7253
Epoch 1: val_accuracy improved from -inf to 0.83706, saving model to ./drive/MyDrive/checkpoints/model_weights.weights.h5
[1m2812/2812[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m180s[0m 55ms/step - accuracy: 0.5375 - loss: 1.7249 - val_accuracy: 0.8371 - val_loss: 0.5817 - learning_rate: 1.0000e-04
Epoch 2/100
[1m   1/2812[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:17[0m 27ms/step - accuracy: 0.7969 - loss: 0.7368




Epoch 2: val_accuracy improved from 0.83706 to 0.83745, saving model to ./drive/MyDrive/checkpoints/model_weights.weights.h5
[1m2812/2812[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 3ms/step - accuracy: 0.7969 - loss: 0.7368 - val_accuracy: 0.8374 - val_loss: 0.5819 - learning_rate: 1.0000e-04
Epoch 3/100
[1m2812/2812[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step - accuracy: 0.8552 - loss: 0.5220
Epoch 3: val_accuracy improved from 0.83745 to 0.86234, saving model to ./drive/MyDrive/checkpoints/model_weights.weights.h5
[1m2812/2812[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 51ms/step - accuracy: 0.8552 - loss: 0.5220 - val_accuracy: 0.8623 - val_loss: 0.4775 - learning_rate: 1.0000e-04
Epoch 4/100
[1m   1/2812[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:43[0m 37ms/step - accuracy: 0.8359 - loss: 0.4891
Epoch 4: val_accuracy improved from 0.86234 to 0.86439, saving model to ./drive/MyDrive/checkpoints/model_weights.weights.h5
[1m2812/2812[0m 

KeyboardInterrupt: 

In [None]:
# Load the data
X, y, categories = load_quickdraw_data(num_classes=NUM_CLASSES, samples_per_class=10000)

# Split the data into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

Loading data...
Loading Airplane data...
Successfully loaded 20000 samples for Airplane
Loading Apple data...
Successfully loaded 20000 samples for Apple
Loading Bicycle data...
Successfully loaded 20000 samples for Bicycle
Loading Book data...
Successfully loaded 20000 samples for Book
Loading Car data...
Successfully loaded 20000 samples for Car
Loading Cat data...
Successfully loaded 20000 samples for Cat
Loading Chair data...
Successfully loaded 20000 samples for Chair
Loading Clock data...
Successfully loaded 20000 samples for Clock
Loading Dog data...
Successfully loaded 20000 samples for Dog
Loading Door data...
Successfully loaded 20000 samples for Door
Loading Eye data...
Successfully loaded 20000 samples for Eye
Loading Fish data...
Successfully loaded 20000 samples for Fish
Loading Flower data...
Successfully loaded 20000 samples for Flower
Loading Fork data...
Successfully loaded 20000 samples for Fork
Loading House data...
Successfully loaded 20000 samples for House
Loadin

In [None]:
# Load the best model weights
# 2. Custom ResNet implementation
model = create_custom_resnet(input_shape=(28, 28, 1), num_classes=NUM_CLASSES, depth=20)

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)
model.load_weights(os.path.join(checkpoint_dir, 'model_weights.weights.h5'))

# Evaluate the model
print("\nEvaluating the model...")
evaluate_model(model, X_test, y_test, categories)

# Save the model
model.save('./drive/MyDrive/Model/quickdraw_resnet_model_20_000.h5')
print("\nModel saved as 'quickdraw_resnet_model.h5'")

# Save class mapping
with open('./drive/MyDrive/Model/class_mapping.json', 'w') as f:
    json.dump({i: category for i, category in enumerate(categories)}, f)
print("Class mapping saved as 'class_mapping.json'")

  saveable.load_own_variables(weights_store.get(inner_path))



Evaluating the model...
[1m2813/2813[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 4ms/step

Test Accuracy: 93.40%

Classification Report:
                precision    recall  f1-score   support

      Airplane       0.93      0.90      0.92      2970
         Apple       0.98      0.97      0.97      3009
       Bicycle       0.95      0.98      0.97      2941
          Book       0.94      0.96      0.95      3061
           Car       0.96      0.95      0.96      3070
           Cat       0.86      0.87      0.86      3007
         Chair       0.95      0.94      0.94      2941
         Clock       0.96      0.94      0.95      2946
           Dog       0.80      0.84      0.82      3011
          Door       0.97      0.94      0.96      3042
           Eye       0.94      0.95      0.94      2950
          Fish       0.95      0.94      0.95      2960
        Flower       0.93      0.94      0.93      3048
          Fork       0.95      0.94      0.95      3016
         




Model saved as 'quickdraw_resnet_model.h5'
Class mapping saved as 'class_mapping.json'
