In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import os
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Flatten, MaxPooling2D, BatchNormalization, Dropout, Input, Lambda, Conv2D, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import MobileNetV2
import tensorflow.keras.backend as K
import random
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from skimage.segmentation import slic
from skimage.util import img_as_float
from torch_geometric.data import Data, DataLoader
import time
import warnings
warnings.filterwarnings('ignore')

# Global Variables
train_dir = 'Dataset/pest/train'
test_dir = 'Dataset/pest/test'

# Define the necessary global variables if not already defined
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
train_dir = 'Dataset/pest/train'
test_dir = 'Dataset/pest/test' 

# Image dimensions
IMG_SIZE = (224, 224)
BATCH_SIZE = 32

# Common data loading for TensorFlow models
def load_tf_data(shuffle=True):
    train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        zca_epsilon=1e-06,
        rotation_range=30,
        width_shift_range=0.1,
        height_shift_range=0.2,
        shear_range=20,
        zoom_range=0.8,
        fill_mode="nearest",
        horizontal_flip=True,
        vertical_flip=True,
        validation_split=0.1,
        rescale=1./255
    )
    test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
    training = train_datagen.flow_from_directory(
        train_dir, 
        batch_size=BATCH_SIZE, 
        target_size=IMG_SIZE, 
        subset="training",
        shuffle=shuffle
    )
    validing = train_datagen.flow_from_directory(
        train_dir, 
        batch_size=BATCH_SIZE, 
        target_size=IMG_SIZE, 
        subset='validation', 
        shuffle=shuffle
    )
    testing = test_datagen.flow_from_directory(
        test_dir, 
        batch_size=BATCH_SIZE, 
        target_size=IMG_SIZE, 
        shuffle=shuffle
    )
    num_classes = len(training.class_indices)
    class_labels = list(training.class_indices.keys())
    return training, validing, testing, num_classes, class_labels

# Utility function to plot and save training history
def plot_training_history(history, model_name):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title(f'{model_name} - Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title(f'{model_name} - Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    plt.tight_layout()
    plt.savefig(f'{model_name}_training_history.png')
    plt.close()

# Utility function to evaluate and visualize predictions
def visualize_prediction(model, img_path, class_labels, model_name, is_tf_model=True):
    if not os.path.exists(img_path):
        print(f"Warning: Test image path {img_path} not found. Skipping prediction.")
        return
    # Load and display the test image
    test_img = cv2.imread(img_path)
    test_img = cv2.resize(test_img, IMG_SIZE)
    plt.figure(figsize=(6, 6))
    plt.imshow(cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB))
    plt.title("Test Image")
    plt.axis('off')
    plt.savefig(f'{model_name}_test_image.png')
    plt.close()
    if is_tf_model:
        # For TensorFlow models
        img_array = image.img_to_array(image.load_img(img_path, target_size=IMG_SIZE))
        img_array = np.expand_dims(img_array, axis=0) / 255.0
        prediction = model.predict(img_array)
        predicted_class = np.argmax(prediction[0])
        predicted_label = class_labels[predicted_class]
        # Plot prediction probabilities
        plt.figure(figsize=(10, 5))
        plt.bar(class_labels, prediction[0])
        plt.xlabel("Class")
        plt.ylabel("Probability")
        plt.title(f"{model_name} - Prediction: {predicted_label}")
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(f'{model_name}_prediction.png')
        plt.close()
        print(f"{model_name} Predicted Class: {predicted_label}")
    else:
        # For PyTorch models (handled differently by each model's predict function)
        print(f"{model_name} prediction visualization is handled by the model's own function")

