In [1]:
# Step 1: Import libraries and suppress warnings
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
tf.get_logger().setLevel('ERROR')

from tensorflow import keras
from keras import layers, models, losses
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

In [2]:
# Step 2: Configuration
BASE_DIR = '/home/hamid/ML/Datasets/cat-vs-dog/split_data/'
IMG_HEIGHT = 180
IMG_WIDTH = 180
BATCH_SIZE = 32
EPOCHS = 40

print("="*60)
print("Cat vs Dog Classification with AlexNet")
print("="*60)

# =============================================================================
# NEW SECTION: Validate and Clean Images (ADD THIS BEFORE TRAINING)
# =============================================================================

def validate_and_clean_images(base_dir):
    """Validate and remove corrupted images from all datasets"""
    
    datasets = ['train', 'val', 'test']
    categories = ['Cat', 'Dog']
    
    total_removed = 0
    
    for dataset in datasets:
        for category in categories:
            data_dir = os.path.join(base_dir, dataset, category)
            
            if not os.path.exists(data_dir):
                print(f"Directory not found: {data_dir}")
                continue
            
            print(f'\nValidating {dataset}/{category} files...')
            
            # Get all image filenames
            image_filenames = [os.path.join(data_dir, filename) 
                             for filename in os.listdir(data_dir)
                             if filename.lower().endswith(('.jpg', '.jpeg', '.png'))]
            
            removed_count = 0
            
            for image_path in image_filenames:
                try:
                    # Load image to check format
                    img = tf.keras.utils.load_img(image_path)
                    
                    # Check if it's JPEG format
                    if img.format != 'JPEG' and img.format != 'jpg':
                        print(f'Not jpeg. Removing... {img.format} {image_path}')
                        os.remove(image_path)
                        removed_count += 1
                        continue
                    
                    # Read image to check shape
                    img_array = mpimg.imread(image_path)
                    
                    # Check if image has valid shape
                    if len(img_array.shape) < 3:
                        print(f'Removing... Invalid shape {img_array.shape} {image_path}')
                        os.remove(image_path)
                        removed_count += 1
                        continue
                    
                    # Check number of channels
                    if img_array.shape[2] < 1 or img_array.shape[2] > 4 or img_array.shape[2] == 2:
                        print(f'Removing... Invalid channels {img_array.shape} {image_path}')
                        os.remove(image_path)
                        removed_count += 1
                        continue
                        
                except Exception as e:
                    print(f'Error with {image_path}: {e}')
                    print(f'Removing corrupted file: {image_path}')
                    try:
                        os.remove(image_path)
                        removed_count += 1
                    except:
                        pass
            
            # Count remaining files
            remaining = len([f for f in os.listdir(data_dir) 
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            
            print(f'✓ {dataset}/{category}: {remaining} images remaining, {removed_count} removed')
            total_removed += removed_count
    
    print(f'\n{"="*60}')
    print(f'Validation Complete! Total removed: {total_removed} images')
    print(f'{"="*60}\n')

# Run validation BEFORE creating datasets
print("\nStep 1: Validating and cleaning images...")
validate_and_clean_images(BASE_DIR)


Cat vs Dog Classification with AlexNet

Step 1: Validating and cleaning images...

Validating train/Cat files...
✓ train/Cat: 8685 images remaining, 0 removed

Validating train/Dog files...
✓ train/Dog: 8648 images remaining, 0 removed

Validating val/Cat files...
✓ val/Cat: 2470 images remaining, 0 removed

Validating val/Dog files...
✓ val/Dog: 2474 images remaining, 0 removed

Validating test/Cat files...
✓ test/Cat: 1247 images remaining, 0 removed

Validating test/Dog files...
✓ test/Dog: 1237 images remaining, 0 removed

Validation Complete! Total removed: 0 images



In [3]:
# =============================================================================
# Step 3: Create datasets (AFTER validation)
# =============================================================================

def create_dataset(directory, img_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE, shuffle=True):
    """Create a tf.data dataset with proper error handling"""
    
    dataset = tf.keras.utils.image_dataset_from_directory(
        directory,
        image_size=img_size,
        batch_size=batch_size,
        label_mode='int',
        shuffle=shuffle
    )
    
    # Normalize pixel values to [0, 1]
    dataset = dataset.map(lambda x, y: (x/255.0, y))
    
    # Skip any remaining corrupted images
    dataset = dataset.apply(tf.data.experimental.ignore_errors())
    
    # Prefetch for performance
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    
    return dataset

# Create datasets
print("\nStep 2: Loading datasets...")
train_dataset = create_dataset(os.path.join(BASE_DIR, 'train'), shuffle=True)
val_dataset = create_dataset(os.path.join(BASE_DIR, 'val'), shuffle=False)
test_dataset = create_dataset(os.path.join(BASE_DIR, 'test'), shuffle=False)
print("✓ Datasets loaded successfully!\n")



Step 2: Loading datasets...
Found 17333 files belonging to 2 classes.


I0000 00:00:1764946941.301434   11054 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5603 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4060, pci bus id: 0000:01:00.0, compute capability: 8.9


Found 4944 files belonging to 2 classes.
Found 2484 files belonging to 2 classes.
✓ Datasets loaded successfully!



In [4]:
img_height = 180
img_width = 180
model = models.Sequential()
model.add(layers.Input(shape=(img_height, img_width, 3)))
model.add(layers.Rescaling(1./255))
model.add(layers.Conv2D(96, 11, strides=4, padding='same'))
model.add(layers.Lambda(tf.nn.local_response_normalization))
model.add(layers.Activation('relu'))
model.add(layers.MaxPooling2D(3, strides=2))
model.add(layers.Conv2D(256, 5, strides=4, padding='same'))
model.add(layers.Lambda(tf.nn.local_response_normalization))
model.add(layers.Activation('relu'))
model.add(layers.MaxPooling2D(3, strides=2))
model.add(layers.Conv2D(384, 3, strides=4, padding='same'))
model.add(layers.Activation('relu'))
model.add(layers.Conv2D(384, 3, strides=4, padding='same'))
model.add(layers.Activation('relu'))
model.add(layers.Conv2D(256, 3, strides=4, padding='same'))
model.add(layers.Activation('relu'))
model.add(layers.Flatten())
model.add(layers.Dense(4096, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(4096, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(2, activation='softmax'))

model.summary()

In [5]:
model.compile(
    optimizer='adam',
    loss=losses.sparse_categorical_crossentropy,
    metrics=['accuracy']
    )
history = model.fit(
    train_dataset,
    epochs=EPOCHS,
    validation_data=val_dataset,
    verbose=1
    )

Epoch 1/40
     11/Unknown [1m6s[0m 16ms/step - accuracy: 0.5066 - loss: 0.6949

I0000 00:00:1764946948.003580   11142 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


    542/Unknown [1m18s[0m 23ms/step - accuracy: 0.4956 - loss: 0.6948



[1m542/542[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 30ms/step - accuracy: 0.4922 - loss: 0.6941 - val_accuracy: 0.5004 - val_loss: 0.6932
Epoch 2/40
[1m542/542[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.4990 - loss: 0.6934

KeyboardInterrupt: 

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(15,15))
axs[0].plot(history.history['loss'])
axs[0].plot(history.history['val_loss'])
axs[0].title.set_text('Training Loss vs Validation Loss')
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Loss')
axs[0].legend(['Train', 'Val'])
axs[1].plot(history.history['accuracy'])
axs[1].plot(history.history['val_accuracy'])
axs[1].title.set_text('Training Accuracy vs Validation Accuracy')
axs[1].set_xlabel('Epochs')
axs[1].set_ylabel('Accuracy')
axs[1].legend(['Train', 'Val'])

In [None]:
model.evaluate(test_ds)