In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from sklearn.model_selection import train_test_split
import numpy as np
import os
from pathlib import Path

# Configuration
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 0.001

# Dataset paths - UPDATE THESE
REAL_ACNE_PATH = "data/real_acne"  # 299 images
AI_ACNE_PATH = "data/ai_acne"      # 568 images
NO_ACNE_PATH = "data/no_acne"      # 5000 images

class AcneDatasetBuilder:
    def __init__(self, real_path, ai_path, no_acne_path, img_size=224):
        self.real_path = real_path
        self.ai_path = ai_path
        self.no_acne_path = no_acne_path
        self.img_size = img_size
        
    def load_and_preprocess_image(self, path):
        """Load and preprocess a single image"""
        img = tf.io.read_file(path)
        # Try to decode as JPEG first, then PNG
        try:
            img = tf.image.decode_jpeg(img, channels=3)
        except:
            img = tf.image.decode_png(img, channels=3)
        img = tf.image.resize(img, [self.img_size, self.img_size])
        img = tf.cast(img, tf.float32) / 255.0
        return img
    
    def create_dataset(self):
        """Create balanced dataset with proper validation"""
        # Load image paths
        real_acne_files = list(Path(self.real_path).glob("*.jpg")) + \
                         list(Path(self.real_path).glob("*.png"))
        ai_acne_files = list(Path(self.ai_path).glob("*.jpg")) + \
                       list(Path(self.ai_path).glob("*.png"))
        no_acne_files = list(Path(self.no_acne_path).glob("*.jpg")) + \
                       list(Path(self.no_acne_path).glob("*.png"))
        
        print(f"Found {len(real_acne_files)} real acne images")
        print(f"Found {len(ai_acne_files)} AI acne images")
        print(f"Found {len(no_acne_files)} no acne images")
        
        # Balance dataset - use all acne images, sample no-acne
        total_acne = len(real_acne_files) + len(ai_acne_files)
        no_acne_sample = np.random.choice(no_acne_files, 
                                         size=min(total_acne, len(no_acne_files)),
                                         replace=False)
        
        # Create labels (1 = acne, 0 = no acne)
        acne_files = real_acne_files + ai_acne_files
        all_files = acne_files + list(no_acne_sample)
        labels = [1] * len(acne_files) + [0] * len(no_acne_sample)
        
        # Create image types for tracking (0=real, 1=AI, 2=no_acne)
        img_types = [0] * len(real_acne_files) + \
                   [1] * len(ai_acne_files) + \
                   [2] * len(no_acne_sample)
        
        # Split: Use real images for validation, AI+some real for training
        real_indices = [i for i, t in enumerate(img_types) if t == 0]
        other_indices = [i for i, t in enumerate(img_types) if t != 0]
        
        # Split real images: 70% train, 30% validation
        real_train_idx, real_val_idx = train_test_split(
            real_indices, test_size=0.3, random_state=42
        )
        
        # All AI and no-acne go to training, some to validation
        other_train_idx, other_val_idx = train_test_split(
            other_indices, test_size=0.2, random_state=42
        )
        
        train_idx = real_train_idx + other_train_idx
        val_idx = real_val_idx + other_val_idx
        
        # Create train and validation sets
        train_files = [str(all_files[i]) for i in train_idx]
        train_labels = [labels[i] for i in train_idx]
        val_files = [str(all_files[i]) for i in val_idx]
        val_labels = [labels[i] for i in val_idx]
        
        print(f"\nDataset split:")
        print(f"Training samples: {len(train_files)}")
        print(f"Validation samples: {len(val_files)}")
        print(f"Validation real images: {len([i for i in val_idx if img_types[i] == 0])}")
        
        return train_files, train_labels, val_files, val_labels

def create_tf_dataset(file_paths, labels, batch_size, augment=False, img_size=224):
    """Create TensorFlow dataset from file paths and labels"""
    def load_image(path, label):
        img = tf.io.read_file(path)
        # Try JPEG first, then PNG
        try:
            img = tf.image.decode_jpeg(img, channels=3)
        except:
            img = tf.image.decode_png(img, channels=3)
        img = tf.image.resize(img, [img_size, img_size])
        img = tf.cast(img, tf.float32) / 255.0
        return img, label
    
    def augment_image(img, label):
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_brightness(img, 0.2)
        img = tf.image.random_contrast(img, 0.8, 1.2)
        img = tf.image.random_saturation(img, 0.8, 1.2)
        return img, label
    
    dataset = tf.data.Dataset.from_tensor_slices((file_paths, labels))
    dataset = dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    
    if augment:
        dataset = dataset.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
    
    dataset = dataset.shuffle(1000) if augment else dataset
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