# ----- MODEL 1: MobileNetV2 Transfer Learning -----
def build_mobilenetv2_model(num_classes):
    base_model = MobileNetV2(
        weights='imagenet',
        include_top=False,
        input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)
    )
    
    # Freeze the base model layers
    for layer in base_model.layers:
        layer.trainable = False
    
    model = Sequential([
        base_model,
        GlobalAveragePooling2D(),
        Dense(512, activation='relu'),
        BatchNormalization(),
        Dropout(0.5),
        Dense(128, activation='relu'),
        BatchNormalization(),
        Dropout(0.3),
        Dense(num_classes, activation='softmax')
    ])
    
    model.compile(
        optimizer=Adam(0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

def train_mobilenetv2(training, validing, testing, num_classes):
    print("\n----- Training MobileNetV2 Transfer Learning Model -----")
    model = build_mobilenetv2_model(num_classes)
    
    callbacks = [
        EarlyStopping(patience=100, restore_best_weights=True, verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=250, verbose=1)
    ]
    
    start_time = time.time()
    history = model.fit(
        training,
        validation_data=validing,
        epochs=500,
        callbacks=callbacks,
        verbose=1
    )
    training_time = time.time() - start_time
    
    # Evaluate on test data
    test_loss, test_acc = model.evaluate(testing, verbose=0)
    print(f"MobileNetV2 Test Accuracy: {test_acc * 100:.2f}%")
    print(f"Training time: {training_time:.2f} seconds")
    
    # Save model and plot training history
    model.save("mobilenetv2_model.h5")
    plot_training_history(history, "MobileNetV2")
    
    return model, test_acc, training_time

# ----- MODEL 2: Siamese Network -----
# Euclidean distance function
def euclidean_distance(vectors):
    x, y = vectors
    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
    return K.sqrt(K.maximum(sum_square, K.epsilon()))

def build_siamese_base():
    base_model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)),
        MaxPooling2D(2, 2),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D(2, 2),
        Conv2D(128, (3, 3), activation='relu'),
        MaxPooling2D(2, 2),
        Flatten(),
        Dense(128, activation='relu'),
        BatchNormalization(),
        Dropout(0.3)
    ])
    return base_model

def build_siamese_model():
    base_network = build_siamese_base()
    
    input_a = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
    input_b = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
    
    processed_a = base_network(input_a)
    processed_b = base_network(input_b)
    
    distance = Lambda(euclidean_distance)([processed_a, processed_b])
    output = Dense(1, activation='sigmoid')(distance)
    
    siamese_model = Model(inputs=[input_a, input_b], outputs=output)
    siamese_model.compile(optimizer=Adam(0.001), loss='binary_crossentropy', metrics=['accuracy'])
    return siamese_model

def create_pairs(images_by_class, num_pairs_per_class=100):
    pairs = []
    labels = []
    
    # Same class pairs (label=1)
    for class_images in images_by_class.values():
        if len(class_images) < 2:
            continue
            
        for _ in range(num_pairs_per_class):
            # Randomly select two images from the same class
            idx1, idx2 = random.sample(range(len(class_images)), 2)
            pairs.append([class_images[idx1], class_images[idx2]])
            labels.append(1)  # 1 indicates same class
    
    # Different class pairs (label=0)
    classes = list(images_by_class.keys())
    for _ in range(num_pairs_per_class * len(classes)):
        # Select two different classes
        class1, class2 = random.sample(classes, 2)
        
        # Select random images from these classes
        img1 = random.choice(images_by_class[class1])
        img2 = random.choice(images_by_class[class2])
        
        pairs.append([img1, img2])
        labels.append(0)  # 0 indicates different classes
    
    return np.array(pairs), np.array(labels)

def load_images_from_directory(directory):
    images_by_class = {}
    class_indices = {}
    
    # Get all subdirectories (classes)
    classes = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
    
    for i, class_name in enumerate(classes):
        class_indices[class_name] = i
        class_dir = os.path.join(directory, class_name)
        images_by_class[class_name] = []
        
        # Get all image files (limit to 50 per class to manage memory)
        image_files = [f for f in os.listdir(class_dir) if f.endswith(('.jpg', '.jpeg', '.png'))][:50]
        
        for img_file in image_files:
            img_path = os.path.join(class_dir, img_file)
            img = cv2.imread(img_path)
            if img is None:
                continue
                
            img = cv2.resize(img, IMG_SIZE)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = img / 255.0  # Normalize to [0, 1]
            images_by_class[class_name].append(img)
            
    return images_by_class, class_indices

