In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Conv2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras import backend as K
import tensorflow as tf

# Custom Focal Loss Function
def focal_loss(gamma=2., alpha=0.25):
    def focal_loss_fixed(y_true, y_pred):
        y_true = tf.convert_to_tensor(y_true, tf.float32)
        y_pred = tf.convert_to_tensor(y_pred, tf.float32)
        alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)
        p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (1 - y_pred)
        focal_loss = -alpha_t * K.pow((1 - p_t), gamma) * K.log(p_t + K.epsilon())
        return K.mean(focal_loss)
    return focal_loss_fixed

# Paths
train_dir = "D:/DATASET/CNN/ballooning/train"
val_dir = "D:/DATASET/CNN/ballooning/test"

# Data Augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    channel_shift_range=50,
    fill_mode='nearest'
)
val_datagen = ImageDataGenerator(rescale=1./255)

# Data Generators
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(299, 299),
    batch_size=32,
    class_mode='binary'
)
val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(299, 299),
    batch_size=32,
    class_mode='binary'
)

# Compute Class Weights
y_train = train_generator.classes  # Get true labels
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)
class_weights_dict = dict(enumerate(class_weights))
print(f"Class weights: {class_weights_dict}")

# Initialize ResNet50 Base Model
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(299, 299, 3))
for layer in base_model.layers[:-10]:  # Fine-tune last 10 layers
    layer.trainable = False

# Add Custom Layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(1, activation='sigmoid')(x)  # Binary classification

# Final Model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile Model with Focal Loss
model.compile(
    optimizer=Adam(learning_rate=0.00001),
    loss=focal_loss(gamma=2., alpha=0.25),
    metrics=['accuracy']
)

# Train the Model
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=10,
    class_weight=class_weights_dict
)




In [None]:
# Save Model
model.save('D:/DATASET/Models/model_ballooning_focal.h5')

# Plot Training History
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

# Evaluate Model and Generate Classification Report
from sklearn.metrics import classification_report

val_steps = val_generator.samples // val_generator.batch_size
y_true = val_generator.classes
y_pred = (model.predict(val_generator, steps=val_steps) > 0.5).astype(int).flatten()

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=['Ballooning Absent', 'Ballooning Present']))

# SHAP Analysis (Optional)
import shap

# Create a SHAP explainer
explainer = shap.Explainer(model, train_generator)
shap_values = explainer(train_generator[0][0])

# Visualize SHAP Values for the First Image
shap.image_plot(shap_values, train_generator[0][0])