In [1]:
import tensorflow as tf
from tensorflow.keras.applications import Xception
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import os

# Define constants
IMG_HEIGHT, IMG_WIDTH = 224, 224  # Xception input size
BATCH_SIZE = 20
NUM_CLASSES = 10  # Adjust based on your tomato dataset classes
EPOCHS = 10
DATA_DIR = 'dataset'  # Update to your dataset path

# Data augmentation and preprocessing
train_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.xception.preprocess_input,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2  # 20% for validation
)

# Load training data
train_generator = train_datagen.flow_from_directory(
    DATA_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training'
)

# Load validation data
validation_generator = train_datagen.flow_from_directory(
    DATA_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation'
)

# Load Xception model with pre-trained weights
base_model = Xception(weights='imagenet', include_top=False, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))

# Freeze base model layers
base_model.trainable = False

# Add custom layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x)
predictions = Dense(NUM_CLASSES, activation='softmax')(x)

# Create final model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile model
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Print model summary
model.summary()

# Train the model
history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=validation_generator
)

# Fine-tuning: Unfreeze some layers of the base model
base_model.trainable = True
fine_tune_at = 100  # Unfreeze from this layer onwards
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

# Recompile with a lower learning rate for fine-tuning
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Continue training for fine-tuning
fine_tune_epochs = 5
history_fine = model.fit(
    train_generator,
    epochs=EPOCHS + fine_tune_epochs,
    initial_epoch=history.epoch[-1] + 1,
    validation_data=validation_generator
)

# Evaluate on validation set
val_loss, val_accuracy = model.evaluate(validation_generator)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f}")

# Generate predictions for confusion matrix
val_generator = train_datagen.flow_from_directory(
    DATA_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)
y_pred = model.predict(val_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = val_generator.classes

# Plot confusion matrix
cm = confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=val_generator.class_indices.keys(), yticklabels=val_generator.class_indices.keys())
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig('confusion_matrix.png')
plt.close()

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'] + history_fine.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'] + history_fine.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'] + history_fine.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'] + history_fine.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_history.png')
plt.close()

# Print classification report
print("Classification Report:")
print(classification_report(y_true, y_pred_classes, target_names=val_generator.class_indices.keys()))

# Save the model
model.save('xception_tomato_leaf_model.h5')

Found 12813 images belonging to 10 classes.
Found 3198 images belonging to 10 classes.


  self._warn_if_super_not_called()


Epoch 1/10
[1m641/641[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m593s[0m 921ms/step - accuracy: 0.5781 - loss: 1.2576 - val_accuracy: 0.7983 - val_loss: 0.6140
Epoch 2/10
[1m641/641[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m537s[0m 838ms/step - accuracy: 0.7482 - loss: 0.7296 - val_accuracy: 0.8133 - val_loss: 0.5485
Epoch 3/10
[1m641/641[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m546s[0m 852ms/step - accuracy: 0.7702 - loss: 0.6591 - val_accuracy: 0.8274 - val_loss: 0.4978
Epoch 4/10
[1m641/641[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m511s[0m 797ms/step - accuracy: 0.7990 - loss: 0.5953 - val_accuracy: 0.8418 - val_loss: 0.4756
Epoch 5/10
[1m641/641[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m539s[0m 841ms/step - accuracy: 0.7995 - loss: 0.5694 - val_accuracy: 0.8377 - val_loss: 0.4574
Epoch 6/10
[1m641/641[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m538s[0m 838ms/step - accuracy: 0.8135 - loss: 0.5476 - val_accuracy: 0.8465 - val_loss: 0.4558
Epoc



Classification Report:
                                             precision    recall  f1-score   support

                      Tomato_Bacterial_spot       0.97      0.99      0.98       425
                        Tomato_Early_blight       0.94      0.94      0.94       200
                         Tomato_Late_blight       0.99      0.97      0.98       381
                           Tomato_Leaf_Mold       0.98      0.99      0.99       190
                  Tomato_Septoria_leaf_spot       0.95      0.98      0.97       354
Tomato_Spider_mites_Two_spotted_spider_mite       0.98      0.96      0.97       335
                        Tomato__Target_Spot       0.94      0.91      0.93       280
      Tomato__Tomato_YellowLeaf__Curl_Virus       1.00      0.99      0.99       641
                Tomato__Tomato_mosaic_virus       0.89      1.00      0.94        74
                             Tomato_healthy       0.98      0.99      0.99       318

                                   accur