def prepare_pairs_for_model(pairs):
    left_input = []
    right_input = []
    
    for pair in pairs:
        left_input.append(pair[0])
        right_input.append(pair[1])
        
    return [np.array(left_input), np.array(right_input)]

def classify_with_siamese(siamese_model, reference_images_by_class, test_image, threshold=0.5):
    results = {}
    
    # For each class, compare the test image with all reference images
    for class_name, ref_images in reference_images_by_class.items():
        similarities = []
        
        for ref_img in ref_images:
            # Prepare the pair
            pair = prepare_pairs_for_model([[test_image, ref_img]])
            
            # Get similarity prediction (1 = same class, 0 = different class)
            similarity = siamese_model.predict(pair, verbose=0)[0][0]
            similarities.append(similarity)
        
        # Average similarity for this class
        avg_similarity = np.mean(similarities)
        results[class_name] = avg_similarity
    
    # Find the class with highest similarity
    predicted_class = max(results, key=results.get)
    confidence = results[predicted_class]
    
    # Only classify if confidence is above threshold
    if confidence > threshold:
        return predicted_class, results
    else:
        return "Unknown", results

def train_siamese_network():
    print("\n----- Training Siamese Network Model -----")
    # Load images for training
    train_images_by_class, class_indices = load_images_from_directory(train_dir)
    print(f"Found {len(class_indices)} classes")
    # Create pairs for training
    print("Creating image pairs for training...")
    pairs, labels = create_pairs(train_images_by_class)
    # Split into train and validation
    train_pairs, val_pairs, train_labels, val_labels = train_test_split(
        pairs, labels, test_size=0.2, random_state=42, shuffle=True
    )
    # Prepare data for model
    train_pair_data = prepare_pairs_for_model(train_pairs)
    val_pair_data = prepare_pairs_for_model(val_pairs)
    # Build and train model
    siamese_model = build_siamese_model()
    # Define callbacks
    early_stopping = EarlyStopping(monitor='val_loss', patience=100, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=250, min_lr=1e-6, verbose=1)
    # Train the model
    start_time = time.time()
    history = siamese_model.fit(
        train_pair_data, train_labels,
        validation_data=(val_pair_data, val_labels),
        epochs=500,
        batch_size=BATCH_SIZE,
        callbacks=[early_stopping, reduce_lr],
        verbose=1
    )
    training_time = time.time() - start_time
    # Save model
    siamese_model.save("siamese_model.h5")
    # Plot training history
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Siamese Network - Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Siamese Network - Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    plt.tight_layout()
    plt.savefig('siamese_training_history.png')
    plt.close()
    
    # Evaluate on test data
    print("Loading test images for evaluation...")
    test_images_by_class, _ = load_images_from_directory(test_dir)
    
    # Create pairs for testing
    print("Creating image pairs for testing...")
    test_pairs, test_labels = create_pairs(test_images_by_class, num_pairs_per_class=50)
    test_pair_data = prepare_pairs_for_model(test_pairs)
    
    # Evaluate
    test_loss, test_acc = siamese_model.evaluate(test_pair_data, test_labels, verbose=0)
    print(f"Siamese Network Test Accuracy: {test_acc * 100:.2f}%")
    print(f"Training time: {training_time:.2f} seconds")
    
    # Test on a single image
    img_test_path = 'Dataset/pest/test/beetle/jpg_33.jpg'
    if os.path.exists(img_test_path):
        # Load and preprocess the test image
        img = cv2.imread(img_test_path)
        img = cv2.resize(img, IMG_SIZE)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img / 255.0
        
        # Classify the test image
        predicted_class, class_similarities = classify_with_siamese(
            siamese_model, train_images_by_class, img
        )
        
        print(f"Siamese Network Predicted Class: {predicted_class}")
        
        # Plot class similarities
        plt.figure(figsize=(10, 5))
        classes = list(class_similarities.keys())
        similarities = list(class_similarities.values())
        
        plt.bar(classes, similarities)
        plt.xlabel("Class")
        plt.ylabel("Similarity Score")
        plt.title("Siamese Network - Class Similarity Scores")
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig('siamese_class_similarities.png')
        plt.close()
    
    return siamese_model, test_acc, training_time

