In [None]:
import tensorflow as tf
import numpy as np
import cv2
import glob
import os
import matplotlib.pyplot as plt
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from keras.models import Model
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.utils import class_weight
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve
import seaborn as sns

# Function to load and preprocess images
def load_images_from_directory(directory, label, target_size=(224, 224)):
    images = []
    labels = []
    files = glob.glob(os.path.join(directory, "*.jpg"))
    
    for file in files:
        try:
            img = cv2.imread(file)
            if img is None:
                print(f"Warning: Could not read image {file}")
                continue
                
            # Convert BGR to RGB (cv2 loads as BGR, but MobileNet expects RGB)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Resize image
            img = cv2.resize(img, target_size).astype(np.float32)
            
            # Preprocess for MobileNetV2
            img = preprocess_input(img)
            
            images.append(img)
            labels.append(label)
        except Exception as e:
            print(f"Error processing {file}: {e}")
    
    return images, labels

# Load data from exactly two folders
def load_binary_data():
    base_path = "data/"
    
    # Define your two categories - modify these to match your folder names
    category_a = "plastic"  # First category folder name
    category_b = "waste"  # Second category folder name
    
    # Initialize lists
    all_images = []
    all_labels = []
    
    # Load category A images (label 0)
    category_a_path = os.path.join(base_path, category_a)
    if not os.path.exists(category_a_path):
        print(f"Error: Directory {category_a_path} does not exist")
        return None, None
            
    print(f"Loading images from {category_a}...")
    images, labels = load_images_from_directory(category_a_path, 0)  # Label 0 for category A
    all_images.extend(images)
    all_labels.extend(labels)
    
    # Load category B images (label 1)
    category_b_path = os.path.join(base_path, category_b)
    if not os.path.exists(category_b_path):
        print(f"Error: Directory {category_b_path} does not exist")
        return None, None
            
    print(f"Loading images from {category_b}...")
    images, labels = load_images_from_directory(category_b_path, 1)  # Label 1 for category B
    all_images.extend(images)
    all_labels.extend(labels)
    
    # Convert to numpy arrays
    X = np.array(all_images)
    y = np.array(all_labels)
    
    print(f"Loaded {len(X)} images with shape {X.shape}")
    
    # Check class distribution
    unique, counts = np.unique(y, return_counts=True)
    for u, c in zip(unique, counts):
        class_name = category_b if u == 1 else category_a
        print(f"Class {u} ({class_name}): {c} samples")
    
    return X, y, [category_a, category_b]  # Return category names for later use

# Load and prepare data
print("Loading dataset...")
X, y, categories = load_binary_data()

# Check if data loaded successfully
if X is None or y is None:
    print("Failed to load data. Please check the folder names and structure.")
    exit()

# Calculate class weights for imbalanced dataset
class_weights = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y),
    y=y
)
class_weights_dict = {i: weight for i, weight in enumerate(class_weights)}
print("Class weights:", class_weights_dict)

# Split data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Create data generators with augmentation
train_datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator()  # No augmentation for validation data

# Fix for "input ran out of data" warning by setting shuffle=True and setting proper batch sizes
batch_size = 32

# Calculate steps_per_epoch correctly
train_samples = len(X_train)
val_samples = len(X_val)
steps_per_epoch = train_samples // batch_size
validation_steps = val_samples // batch_size

# Ensure at least one step
steps_per_epoch = max(1, steps_per_epoch)
validation_steps = max(1, validation_steps)

# Configure the generators to shuffle data
train_generator = train_datagen.flow(
    X_train, y_train, 
    batch_size=batch_size,
    shuffle=True
)

val_generator = val_datagen.flow(
    X_val, y_val, 
    batch_size=batch_size,
    shuffle=False
)

print(f"Training with {steps_per_epoch} steps per epoch, {validation_steps} validation steps")

# Build model with proper output layer for binary classification
print("Building model...")
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Initially freeze base model layers
for layer in base_model.layers:
    layer.trainable = False

# Add classification head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(256, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
# Binary output with sigmoid activation
predictions = Dense(1, activation='sigmoid')(x)

# Create model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile model with binary crossentropy loss
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
)

# Add callbacks for better training
checkpoint_path = "best_model_binary.keras"
callbacks = [
    EarlyStopping(patience=10, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=1e-6, verbose=1),
    ModelCheckpoint(checkpoint_path, save_best_only=True, verbose=1)
]

# Step 1: Train with frozen base model layers
print("Phase 1: Training with frozen base model...")
history1 = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_generator,
    validation_steps=validation_steps,
    epochs=15,
    callbacks=callbacks,
    class_weight=class_weights_dict,
    verbose=1
)

