In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import Xception
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import confusion_matrix, accuracy_score
import seaborn as sns
from tensorflow.keras.mixed_precision import set_global_policy

# Enable mixed precision
set_global_policy('mixed_float16')

# Configure TensorFlow for CPU (12 threads)
tf.config.threading.set_intra_op_parallelism_threads(12)
tf.config.threading.set_inter_op_parallelism_threads(12)

# Global constants
batch_size = 16
TRAIN_PERCENT = 0.8
img_size = (128, 128)
dataset_dir = "dataset"

# Get class names
class_names = sorted([d for d in os.listdir(dataset_dir) 
                     if os.path.isdir(os.path.join(dataset_dir, d)) and not d.startswith('.')])
print(f"Detected classes: {class_names}")
print(f"Total number of classes: {len(class_names)}")

# Count total images and compute class weights
total_images = 0
class_counts = {}
for class_name in class_names:
    class_dir = os.path.join(dataset_dir, class_name)
    num_images = len([f for f in os.listdir(class_dir) 
                     if os.path.isfile(os.path.join(class_dir, f))])
    class_counts[class_name] = num_images
    print(f"Class '{class_name}': {num_images} images")
    total_images += num_images
print(f"Total images in dataset: {total_images}")

# Compute class weights
total_samples = total_images
class_weights = {i: total_samples / (len(class_names) * count) for i, (class_name, count) in enumerate(class_counts.items())}
print("Class weights:", class_weights)

def process_image(file_path, label, img_size=(128, 128)):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, img_size)
    img.set_shape([img_size[0], img_size[1], 3])
    return img, label

# Data augmentation layers
augmentation = models.Sequential([
    layers.RandomRotation(0.0556),
    layers.RandomTranslation(0.1, 0.1),
    layers.RandomZoom(0.1),
    layers.RandomFlip("horizontal"),
    layers.RandomFlip("vertical"),
    layers.RandomBrightness(0.1),
    layers.RandomContrast(0.1)
])

def create_dataset(dataset_dir, class_names, img_size=(128, 128), train_percent=0.8):
    file_paths = []
    labels = []
    for class_idx, class_name in enumerate(class_names):
        class_dir = os.path.join(dataset_dir, class_name)
        files = [os.path.join(class_dir, f) for f in os.listdir(class_dir) 
                 if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        file_paths.extend(files)
        labels.extend([class_idx] * len(files))
        print(f"Processing class {class_name}... ({len(files)} images)")

    file_paths = tf.constant(file_paths)
    labels = tf.constant(labels)

    dataset = tf.data.Dataset.from_tensor_slices((file_paths, labels))
    dataset = dataset.shuffle(buffer_size=len(file_paths), seed=42)
    dataset = dataset.map(process_image, num_parallel_calls=tf.data.AUTOTUNE)

    total_size = len(file_paths)
    train_size = int(train_percent * total_size)
    train_dataset = dataset.take(train_size)
    test_dataset = dataset.skip(train_size)

    train_dataset = train_dataset.map(lambda x, y: (augmentation(x, training=True), y), 
                                     num_parallel_calls=tf.data.AUTOTUNE)

    train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return train_dataset, test_dataset

# Load datasets
train_dataset, test_dataset = create_dataset(dataset_dir, class_names, img_size)
print(f"Dataset loaded: ~{total_images} images, {img_size} resolution")
print(f"Training set: ~{int(total_images * TRAIN_PERCENT)} images")
print(f"Testing set: ~{total_images - int(total_images * TRAIN_PERCENT)} images")

def build_xception_model(input_shape=(128, 128, 3), num_classes=len(class_names)):
    base = Xception(include_top=False, input_shape=input_shape, weights='imagenet')
    
    base.trainable = False
    
    model = models.Sequential([
        base,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(),
        layers.Dense(64, activation='relu'),
        layers.Dropout(0.4),
        layers.Dense(num_classes, activation='softmax', dtype='float32')
    ])
    
    model.compile(optimizer=optimizers.Adam(learning_rate=1e-4),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    return model

# Initialize model
model = build_xception_model()

# Callbacks
callbacks = [
    EarlyStopping(patience=5, restore_best_weights=True),
    ModelCheckpoint('best_xception.h5', save_best_only=True)
]

# First training phase
print("Training Xception (frozen base)...")
history = model.fit(
    train_dataset,
    epochs=10,
    validation_data=test_dataset,
    callbacks=callbacks,
    class_weight=class_weights,
    verbose=1
)

# Fine-tuning phase: Unfreeze last 10 layers
model.layers[0].trainable = True
for layer in model.layers[0].layers[:-10]:
    layer.trainable = False

model.compile(optimizer=optimizers.Adam(learning_rate=1e-5),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

print("Fine-tuning Xception (last 10 layers)...")
history_finetune = model.fit(
    train_dataset,
    epochs=3,
    validation_data=test_dataset,
    callbacks=callbacks,
    class_weight=class_weights,
    verbose=1
)

# Combine histories for plotting
history.history['accuracy'] += history_finetune.history['accuracy']
history.history['val_accuracy'] += history_finetune.history['val_accuracy']
history.history['loss'] += history_finetune.history['loss']
history.history['val_loss'] += history_finetune.history['val_loss']

# Plot accuracy and loss
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.savefig('training_metrics.png')
plt.close()

# Confusion matrix
y_pred = []
y_true = []
for images, labels in test_dataset:
    preds = model.predict(images, verbose=0)
    y_pred.extend(np.argmax(preds, axis=1))
    y_true.extend(labels.numpy())
y_pred = np.array(y_pred)
y_true = np.array(y_true)

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix - Xception')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.close()

# Evaluate model
test_loss, test_acc = model.evaluate(test_dataset, verbose=0)
print(f"Xception Test Accuracy (after fine-tuning): {test_acc:.4f}")

# Save model
model.save('best_xception.h5')
print("Xception model saved as best_xception.h5")

Detected classes: ['Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']
Total number of classes: 10
Class 'Tomato_Bacterial_spot': 2127 images
Class 'Tomato_Early_blight': 1000 images
Class 'Tomato_Late_blight': 1909 images
Class 'Tomato_Leaf_Mold': 952 images
Class 'Tomato_Septoria_leaf_spot': 1771 images
Class 'Tomato_Spider_mites_Two_spotted_spider_mite': 1676 images
Class 'Tomato__Target_Spot': 1404 images
Class 'Tomato__Tomato_YellowLeaf__Curl_Virus': 3209 images
Class 'Tomato__Tomato_mosaic_virus': 373 images
Class 'Tomato_healthy': 1591 images
Total images in dataset: 16012
Class weights: {0: 0.752797367183827, 1: 1.6012, 2: 0.8387637506547931, 3: 1.6819327731092437, 4: 0.9041219649915302, 5: 0.9553699284009547, 6: 1.1404558404558405, 7: 0.4989716422561546,

KeyboardInterrupt: 