# ----- MODEL 3: GNN Model -----
# Only run if PyTorch and PyTorch Geometric are available
def train_gnn_model():
    try:
        print("\n----- Training Graph Neural Network Model -----")
        # Function to convert an image into a graph
        def image_to_graph(img_path, label):
            img = cv2.imread(img_path)
            if img is None:
                return None
            
            img = cv2.resize(img, IMG_SIZE)
            img = img_as_float(img)
            
            try:
                segments = slic(img, n_segments=100, compactness=10)
                
                nodes = np.unique(segments)
                features = []
                
                for node in nodes:
                    mask = segments == node
                    if np.sum(mask) > 0:
                        mean_color = np.mean(img[mask], axis=0)
                        std_color = np.std(img[mask], axis=0)
                        features.append(np.concatenate([mean_color, std_color]))
                
                features = np.array(features)
                
                # Create spatial edge connections
                edges = []
                for i in range(len(nodes)):
                    for j in range(i + 1, len(nodes)):
                        if are_segments_adjacent(segments, nodes[i], nodes[j]):
                            edges.append([i, j])
                            edges.append([j, i])
                
                if len(edges) == 0:  # Fallback if no adjacency detected
                    for i in range(len(nodes)):
                        for j in range(i + 1, min(i + 5, len(nodes))):
                            edges.append([i, j])
                            edges.append([j, i])
                
                edges = torch.tensor(edges, dtype=torch.long).t().contiguous()
                features = torch.tensor(features, dtype=torch.float)
                
                # Create a single label for the whole graph
                y = torch.tensor(label, dtype=torch.long)
                
                return Data(x=features, edge_index=edges, y=y)
            
            except Exception as e:
                print(f"Error processing image: {e}")
                return None

        # Check if two segments are adjacent
        def are_segments_adjacent(segments, seg1, seg2):
            mask1 = segments == seg1
            mask2 = segments == seg2
            
            kernel = np.ones((3, 3), np.uint8)
            dilated_mask1 = cv2.dilate(mask1.astype(np.uint8), kernel, iterations=1)
            
            return np.any(dilated_mask1 & mask2)

        # GNN Model
        class GNNModel(nn.Module):
            def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.3):
                super(GNNModel, self).__init__()
                self.conv1 = GATConv(input_dim, hidden_dim)
                self.conv2 = GATConv(hidden_dim, hidden_dim)
                self.batch_norm1 = nn.BatchNorm1d(hidden_dim)
                self.batch_norm2 = nn.BatchNorm1d(hidden_dim)
                self.fc = nn.Linear(hidden_dim, output_dim)
                self.relu = nn.ReLU()
                self.dropout = nn.Dropout(dropout_rate)
            
            def forward(self, data):
                x, edge_index, batch = data.x, data.edge_index, data.batch
                
                if batch is None:
                    batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
                
                x = self.conv1(x, edge_index)
                x = self.batch_norm1(x)
                x = self.relu(x)
                x = self.dropout(x)
                
                x = self.conv2(x, edge_index)
                x = self.batch_norm2(x)
                x = self.relu(x)
                x = self.dropout(x)
                
                x = global_mean_pool(x, batch)
                x = self.fc(x)
                
                return x

        # Function to create dataset from directory
        def create_dataset(root_dir, class_indices):
            dataset = []
            class_counts = {}
            
            for class_name, idx in class_indices.items():
                class_dir = os.path.join(root_dir, class_name)
                if not os.path.isdir(class_dir):
                    continue
                    
                image_files = [f for f in os.listdir(class_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
                
                max_images = 50
                selected_files = image_files[:min(max_images, len(image_files))]
                class_counts[class_name] = len(selected_files)
                
                for img_file in selected_files:
                    img_path = os.path.join(class_dir, img_file)
                    graph_data = image_to_graph(img_path, idx)
                    if graph_data is not None:
                        dataset.append(graph_data)
            
            print(f"Dataset creation summary: {class_counts}")
            return dataset

        # Get class indices from TensorFlow's ImageDataGenerator for consistency
        datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
        temp_gen = datagen.flow_from_directory(
            train_dir, 
            batch_size=1,
            target_size=IMG_SIZE,
            shuffle=True
        )
        class_indices = temp_gen.class_indices
        num_classes = len(class_indices)
        
        # Create datasets
        print("Creating training dataset...")
        train_dataset = create_dataset(train_dir, class_indices)
        print("Creating testing dataset...")
        test_dataset = create_dataset(test_dir, class_indices)
        
        if not train_dataset:
            raise ValueError("No training data could be created.")
        
        # Split training data into train and validation
        train_data, val_data = train_test_split(train_dataset, test_size=0.2, random_state=42, shuffle=True)
        
        # Create data loaders
        train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=8, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)
        
        # Get feature dimension from data
        input_dim = train_data[0].x.shape[1] if train_data else 6
        
        # Create and train model
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        model = GNNModel(input_dim=input_dim, hidden_dim=64, output_dim=num_classes).to(device)
        
        optimizer = optim.Adam(model.parameters(), lr=0.1, weight_decay=1e-4)
        criterion = nn.CrossEntropyLoss()
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
        
        # Training loop with early stopping
        best_val_loss = float('inf')
        patience = 100
        counter = 0
        early_stop = False
        epochs = 500
        
        train_losses = []
        val_losses = []
        train_accs = []
        val_accs = []
        
        start_time = time.time()
        
        for epoch in range(epochs):
            if early_stop:
                print("Early stopping triggered!")
                break
                
            # Training
            model.train()
            total_loss = 0
            correct = 0
            total = 0
            
            for batch in train_loader:
                batch = batch.to(device)
                optimizer.zero_grad()
                
                try:
                    output = model(batch)
                    loss = criterion(output, batch.y)
                    loss.backward()
                    
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    
                    optimizer.step()
                    
                    total_loss += loss.item()
                    _, predicted = output.max(dim=1)
                    total += batch.y.size(0)
                    correct += predicted.eq(batch.y).sum().item()
                except Exception as e:
                    continue
            
            if total > 0:
                train_loss = total_loss / len(train_loader)
                train_acc = 100.0 * correct / total
                train_losses.append(train_loss)
                train_accs.append(train_acc)
            else:
                continue
            
            # Validation
            model.eval()
            val_loss = 0
            correct = 0
            total = 0
            
            with torch.no_grad():
                for batch in val_loader:
                    batch = batch.to(device)
                    try:
                        output = model(batch)
                        loss = criterion(output, batch.y)
                        
                        val_loss += loss.item()
                        _, predicted = output.max(dim=1)
                        total += batch.y.size(0)
                        correct += predicted.eq(batch.y).sum().item()
                    except Exception as e:
                        continue
            
            if total > 0:
                val_loss = val_loss / len(val_loader)
                val_acc = 100.0 * correct / total
                val_losses.append(val_loss)
                val_accs.append(val_acc)
                
                scheduler.step(val_loss)
                
                print(f"Epoch {epoch+1}/{epochs}, "
                      f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
                      f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
                
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    counter = 0
                    torch.save(model.state_dict(), "best_gnn_model.pth")
                else:
                    counter += 1
                    if counter >= patience:
                        early_stop = True
            else:
                print("Warning: No valid batches in validation epoch")
        
        training_time = time.time() - start_time
        
        # Plot training curves
        if train_losses and val_losses:
            plt.figure(figsize=(12, 5))
            plt.subplot(1, 2, 1)
            plt.plot(train_losses, label='Train Loss')
            plt.plot(val_losses, label='Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            
            plt.subplot(1, 2, 2)
            plt.plot(train_accs, label='Train Accuracy')
            plt.plot(val_accs, label='Validation Accuracy')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy (%)')
            plt.legend()
            plt.savefig('gnn_training_curves.png')
            plt.close()
        
        # Evaluate best model on test set
        model.load_state_dict(torch.load("best_gnn_model.pth"))
        model.eval()
        correct = 0
        total = 0
        
        idx_to_class = {v: k for k, v in class_indices.items()}
        
        with torch.no_grad():
            for batch in test_loader:
                batch = batch.to(device)
                try:
                    output = model(batch)
                    _, predicted = output.max(dim=1)
                    
                    total += batch.y.size(0)
                    correct += predicted.eq(batch.y).sum().item()
                except Exception as e:
                    continue
        
        if total > 0:
            test_acc = 100.0 * correct / total
            print(f"GNN Test Accuracy: {test_acc:.2f}%")
        else:
            test_acc = 0
            print("Warning: No valid batches in test evaluation")
            
        print(f"Training time: {training_time:.2f} seconds")
        
        return model, test_acc, training_time
    except Exception as e:
        print(f"Could not train GNN model: {e}")
        return None, 0, 0


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class PatchExtractor(tf.keras.layers.Layer):
    def __init__(self, patch_size):
        super(PatchExtractor, self).__init__()
        self.patch_size = patch_size
    
    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID"
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

# ----- MODEL 4: Vision Transformer -----
def build_vit_model(num_classes, image_size=224, patch_size=16, num_heads=8, 
                   transformer_layers=8, hidden_units=64, mlp_units=128):  
    input_shape = (image_size, image_size, 3)
    num_patches = (image_size // patch_size) ** 2
    projection_dim = hidden_units
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape) 
    # Create patches and project them
    patches = PatchExtractor(patch_size)(inputs)
    patch_projection = tf.keras.layers.Dense(projection_dim)(patches)
    # Add positional embeddings
    positions = tf.range(start=0, limit=num_patches, delta=1)
    position_embedding = tf.keras.layers.Embedding(
        input_dim=num_patches, output_dim=projection_dim
    )(positions)
    encoded_patches = patch_projection + position_embedding
    # Create transformer blocks
    for _ in range(transformer_layers):
        # Layer normalization 1
        x1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Multi-head attention
        attention_output = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim // num_heads, dropout=0.1
        )(x1, x1)
        
        # Skip connection 1
        x2 = tf.keras.layers.Add()([attention_output, encoded_patches])
        
        # Layer normalization 2
        x3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x2)
        
        # MLP
        x4 = tf.keras.layers.Dense(mlp_units, activation="gelu")(x3)
        x4 = tf.keras.layers.Dropout(0.1)(x4)
        x4 = tf.keras.layers.Dense(projection_dim)(x4)
        x4 = tf.keras.layers.Dropout(0.1)(x4)
        
        # Skip connection 2
        encoded_patches = tf.keras.layers.Add()([x4, x2])

    # Final layer normalization and global pooling
    representation = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = tf.keras.layers.GlobalAveragePooling1D()(representation)
    # Classification head
    representation = tf.keras.layers.Dropout(0.3)(representation)
    features = tf.keras.layers.Dense(256, activation="relu")(representation)
    features = tf.keras.layers.BatchNormalization()(features)
    features = tf.keras.layers.Dropout(0.3)(features)
    outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(features)
    # Create the model
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    # Compile the model
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=["accuracy"]
    )
    
    return model

