<a href="https://colab.research.google.com/github/axel-sirota/cnn-visual-mastery/blob/main/Module2Demo1_BasicCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Creating a Basic CNN model



In this demo we will create a basic CNN model to solve the cats vs dogs problem

In [None]:
pip install tensorflow-datasets


In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

# Load the Cats vs Dogs dataset
(train_ds, validation_ds), ds_info = tfds.load(
    'cats_vs_dogs',
    split=['train[:80%]', 'train[80%:]'],
    as_supervised=True,  # Include labels
    with_info=True,
)


First we need to preprocess the data, normally resizing it and ensuring pixels are normalized. We will use Tensorflow Datasets functionality for that

In [None]:

# Preprocess the data
def preprocess(image, label):
    image = tf.image.resize(image, [64, 64])  # Resize images
    image = tf.cast(image, tf.float32) / 255.0  # Normalize pixel values
    return image, label

train_ds = train_ds.map(preprocess).batch(32).prefetch(buffer_size=tf.data.AUTOTUNE)
validation_ds = validation_ds.map(preprocess).batch(32).prefetch(buffer_size=tf.data.AUTOTUNE)


And now the model, we will just do 2 rounds of Convolutional layers and MaxPooling, followed by one round on FFN.

In [None]:

# Build the CNN model
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dense(128, activation='relu'),
    Dense(1, activation='sigmoid')  # Binary classification (cat or dog)
])

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(train_ds, validation_data=validation_ds, epochs=10)

# Evaluate the model
val_loss, val_accuracy = model.evaluate(validation_ds)
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