def build_model(img_size=224):
    """Build custom CNN model"""
    # Clear any previous models from memory
    tf.keras.backend.clear_session()
    
    model = keras.Sequential([
        layers.Input(shape=(img_size, img_size, 3)),
        
        # Block 1
        layers.Conv2D(32, 3, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(32, 3, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2),
        layers.Dropout(0.25),
        
        # Block 2
        layers.Conv2D(64, 3, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(64, 3, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2),
        layers.Dropout(0.25),
        
        # Block 3
        layers.Conv2D(128, 3, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(128, 3, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2),
        layers.Dropout(0.25),
        
        # Block 4
        layers.Conv2D(256, 3, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(256, 3, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2),
        layers.Dropout(0.25),
        
        # Dense layers
        layers.GlobalAveragePooling2D(),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(1, activation='sigmoid')
    ])
    
    return model, None  # Return None for base_model since we're not using transfer learning

def train_model():
    """Main training function"""
    print("Building dataset...")
    builder = AcneDatasetBuilder(REAL_ACNE_PATH, AI_ACNE_PATH, NO_ACNE_PATH, IMG_SIZE)
    train_files, train_labels, val_files, val_labels = builder.create_dataset()
    
    print("\nCreating TensorFlow datasets...")
    train_ds = create_tf_dataset(train_files, train_labels, BATCH_SIZE, augment=True, img_size=IMG_SIZE)
    val_ds = create_tf_dataset(val_files, val_labels, BATCH_SIZE, augment=False, img_size=IMG_SIZE)
    
    print("\nBuilding model...")
    model, base_model = build_model(IMG_SIZE)
    
    # Compile model
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss='binary_crossentropy',
        metrics=['accuracy', 
                keras.metrics.Precision(name='precision'),
                keras.metrics.Recall(name='recall'),
                keras.metrics.AUC(name='auc')]
    )
    
    print("\nModel summary:")
    model.summary()
    
    # Callbacks
    callbacks = [
        keras.callbacks.ModelCheckpoint(
            'best_acne_model.keras',
            monitor='val_auc',
            mode='max',
            save_best_only=True,
            verbose=1
        ),
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            verbose=1
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-7,
            verbose=1
        ),
        keras.callbacks.TensorBoard(
            log_dir='logs',
            histogram_freq=1
        )
    ]
    
    # Phase 1: Train with frozen base
    print("\n" + "="*50)
    print("Phase 1: Training with frozen base model")
    print("="*50)
    history1 = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=EPOCHS // 2,
        callbacks=callbacks,
        verbose=1
    )
    
    # Phase 2: Fine-tune entire model (if using transfer learning)
    print("\n" + "="*50)
    print("Phase 2: Fine-tuning entire model")
    print("="*50)
    
    if base_model is not None:
        base_model.trainable = True
    
    # Recompile with lower learning rate
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE/10),
        loss='binary_crossentropy',
        metrics=['accuracy', 
                keras.metrics.Precision(name='precision'),
                keras.metrics.Recall(name='recall'),
                keras.metrics.AUC(name='auc')]
    )
    
    history2 = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=EPOCHS // 2,
        callbacks=callbacks,
        verbose=1
    )
    
    # Evaluate final model
    print("\n" + "="*50)
    print("Final Evaluation on Validation Set")
    print("="*50)
    results = model.evaluate(val_ds, verbose=1)
    
    print("\nFinal Metrics:")
    print(f"Loss: {results[0]:.4f}")
    print(f"Accuracy: {results[1]:.4f}")
    print(f"Precision: {results[2]:.4f}")
    print(f"Recall: {results[3]:.4f}")
    print(f"AUC: {results[4]:.4f}")
    
    # Save final model
    model.save('acne_detection_model_final.keras')
    print("\nModel saved as 'acne_detection_model_final.keras'")
    
    return model, history1, history2

if __name__ == "__main__":
    # Set random seeds for reproducibility
    tf.random.set_seed(42)
    np.random.seed(42)
    
    # Train model
    model, history1, history2 = train_model()
    
    print("\nTraining complete!")
    print("Best model saved as 'best_acne_model.keras'")
    print("Final model saved as 'acne_detection_model_final.keras'")

Building dataset...
Found 299 real acne images
Found 390 AI acne images
Found 5000 no acne images

Dataset split:
Training samples: 1072
Validation samples: 306
Validation real images: 90

Creating TensorFlow datasets...

Building model...


ValueError: Shape mismatch in layer #1 (named stem_conv)for weight stem_conv/kernel. Weight expects shape (3, 3, 1, 32). Received saved weight with shape (3, 3, 3, 32)