# Step 2: Unfreeze the base model and fine-tune with a lower learning rate
print("Phase 2: Fine-tuning with unfrozen base model...")
# Unfreeze some layers in the base model
for layer in base_model.layers[-40:]:  # Unfreeze last 40 layers
    layer.trainable = True

# Recompile with a lower learning rate
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
)

# Continue training with unfrozen layers
history2 = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_generator,
    validation_steps=validation_steps,
    epochs=25,
    callbacks=callbacks,
    class_weight=class_weights_dict,
    verbose=1,
    initial_epoch=len(history1.history['accuracy'])  # Start from the last epoch
)

# Load the best model
model = tf.keras.models.load_model(checkpoint_path)

# Evaluate the model on full validation set
val_loss, val_acc, val_auc, val_precision, val_recall = model.evaluate(X_val, y_val, batch_size=batch_size)
print(f"Validation accuracy: {val_acc:.4f}")
print(f"Validation AUC: {val_auc:.4f}")
print(f"Validation precision: {val_precision:.4f}")
print(f"Validation recall: {val_recall:.4f}")

# Generate predictions for confusion matrix
y_pred = model.predict(X_val)
y_pred_classes = (y_pred > 0.5).astype(int).flatten()

# Generate confusion matrix
cm = confusion_matrix(y_val, y_pred_classes)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=categories, yticklabels=categories)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('binary_confusion_matrix.png')
plt.show()

# Print classification report
print("Classification Report:")
print(classification_report(y_val, y_pred_classes, target_names=categories))

# Create ROC curve
fpr, tpr, thresholds = roc_curve(y_val, y_pred)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.savefig('binary_roc_curve.png')
plt.show()

# Check what keys are available in history objects
print("Available keys in history1:", history1.history.keys())
print("Available keys in history2:", history2.history.keys())

# Combine histories from both training phases with proper key handling
history_combined = {}

# Add metrics that we're sure exist
history_combined['accuracy'] = history1.history['accuracy'] + history2.history['accuracy']
history_combined['val_accuracy'] = history1.history['val_accuracy'] + history2.history['val_accuracy']
history_combined['loss'] = history1.history['loss'] + history2.history['loss']
history_combined['val_loss'] = history1.history['val_loss'] + history2.history['val_loss']



# Plot training history
plt.figure(figsize=(15, 10))

# Accuracy subplot
plt.subplot(2, 2, 1)
plt.plot(history_combined['accuracy'])
plt.plot(history_combined['val_accuracy'])
plt.axvline(x=len(history1.history['accuracy']), color='r', linestyle='--')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation', 'Phase Change'], loc='lower right')

# Loss subplot
plt.subplot(2, 2, 2)
plt.plot(history_combined['loss'])
plt.plot(history_combined['val_loss'])
plt.axvline(x=len(history1.history['loss']), color='r', linestyle='--')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation', 'Phase Change'], loc='upper right')


plt.tight_layout()
plt.savefig('binary_training_history.png')
plt.show()

# Create visualizations for sample predictions
def visualize_predictions(X_val, y_val, model, categories, num_samples=10):
    # Get random samples
    indices = np.random.choice(range(len(X_val)), num_samples, replace=False)
    images = X_val[indices]
    true_labels = y_val[indices]
    
    # Get predictions
    preds = model.predict(images)
    pred_classes = (preds > 0.5).astype(int).flatten()
    
    # Plot
    plt.figure(figsize=(20, 10))
    for i, idx in enumerate(range(num_samples)):
        plt.subplot(2, 5, i+1)
        
        # Convert image back for visualization
        img = images[idx].copy()
        # Undo preprocessing (approximate)
        img = (img + 1) / 2
        img = np.clip(img, 0, 1)
        
        plt.imshow(img)
        correct = true_labels[idx] == pred_classes[idx]
        color = "green" if correct else "red"
        
        true_label_str = categories[true_labels[idx]]
        pred_label_str = categories[pred_classes[idx]]
        
        conf = preds[idx][0] if pred_classes[idx] == 1 else 1 - preds[idx][0]
        
        plt.title(f"True: {true_label_str}\nPred: {pred_label_str}\nConf: {conf:.2f}", 
                 color=color)
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('binary_sample_predictions.png')
    plt.show()

# Sample visualization
print("Creating visualization of sample predictions...")
visualize_predictions(X_val, y_val, model, categories)

# Save model
model.save(f"binary_classification_{categories[0]}_vs_{categories[1]}.keras")
print(f"Model saved as binary_classification_{categories[0]}_vs_{categories[1]}.keras")

print("Binary classification model training and analysis complete!")