In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from tensorflow.keras.applications import MobileNetV2, DenseNet121
from tensorflow.keras.layers import (
    Input, Dense, Concatenate, GlobalAveragePooling2D, Conv2D,
    BatchNormalization, Activation, MultiHeadAttention, LayerNormalization, Reshape
)
from tensorflow.keras.models import Model
from tensorflow.keras import regularizers

# Parameters
img_height, img_width = 224, 224
batch_size = 32
epochs = 100
seed = 123  # For consistent split

# Path (your full dataset directory with all classes)
data_dir = '/kaggle/input/S2RMCMD/sugarcane-disease_9026img/train'

# 80% Training Dataset
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset='training',
    seed=seed,
    label_mode='int',
    image_size=(img_height, img_width),
    batch_size=batch_size,
    
)

# 20% Testing Dataset
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset='validation',
    seed=seed,
    label_mode='int',
    image_size=(img_height, img_width),
    batch_size=batch_size,

)

# Class names
class_names = train_ds.class_names

# Prefetch for performance
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Build model
input_layer = Input(shape=(img_height, img_width, 3))

mobilenet_base = MobileNetV2(include_top=False, weights='imagenet', input_tensor=input_layer)
densenet_base = DenseNet121(include_top=False, weights='imagenet', input_tensor=input_layer)

mobilenet_base.trainable = True
densenet_base.trainable = True

# Global Average Pooling branch
gap1 = GlobalAveragePooling2D()(mobilenet_base.output)
gap2 = GlobalAveragePooling2D()(densenet_base.output)
gap_concat = Concatenate()([gap1, gap2])

# Conv + Self-Attention branch - MobileNetV2
x1 = mobilenet_base.output
for filters in [32, 64, 128, 256]:
    x1 = Conv2D(filters, (1, 1), padding='same')(x1)
    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)

x1_reshape = Reshape((-1, x1.shape[-1]))(x1)
x1_attn = MultiHeadAttention(num_heads=4, key_dim=32)(x1_reshape, x1_reshape)
x1_attn = LayerNormalization()(x1_attn + x1_reshape)
x1 = Reshape((x1.shape[1], x1.shape[2], x1.shape[3]))(x1_attn)
x1 = GlobalAveragePooling2D()(x1)

# Conv + Self-Attention branch - DenseNet201
x2 = densenet_base.output
for filters in [32, 64, 128, 256]:
    x2 = Conv2D(filters, (1, 1), padding='same')(x2)
    x2 = BatchNormalization()(x2)
    x2 = Activation('relu')(x2)

x2_reshape = Reshape((-1, x2.shape[-1]))(x2)
x2_attn = MultiHeadAttention(num_heads=4, key_dim=32)(x2_reshape, x2_reshape)
x2_attn = LayerNormalization()(x2_attn + x2_reshape)
x2 = Reshape((x2.shape[1], x2.shape[2], x2.shape[3]))(x2_attn)
x2 = GlobalAveragePooling2D()(x2)

# Final concatenation
conv_concat = Concatenate()([x1, x2])
final_concat = Concatenate()([gap_concat, conv_concat])

# Output
output = Dense(len(class_names), activation='softmax', kernel_regularizer=regularizers.l2(0.01))(final_concat)
model = Model(inputs=input_layer, outputs=output)
model.summary()
# Compile
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)

# Save model
model.save('mobilenet_densenet_attention_80_20_split.h5')

# Evaluate on validation set
y_true, y_pred = [], []
for images, labels_batch in val_ds:
    preds = model.predict(images)
    y_pred.extend(np.argmax(preds, axis=1))
    y_true.extend(labels_batch.numpy())

# Confusion Matrix and Classification Report
print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))

# Plot Accuracy and Loss
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()
