In [19]:
# =========================================================
# Install required packages (run in Colab / notebook)
# =========================================================
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install transformers timm
!pip install torch-geometric -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
!pip install scikit-learn

# =========================================================
# TensorFlow imports (updated for newer versions)
# =========================================================
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Dense, Flatten, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.losses import CategoricalCrossentropy

# Alternative imports for newer TensorFlow versions
try:
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
except ImportError:
    from tensorflow.keras.utils import ImageDataGenerator

try:
    from tensorflow.keras.optimizers import Adam
except ImportError:
    from tensorflow.keras.optimizers.legacy import Adam

# sklearn
from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score,
                           matthews_corrcoef, roc_curve, auc, accuracy_score,
                           precision_score, recall_score, f1_score)
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.utils.class_weight import compute_class_weight

# Replace vit-keras with Hugging Face TF ViT
from transformers import TFViTModel, ViTConfig

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# Alternative loss function import
try:
    from tensorflow.keras.losses import CategoricalFocalCrossentropy
except ImportError:
    CategoricalFocalCrossentropy = None

# Check TensorFlow version and available devices
print('TensorFlow version:', tf.__version__)
print('GPU Available:', tf.config.list_physical_devices('GPU'))

# Standard python imports used later
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random
import time
import gc

# Mount Google Drive (Colab)
from google.colab import drive
drive.mount('/content/drive')
train_path = '/content/drive/MyDrive/Colab Notebooks/New Plant Diseases Dataset(Augmented)/train'
test_path = '/content/drive/MyDrive/Colab Notebooks/New Plant Diseases Dataset(Augmented)/valid'

if not os.path.exists(train_path):
    print(f"Warning: Training path {train_path} does not exist.")
    print("Please update the paths or upload your dataset to Google Colab.")

if not os.path.exists(test_path):
    print(f"Warning: Test path {test_path} does not exist.")
    print("Please update the paths or upload your dataset to Google Colab.")

