In [2]:
# Install MLflow for experiment tracking
!pip install mlflow -q

In [3]:
import os
import gc
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from PIL import Image
import mlflow
import mlflow.tensorflow

In [4]:
# 1. FORCE CLEAR HANGING MEMORY
tf.keras.backend.clear_session()
gc.collect()

# 2. SAFE CONFIGURATION
IMG_SIZE = (224, 224)
BATCH_SIZE = 16
EPOCHS = 10
data_dir = '/content/PetImages'

In [5]:
# 3. AGGRESSIVE DEEP-CLEAN OF IMAGES
print("Starting deep image scan... (This will take a minute)")
bad_count = 0
for root, dirs, files in os.walk(data_dir):
    for file in files:
        file_path = os.path.join(root, file)
        try:
            # Force PIL to load every single pixel to catch deep corruption
            with Image.open(file_path) as img:
                img.load()
                # Check if image is standard RGB (drops Grayscale, CMYK, etc.)
                if img.mode != 'RGB':
                    raise ValueError("Non-RGB format")
        except Exception:
            os.remove(file_path)
            bad_count += 1
print(f"Cleanup complete! Deleted {bad_count} corrupted/incompatible images.")

Starting deep image scan... (This will take a minute)




Cleanup complete! Deleted 0 corrupted/incompatible images.


In [6]:
# 4. BUILD DATASETS
print("\nBuilding datasets...")
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir, validation_split=0.2, subset="training", seed=42,
    image_size=IMG_SIZE, batch_size=BATCH_SIZE
)
val_ds_full = tf.keras.utils.image_dataset_from_directory(
    data_dir, validation_split=0.2, subset="validation", seed=42,
    image_size=IMG_SIZE, batch_size=BATCH_SIZE
)

val_batches = tf.data.experimental.cardinality(val_ds_full)
test_ds = val_ds_full.take(val_batches // 2)
val_ds = val_ds_full.skip(val_batches // 2)

# Use manual prefetch size instead of AUTOTUNE to cap RAM usage
train_ds = train_ds.shuffle(500).prefetch(buffer_size=2)
val_ds = val_ds.prefetch(buffer_size=2)
test_ds = test_ds.prefetch(buffer_size=2)

# 5. BUILD MODEL
data_augmentation = keras.Sequential([
  layers.RandomFlip("horizontal"),
  layers.RandomRotation(0.1),
])

model = models.Sequential([
    layers.Rescaling(1./255, input_shape=(224, 224, 3)),
    data_augmentation,
    layers.Conv2D(32, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])


Building datasets...
Found 24979 files belonging to 2 classes.
Using 19984 files for training.
Found 24979 files belonging to 2 classes.
Using 4995 files for validation.


  super().__init__(**kwargs)


In [7]:
# 6. TRAIN WITH MLFLOW
print("\nStarting Training Pipeline...")
mlflow.set_tracking_uri("file:./mlruns")
mlflow.set_experiment("Cat_Dog_Classification")

# End any crashed runs that might be hanging
if mlflow.active_run():
    mlflow.end_run()

with mlflow.start_run(run_name="Baseline_CNN_Safe"):
    mlflow.tensorflow.autolog()

    # Train the model
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=EPOCHS
    )

    # Evaluate and Log Test Metrics
    loss, accuracy = model.evaluate(test_ds)
    mlflow.log_metric("test_accuracy", accuracy)
    mlflow.log_metric("test_loss", loss)

    model.save("baseline_cnn.h5")
    print("\nTraining Complete & Model Saved!")


Starting Training Pipeline...


Epoch 1/10
[1m1248/1249[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 37ms/step - accuracy: 0.5687 - loss: 0.6846



[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 54ms/step - accuracy: 0.5688 - loss: 0.6846 - val_accuracy: 0.6959 - val_loss: 0.5876
Epoch 2/10
[1m1248/1249[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 37ms/step - accuracy: 0.7053 - loss: 0.5648



[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m86s[0m 61ms/step - accuracy: 0.7053 - loss: 0.5648 - val_accuracy: 0.7539 - val_loss: 0.5150
Epoch 3/10
[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.7505 - loss: 0.5062



[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 55ms/step - accuracy: 0.7505 - loss: 0.5062 - val_accuracy: 0.7787 - val_loss: 0.4619
Epoch 4/10
[1m1248/1249[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 37ms/step - accuracy: 0.7835 - loss: 0.4627



[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 55ms/step - accuracy: 0.7835 - loss: 0.4627 - val_accuracy: 0.8059 - val_loss: 0.4218
Epoch 5/10
[1m1248/1249[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 37ms/step - accuracy: 0.7991 - loss: 0.4289



[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 61ms/step - accuracy: 0.7991 - loss: 0.4289 - val_accuracy: 0.8147 - val_loss: 0.4050
Epoch 6/10
[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 44ms/step - accuracy: 0.8190 - loss: 0.3974 - val_accuracy: 0.8119 - val_loss: 0.4085
Epoch 7/10
[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.8323 - loss: 0.3693



[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m88s[0m 60ms/step - accuracy: 0.8323 - loss: 0.3693 - val_accuracy: 0.8235 - val_loss: 0.4023
Epoch 8/10
[1m1248/1249[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 37ms/step - accuracy: 0.8439 - loss: 0.3529



[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m84s[0m 56ms/step - accuracy: 0.8439 - loss: 0.3529 - val_accuracy: 0.8303 - val_loss: 0.3746
Epoch 9/10
[1m1248/1249[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 37ms/step - accuracy: 0.8446 - loss: 0.3514



[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 53ms/step - accuracy: 0.8446 - loss: 0.3514 - val_accuracy: 0.8355 - val_loss: 0.3662
Epoch 10/10
[1m1249/1249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 44ms/step - accuracy: 0.8577 - loss: 0.3232 - val_accuracy: 0.8351 - val_loss: 0.3793
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 574ms/step




[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 41ms/step - accuracy: 0.8467 - loss: 0.3408





Training Complete & Model Saved!


In [8]:
# Zip the mlruns folder and the model
!zip -r experiment_data.zip mlruns baseline_cnn.h5

from google.colab import files
files.download('experiment_data.zip')

  adding: mlruns/ (stored 0%)
  adding: mlruns/561519266848247737/ (stored 0%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d3114365/ (stored 0%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d3114365/meta.yaml (deflated 39%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d3114365/metrics/ (stored 0%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d3114365/tags/ (stored 0%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d3114365/tags/mlflow.source.type (stored 0%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d3114365/tags/mlflow.runName (stored 0%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d3114365/tags/mlflow.source.name (stored 0%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d3114365/tags/mlflow.user (stored 0%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d3114365/artifacts/ (stored 0%)
  adding: mlruns/561519266848247737/2840f751ec4b4a76b1b26be2d311436

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>