<a href="https://colab.research.google.com/github/NobodydeBunny/Cat_Dog_breed_classifire_AI_model/blob/main/Cats_breed/Cats_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

In [None]:
(dataset, info) = tfds.load("oxford_iiit_pet",
                            with_info=True,
                            as_supervised=True)

train_ds = dataset['train']
test_ds = dataset['test']
class_names = info.features['label'].names

In [None]:
cat_labels = tf.constant(
    [i for i, name in enumerate(class_names) if name[0].isupper()],
    dtype=tf.int64
)

def is_cat(image, label):
    return tf.reduce_any(tf.equal(cat_labels, label))

train_ds = train_ds.filter(is_cat)
test_ds = test_ds.filter(is_cat)

cat_class_names = [class_names[i] for i in cat_labels.numpy()]
NUM_CLASSES = len(cat_class_names)

In [None]:
plt.figure(figsize=(10,6))

# Create mapping from original cat label indices
label_map = {original_idx: new_idx for new_idx, original_idx in enumerate(cat_labels.numpy())}

for i, (image, label) in enumerate(train_ds.take(6)):
    plt.subplot(2,3,i+1)
    plt.imshow(image)
    plt.title(cat_class_names[label_map[label.numpy()]])
    plt.axis('off')

plt.show()

In [None]:
IMG_SIZE = 224

# Create lookup table for remapping CAT labels to 0-indexed values
# label_map must already be created from cat_labels
label_table = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant(list(label_map.keys()), dtype=tf.int64),
        values=tf.constant(list(label_map.values()), dtype=tf.int64)
    ),
    default_value=-1  # Should not be hit if all labels are cat breeds
)

def preprocess(image, label):
    # Cast image to float32
    image = tf.cast(image, tf.float32)

    # Resize image
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))

    # Normalize pixel values to range [0,1]
    image = image / 255.0

    # Remap original label indices to sequential 0â€“11
    remapped_label = label_table.lookup(label)

    return image, remapped_label

# Apply preprocessing
train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_ds  = test_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

# Shuffle dataset
train_ds = train_ds.shuffle(1000)

# Batch dataset
train_ds = train_ds.batch(32)
test_ds  = test_ds.batch(32)

# Prefetch for performance optimization
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
test_ds  = test_ds.prefetch(tf.data.AUTOTUNE)

In [None]:
counter = Counter()

for _, label in train_ds.unbatch():
    counter[label.numpy().item()] += 1 # Use .item() to get the scalar value

print("Class Distribution:")
for i in range(NUM_CLASSES):
    print(cat_class_names[i], ":", counter[i])

In [None]:
# Load pre-trained MobileNetV2 without top layers
base_model = tf.keras.applications.MobileNetV2(
    input_shape=(224, 224, 3),
    include_top=False,       # Exclude the default classifier
    weights='imagenet'       # Use ImageNet pre-trained weights
)

# Freeze the convolutional base to retain learned features
base_model.trainable = False

# Build a simple yet effective classification head
model = tf.keras.Sequential([
    base_model,                                      # Pre-trained
    tf.keras.layers.GlobalAveragePooling2D(),       # Pool feature maps into a
    tf.keras.layers.Dense(512, activation='relu'),  # Fully connected layer to
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')  # Output layer
])

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),  # Stable learning rate for transfer learning
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
history = model.fit(
    train_ds,
    validation_data=test_ds,
    epochs=20  # gives top layers enough time to learn
)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
model.save("/content/drive/MyDrive/Colab Notebooks/cat_breed_model.keras")
print("Model saved successfully.")

In [None]:
plt.figure(figsize=(12,6))

plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='Train')
plt.plot(history.history['val_accuracy'], label='Validation')
plt.title("Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Validation')
plt.title("Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.show()

In [None]:
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

In [None]:
y_true = []
y_pred = []

for images, labels in test_ds:
    preds = model.predict(images, verbose=0)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(preds, axis=1))

print(classification_report(y_true, y_pred, target_names=cat_class_names))

In [None]:
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(12,10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=cat_class_names,
            yticklabels=cat_class_names)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

In [None]:
def predict_and_display(index):
    image, label = list(test_ds.unbatch())[index]
    input_img = tf.expand_dims(image, axis=0)

    probs = model.predict(input_img, verbose=0)[0]
    pred_class = np.argmax(probs)

    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.imshow(image)
    plt.title(f"True: {cat_class_names[label.numpy()]}")
    plt.axis('off')

    plt.subplot(1,2,2)
    plt.barh(cat_class_names, probs*100)
    plt.xlim(0, 100)
    plt.title(f"Predicted: {cat_class_names[pred_class]}")
    plt.show()

In [None]:
predict_and_display(12)