def display_sample_images(path):
    if os.path.exists(path):
        disease_list = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
        if len(disease_list) > 0:
            fig, axes = plt.subplots(2, 5, figsize=(15, 8))
            fig.suptitle('Sample Images of Leaf Diseases', fontsize=16)
            for i, disease in enumerate(disease_list[:10]):
                disease_path = os.path.join(path, disease)
                if os.path.exists(disease_path):
                    images = [f for f in os.listdir(disease_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                    if images:
                        sample_image = random.choice(images)
                        try:
                            img = tf.keras.utils.load_img(os.path.join(disease_path, sample_image))
                        except AttributeError:
                            img = tf.keras.preprocessing.image.load_img(os.path.join(disease_path, sample_image))
                        row, col = i // 5, i % 5
                        if row < 2 and col < 5:
                            axes[row, col].imshow(img)
                            axes[row, col].set_title(disease, fontsize=12)
                            axes[row, col].axis('off')
            for i in range(len(disease_list), 10):
                row, col = i // 5, i % 5
                if row < 2 and col < 5:
                    axes[row, col].axis('off')

            plt.tight_layout()
            plt.show()
    else:
        print(f"Path {path} not found. Skipping image display.")

# Display sample images
display_sample_images(train_path)

# Image Data Generators
train_datagen = ImageDataGenerator(
    rescale=1.0/255.0,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    brightness_range=[0.8, 1.2],
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest'
)

test_datagen = ImageDataGenerator(rescale=1.0/255.0)

# Load datasets
def create_data_generators():
    try:
        train_set = train_datagen.flow_from_directory(
            train_path,
            target_size=(224, 224),
            batch_size=32,
            class_mode='categorical'
        )

        test_set = test_datagen.flow_from_directory(
            test_path,
            target_size=(224, 224),
            batch_size=32,
            class_mode='categorical'
        )

        return train_set, test_set
    except Exception as e:
        print(f"Error loading datasets: {e}")
        print("Creating dummy datasets for demonstration...")
        # Create dummy data for testing
        return None, None

train_set, test_set = create_data_generators()

def display_class_distribution(train_set):
    if train_set is not None:
        # Get class names and counts
        disease_classes = list(train_set.class_indices.keys())
        class_counts = []

        for class_name in disease_classes:
            class_dir = os.path.join(train_path, class_name)
            if os.path.exists(class_dir):
                num_images = len([f for f in os.listdir(class_dir)
                                if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
                class_counts.append(num_images)
            else:
                class_counts.append(0)

        plt.figure(figsize=(12, 6))
        bars = plt.bar(range(len(disease_classes)), class_counts)
        plt.title('Class Distribution in Training Dataset')
        plt.xlabel('Disease Classes')
        plt.ylabel('Image Count')
        plt.xticks(range(len(disease_classes)), disease_classes, rotation=45, ha='right')

        # Add value labels on bars
        for bar, count in zip(bars, class_counts):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    str(count), ha='center', va='bottom')

        plt.tight_layout()
        plt.show()
    else:
        print("No training data available for class distribution plot.")

display_class_distribution(train_set)

# Function definitions
def predict_with_tta(model, image, n_aug=5):
    """Test Time Augmentation for better predictions"""
    if len(image.shape) == 3:
        image = np.expand_dims(image, axis=0)
    elif len(image.shape) == 4 and image.shape[0] == 1:
        # Already has batch dimension
        pass
    else:
        raise ValueError(f"Invalid image shape: {image.shape}")

    aug = ImageDataGenerator(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True
    )

    preds = []
    try:
        for _ in range(n_aug):
            # Create augmented version
            aug_img = aug.random_transform(image[0])
            aug_img = np.expand_dims(aug_img, axis=0)
            preds.append(model.predict(aug_img, verbose=0))

        return np.mean(preds, axis=0)
    except Exception as e:
        print(f"TTA failed, using original prediction: {e}")
        return model.predict(image, verbose=0)

def get_features(generator, feature_extractor):
    """Extract features from data generator"""
    features = []
    labels = []

    print("Extracting features...")
    try:
        for i in range(len(generator)):
            x, y = generator[i]
            batch_features = feature_extractor.predict(x, verbose=0)
            features.append(batch_features)
            labels.append(y)

            if i % 10 == 0:
                print(f"Processed batch {i+1}/{len(generator)}")

            # Clear memory periodically
            if i % 50 == 0:
                gc.collect()

        return np.vstack(features), np.vstack(labels)
    except Exception as e:
        print(f"Error during feature extraction: {e}")
        if features and labels:
            print("Returning partial features...")
            return np.vstack(features), np.vstack(labels)
        else:
            raise e

class SimpleGNN(torch.nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super(SimpleGNN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, num_classes)
        self.dropout = 0.5

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training, p=self.dropout)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

def create_graph(features, k=5):
    """Create k-NN graph from features"""
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(features)
    distances, indices = nbrs.kneighbors(features)

    # Remove self-connections
    indices = indices[:, 1:]

    sources = np.repeat(np.arange(len(features)), k)
    targets = indices.flatten()
    edge_index = torch.tensor(np.vstack([sources, targets]), dtype=torch.long)

    return edge_index

# =========================================================
# Vision Transformer Model (Hugging Face TFViT replacement)
# =========================================================
def create_vit_model(num_classes):
    """Create and compile ViT model using Hugging Face TFViTModel
       This replaces vit.vit_b16 from vit-keras to avoid tensorflow-addons dependency.
    """
    # Load pretrained TF ViT backbone
    config = ViTConfig.from_pretrained("google/vit-base-patch16-224")
    # set config.num_labels though we will add our own head
    config.num_labels = num_classes
    vit_backbone = TFViTModel.from_pretrained("google/vit-base-patch16-224", from_pt=False, config=config)

    # Freeze ViT backbone for transfer learning (same intent as original)
    vit_backbone.trainable = False

    # custom classification head (matches original head)
    inputs = Input(shape=(224, 224, 3))
    # Note: Your ImageDataGenerator already rescales images to [0,1], so pass inputs directly.
    # If you prefer to match HF preprocess exactly, integrate ViTImageProcessor on raw images before model input.
    vit_outputs = vit_backbone(inputs)  # TFViTModel returns a TFBaseModelOutput-like object
    # Use CLS token embedding (position 0)
    try:
        x = vit_outputs.last_hidden_state[:, 0, :]  # [batch, hidden_dim]
    except Exception:
        # fallback if the output shape differs; try to flatten the backbone output
        x = tf.keras.layers.Flatten()(vit_outputs[0])

    x = Dense(512, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)

    model.compile(
        optimizer=Adam(learning_rate=0.0001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    return model, vit_backbone

# Model Training and Evaluation Functions
def evaluate_model_corrected(model, test_generator, model_name, is_hybrid=False):
    """Evaluate model performance"""
    if test_generator is None:
        print(f"Cannot evaluate {model_name}: No test data available")
        return {}

    print(f"\nEvaluating {model_name}...")

    # Get predictions for all test data
    all_predictions = []
    all_true_labels = []

    for i in range(len(test_generator)):
        test_images, test_labels = test_generator[i]

        if is_hybrid:
            batch_preds = []
            for img in test_images:
                pred = model(np.expand_dims(img, axis=0))
                batch_preds.append(pred)
            predictions = np.vstack(batch_preds)
        else:
            predictions = model.predict(test_images, verbose=0)

        all_predictions.append(predictions)
        all_true_labels.append(test_labels)

    all_predictions = np.vstack(all_predictions)
    all_true_labels = np.vstack(all_true_labels)

    predicted_classes = np.argmax(all_predictions, axis=1)
    true_classes = np.argmax(all_true_labels, axis=1)

    # Classification Report
    class_names = list(test_generator.class_indices.keys())
    print(f'\n{model_name} Classification Report:')
    print(classification_report(true_classes, predicted_classes, target_names=class_names))

    # Confusion Matrix
    conf_matrix = confusion_matrix(true_classes, predicted_classes)
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'{model_name} - Confusion Matrix')
    plt.xlabel('Predicted Classes')
    plt.ylabel('True Classes')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

    # Performance Metrics
    metrics = {
        'Accuracy': accuracy_score(true_classes, predicted_classes),
        'Precision': precision_score(true_classes, predicted_classes, average='weighted', zero_division=0),
        'Recall': recall_score(true_classes, predicted_classes, average='weighted', zero_division=0),
        'F1 Score': f1_score(true_classes, predicted_classes, average='weighted', zero_division=0),
        'MCC': matthews_corrcoef(true_classes, predicted_classes)
    }

    print(f"\n{model_name} Performance Metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")

    return metrics

def main():
    global vit_classifier, feature_extractor, scaler, knn, gnn_model, edge_index

    if train_set is None or test_set is None:
        print("Cannot proceed without dataset. Please check your dataset paths.")
        return

    num_classes = len(train_set.class_indices)
    print(f"Number of classes: {num_classes}")

    try:
        vit_classifier = load_model('vit_leaf_classifier.keras')
        print("Loaded pretrained Vision Transformer model.")
    except Exception as e:
        print(f"Could not load existing model: {e}")
        print("Creating new Vision Transformer model...")
        vit_classifier, vit_base = create_vit_model(num_classes)

        # Training callbacks
        reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)
        early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

        # Train the model
        print("Training Vision Transformer...")
        start_time = time.time()

        vit_history = vit_classifier.fit(
            train_set,
            validation_data=test_set,
            epochs=100,
            steps_per_epoch=len(train_set),
            validation_steps=len(test_set),
            callbacks=[reduce_lr, early_stopping],
            verbose=1
        )

        training_time = time.time() - start_time
        print(f"Training completed in {training_time:.2f} seconds")

        # Save model
        vit_classifier.save('vit_leaf_classifier.keras')
        print("Saved trained Vision Transformer model.")

        # Plot training history
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(vit_history.history['accuracy'], label='Train Accuracy')
        plt.plot(vit_history.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(vit_history.history['loss'], label='Train Loss')
        plt.plot(vit_history.history['val_loss'], label='Validation Loss')
        plt.title('Loss Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()

        plt.tight_layout()
        plt.show()

    print("Setting up feature extractor...")
    feature_extractor = Model(
        inputs=vit_classifier.input,
        outputs=vit_classifier.layers[-3].output
    )

    # Extract features
    train_features, train_labels = get_features(train_set, feature_extractor)
    test_features, test_labels = get_features(test_set, feature_extractor)

    # Scale features
    print("Scaling features...")
    scaler = StandardScaler()
    train_features_scaled = scaler.fit_transform(train_features)
    test_features_scaled = scaler.transform(test_features)

    # Train KNN
    print("Training KNN classifier...")
    knn = KNeighborsClassifier(n_neighbors=5, weights='distance')
    knn.fit(train_features_scaled, np.argmax(train_labels, axis=1))

    # Create and train GNN
    print("Creating and training GNN...")
    edge_index = create_graph(train_features_scaled, k=5)
    gnn_model = SimpleGNN(
        num_features=train_features_scaled.shape[1],
        hidden_dim=64,
        num_classes=num_classes
    )

    optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.01)

    # Train GNN
    gnn_model.train()
    for epoch in range(100):
        optimizer.zero_grad()
        out = gnn_model(torch.FloatTensor(train_features_scaled), edge_index)
        loss = F.nll_loss(out, torch.LongTensor(np.argmax(train_labels, axis=1)))
        loss.backward()
        optimizer.step()

        if epoch % 20 == 0:
            print(f'GNN Epoch {epoch}, Loss: {loss.item():.4f}')

    #hybrid prediction function
    def hybrid_predict(image):
        """Hybrid prediction combining ViT, KNN, and GNN"""
        if len(image.shape) == 3:
            image = np.expand_dims(image, axis=0)

        # ViT prediction
        vit_pred = predict_with_tta(vit_classifier, image)

        # Feature extraction
        features = feature_extractor.predict(image, verbose=0)
        features_scaled = scaler.transform(features)

        # KNN prediction
        knn_pred = knn.predict_proba(features_scaled)

        # GNN prediction
        with torch.no_grad():
            gnn_model.eval()
            single_edge_index = torch.tensor([[0], [0]], dtype=torch.long)
            gnn_out = gnn_model(torch.FloatTensor(features_scaled), single_edge_index)
            gnn_pred = torch.exp(gnn_out).numpy()

        vit_conf = np.max(vit_pred)
        knn_conf = np.max(knn_pred)
        gnn_conf = np.max(gnn_pred)

        # Weighted average
        final_pred = 0.7 * vit_pred + 0.2 * knn_pred + 0.1 * gnn_pred
        return final_pred

    # Evaluate models
    print("\n" + "="*50)
    print("MODEL EVALUATION")
    print("="*50)

    # Evaluate ViT model
    vit_metrics = evaluate_model_corrected(vit_classifier, test_set, "Vision Transformer")

    # Evaluate Hybrid model
    hybrid_metrics = evaluate_model_corrected(hybrid_predict, test_set, "Hybrid ViT+KNN+GNN", is_hybrid=True)

    # Compare results
    if vit_metrics and hybrid_metrics:
        print("\n" + "="*50)
        print("MODEL COMPARISON")
        print("="*50)
        print(f"{'Metric':<15} {'ViT':<10} {'Hybrid':<10}")
        print("-" * 35)
        for metric in vit_metrics.keys():
            print(f"{metric:<15} {vit_metrics[metric]:<10.4f} {hybrid_metrics[metric]:<10.4f}")

if __name__ == "__main__":
    try:
        import tensorflow as tf
        if hasattr(tf.keras.mixed_precision, 'set_global_policy'):
            tf.keras.mixed_precision.set_global_policy('mixed_float16')
            print("Mixed precision enabled")
        else:
            print("Mixed precision not available in this TensorFlow version")
    except Exception as e:
        print(f"Mixed precision setup failed: {e}")

    main()


Looking in indexes: https://download.pytorch.org/whl/cu121
Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu121.html
TensorFlow version: 2.19.0
GPU Available: []


MessageError: Error: credential propagation was unsuccessful