def train_vit_model(training, validing, testing, num_classes):
    print("\n----- Training Vision Transformer Model -----")
    model = build_vit_model(
        num_classes=num_classes,
        image_size=IMG_SIZE[0],
        patch_size=16,
        num_heads=8,
        transformer_layers=6,
        hidden_units=64,
        mlp_units=128
    )
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6, verbose=1)
    ]
    start_time = time.time()
    history = model.fit(
        training,
        validation_data=validing,
        epochs=500,
        callbacks=callbacks,
        verbose=1
    )
    training_time = time.time() - start_time
    # Evaluate on test data
    test_loss, test_acc = model.evaluate(testing, verbose=0)
    print(f"Vision Transformer Test Accuracy: {test_acc * 100:.2f}%")
    print(f"Training time: {training_time:.2f} seconds")
    # Save model and plot training history
    model.save("vit_model.h5")
    plot_training_history(history, "VisionTransformer")
    return model, test_acc, training_time

# Main function to run all models and compare results
def main():
    # Load data
    training, validing, testing, num_classes, class_labels = load_tf_data()
    print(f"Number of classes: {num_classes}")
    print(f"Class labels: {class_labels}")
    # Dictionary to store results
    results = {}
    # Train MobileNetV2 model
    mobilenetv2_model, mobilenetv2_acc, mobilenetv2_time = train_mobilenetv2(training, validing, testing, num_classes)
    results["MobileNetV2"] = {"accuracy": mobilenetv2_acc, "training_time": mobilenetv2_time}
    # Visualize prediction for a sample image
    sample_img = 'Dataset/pest/test/beetle/jpg_33.jpg'
    if os.path.exists(sample_img):
        visualize_prediction(mobilenetv2_model, sample_img, class_labels, "MobileNetV2")
    # Train Siamese Network
    siamese_model, siamese_acc, siamese_time = train_siamese_network()
    results["Siamese"] = {"accuracy": siamese_acc, "training_time": siamese_time}
    # Train GNN model if possible
    try:
        gnn_model, gnn_acc, gnn_time = train_gnn_model()
        if gnn_model is not None:
            results["GNN"] = {"accuracy": gnn_acc, "training_time": gnn_time}
    except Exception as e:
        print(f"Could not train GNN model: {e}")
    
    # Train Vision Transformer model
    vit_model, vit_acc, vit_time = train_vit_model(training, validing, testing, num_classes)
    results["VisionTransformer"] = {"accuracy": vit_acc, "training_time": vit_time}
    
    if os.path.exists(sample_img):
        visualize_prediction(vit_model, sample_img, class_labels, "VisionTransformer")
    
    # Compare model performances
    print("\n----- Model Performance Comparison -----")
    for model_name, metrics in results.items():
        print(f"{model_name}: Accuracy = {metrics['accuracy'] * 100:.2f}%, Training Time = {metrics['training_time']:.2f} seconds")
    # Plot comparison chart
    plt.figure(figsize=(12, 6))
    # Accuracy comparison
    plt.subplot(1, 2, 1)
    model_names = list(results.keys())
    accuracies = [results[model]["accuracy"] * 100 for model in model_names]
    plt.bar(model_names, accuracies, color='skyblue')
    plt.ylabel('Accuracy (%)')
    plt.title('Model Accuracy Comparison')
    plt.ylim([0, 100])
    
    # Training time comparison
    plt.subplot(1, 2, 2)
    training_times = [results[model]["training_time"] for model in model_names]
    plt.bar(model_names, training_times, color='salmon')
    plt.ylabel('Training Time (seconds)')
    plt.title('Model Training Time Comparison')
    
    plt.tight_layout()
    plt.savefig('model_comparison.png')
    plt.close()
    
    print("Comparison chart saved as 'model_comparison.png'")

In [None]:

if __name__ == "__main__":
    main()