<a href="https://colab.research.google.com/github/Kpreya/Real-Time-crop-disease-analysis-and-prevention-recommendation/blob/main/annam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install opencv-python
!pip install seaborn
import logging
logging.basicConfig(level=logging.INFO)

import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import cv2
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
import requests
import json
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class WheatPestDatasetBuilder:
    """Class to build and manage wheat pest dataset with weather integration"""

    def __init__(self, data_dir="wheat_pest_data", img_size=(224, 224)):
        self.data_dir = Path(data_dir)
        self.img_size = img_size
        self.weather_scaler = StandardScaler()
        self.label_encoder = LabelEncoder()

        # Common wheat pests in India
        self.wheat_pests = [
            'healthy',
            'wheat_rust',
            'aphids',
            'stem_borer',
            'armyworm',
            'bacterial_blight',
            'powdery_mildew',
            'leaf_spot'
        ]

        # Create directory structure
        self.setup_directories()



    def setup_directories(self):
        """Setup directory structure for the dataset"""
        directories = [
            self.data_dir / "images" / "train",
            self.data_dir / "images" / "val",
            self.data_dir / "images" / "test",
            self.data_dir / "weather",
            self.data_dir / "processed",
            self.data_dir / "models"
        ]

        for dir_path in directories:
            dir_path.mkdir(parents=True, exist_ok=True)

        # Create pest class subdirectories
        for split in ['train', 'val', 'test']:
            for pest in self.wheat_pests:
                (self.data_dir / "images" / split / pest).mkdir(exist_ok=True)

    def download_sample_dataset(self):
        """Download and organize sample wheat pest images"""
        logger.info("Setting up sample dataset structure...")

        # Create sample metadata
        sample_data = []
        districts = ['Haryana_Karnal', 'Punjab_Ludhiana', 'UP_Meerut', 'MP_Indore']

        for i in range(1000):  # Sample 1000 entries
            pest_class = np.random.choice(self.wheat_pests)
            district = np.random.choice(districts)
            date = datetime.now() - timedelta(days=np.random.randint(0, 365))

            sample_data.append({
                'image_id': f'wheat_{i:04d}.jpg',
                'pest_class': pest_class,
                'district': district,
                'state': district.split('_')[0],
                'location': district.split('_')[1],
                'date': date.strftime('%Y-%m-%d'),
                'severity': np.random.choice(['mild', 'moderate', 'severe']),
                'crop_stage': np.random.choice(['seedling', 'tillering', 'heading', 'maturity'])
            })

        # Save metadata
        df = pd.DataFrame(sample_data)
        df.to_csv(self.data_dir / "wheat_pest_metadata.csv", index=False)
        logger.info(f"Created metadata for {len(df)} samples")

        return df

    def fetch_weather_data(self, district, date, api_key=None):
        """Fetch weather data for given district and date"""
        # Mock weather data (replace with actual IMD API calls)
        np.random.seed(hash(f"{district}_{date}") % 2**32)

        weather_data = {
            'date': date,
            'district': district,
            'temperature_max': np.random.normal(28, 5),
            'temperature_min': np.random.normal(15, 3),
            'humidity': np.random.normal(65, 15),
            'rainfall': np.random.exponential(2),
            'wind_speed': np.random.normal(8, 3),
            'soil_moisture': np.random.normal(45, 10)
        }

        return weather_data

    def build_weather_dataset(self, metadata_df):
        """Build weather dataset corresponding to pest images"""
        weather_data = []

        logger.info("Fetching weather data...")
        for _, row in metadata_df.iterrows():
            weather = self.fetch_weather_data(row['district'], row['date'])
            weather['image_id'] = row['image_id']
            weather_data.append(weather)

        weather_df = pd.DataFrame(weather_data)
        weather_df.to_csv(self.data_dir / "weather_data.csv", index=False)

        return weather_df

    def create_synthetic_images(self, metadata_df):
        """Create synthetic wheat pest images for demonstration"""
        logger.info("Creating synthetic images...")

        for _, row in metadata_df.iterrows():
            # Create synthetic image based on pest class
            img = self.generate_synthetic_pest_image(row['pest_class'])

            # Determine split (80% train, 10% val, 10% test)
            rand_val = np.random.random()
            if rand_val < 0.8:
                split = 'train'
            elif rand_val < 0.9:
                split = 'val'
            else:
                split = 'test'

            # Save image
            img_path = self.data_dir / "images" / split / row['pest_class'] / row['image_id']
            cv2.imwrite(str(img_path), img)

    def generate_synthetic_pest_image(self, pest_class):
        """Generate synthetic pest image based on class"""
        # Create base wheat crop image
        img = np.random.randint(50, 200, (*self.img_size, 3), dtype=np.uint8)

        # Add pest-specific patterns
        if pest_class == 'wheat_rust':
            # Add rust-like orange/brown spots
            for _ in range(np.random.randint(5, 15)):
                center = (np.random.randint(0, self.img_size[0]), np.random.randint(0, self.img_size[1]))
                radius = np.random.randint(5, 20)
                cv2.circle(img, center, radius, (30, 100, 200), -1)  # Orange-brown color

        elif pest_class == 'aphids':
            # Add small green/black dots
            for _ in range(np.random.randint(10, 30)):
                center = (np.random.randint(0, self.img_size[0]), np.random.randint(0, self.img_size[1]))
                cv2.circle(img, center, 2, (50, 50, 200), -1)

        elif pest_class == 'powdery_mildew':
            # Add white powdery patches
            for _ in range(np.random.randint(3, 8)):
                center = (np.random.randint(0, self.img_size[0]), np.random.randint(0, self.img_size[1]))
                axes = (np.random.randint(15, 30), np.random.randint(10, 25))
                cv2.ellipse(img, center, axes, 0, 0, 360, (200, 200, 200), -1)

        # Add some noise and blur for realism
        img = cv2.GaussianBlur(img, (3, 3), 0)
        noise = np.random.randint(-20, 20, img.shape, dtype=np.int16)
        img = np.clip(img.astype(np.int16) + noise, 0, 255).astype(np.uint8)

        return img

class WeatherImageFusionModel:
    """MSFNet-inspired model for fusing wheat pest images with weather data"""

    def __init__(self, num_classes, img_size=(224, 224), weather_features=6):
        self.num_classes = num_classes
        self.img_size = img_size
        self.weather_features = weather_features
        self.model = None

    def build_image_branch(self):
        """Build CNN branch for image processing"""
        img_input = layers.Input(shape=(*self.img_size, 3), name='image_input')

        # CNN backbone (similar to ResNet blocks)
        x = layers.Conv2D(64, (7, 7), strides=2, padding='same')(img_input)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.MaxPooling2D((3, 3), strides=2, padding='same')(x)

        # Residual blocks
        for filters in [64, 128, 256, 512]:
            x = self.residual_block(x, filters)
            x = self.residual_block(x, filters)
            if filters < 512:
                x = layers.MaxPooling2D((2, 2))(x)

        # Global pooling
        x = layers.GlobalAveragePooling2D()(x)
        img_features = layers.Dense(256, activation='relu', name='img_features')(x)

        return img_input, img_features

    def residual_block(self, x, filters):
        """Residual block for CNN"""
        shortcut = x

        x = layers.Conv2D(filters, (3, 3), padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)

        x = layers.Conv2D(filters, (3, 3), padding='same')(x)
        x = layers.BatchNormalization()(x)

        # Adjust shortcut if needed
        if shortcut.shape[-1] != filters:
            shortcut = layers.Conv2D(filters, (1, 1))(shortcut)

        x = layers.Add()([x, shortcut])
        x = layers.ReLU()(x)

        return x

    def build_weather_branch(self):
        """Build MLP branch for weather data processing"""
        weather_input = layers.Input(shape=(self.weather_features,), name='weather_input')

        x = layers.Dense(128, activation='relu')(weather_input)
        x = layers.Dropout(0.3)(x)
        x = layers.Dense(64, activation='relu')(x)
        x = layers.Dropout(0.3)(x)
        weather_features = layers.Dense(64, activation='relu', name='weather_features')(x)

        return weather_input, weather_features

    def build_fusion_model(self):
        """Build complete fusion model"""
        # Build branches
        img_input, img_features = self.build_image_branch()
        weather_input, weather_features = self.build_weather_branch()

        # Fusion mechanism - Cross-modal attention
        # Attention from weather to image
        weather_att = layers.Dense(256)(weather_features)
        weather_att = layers.Softmax()(weather_att)
        img_attended = layers.Multiply()([img_features, weather_att])

        # Attention from image to weather
        img_att = layers.Dense(64)(img_features)
        img_att = layers.Softmax()(img_att)
        weather_attended = layers.Multiply()([weather_features, img_att])

        # Concatenate attended features
        fused_features = layers.Concatenate()([img_attended, weather_attended])

        # Final classification layers
        x = layers.Dense(512, activation='relu')(fused_features)
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(256, activation='relu')(x)
        x = layers.Dropout(0.3)(x)

        output = layers.Dense(self.num_classes, activation='softmax', name='pest_prediction')(x)

        # Create model
        self.model = keras.Model(
            inputs=[img_input, weather_input],
            outputs=output,
            name='wheat_pest_weather_fusion'
        )

        return self.model

class WheatPestPipeline:
    """Complete pipeline for wheat pest detection with weather fusion"""

    def __init__(self, data_dir="wheat_pest_data"):
        self.data_dir = Path(data_dir)
        self.dataset_builder = WheatPestDatasetBuilder(data_dir)
        self.model_builder = None
        self.model = None

    def setup_dataset(self):
        """Setup complete dataset"""
        logger.info("Setting up wheat pest dataset...")

        # Create sample dataset
        metadata_df = self.dataset_builder.download_sample_dataset()

        # Build weather dataset
        weather_df = self.dataset_builder.build_weather_dataset(metadata_df)

        # Create synthetic images
        self.dataset_builder.create_synthetic_images(metadata_df)

        logger.info("Dataset setup complete!")
        return metadata_df, weather_df

    def load_and_preprocess_data(self):
        """Load and preprocess data for training"""
        logger.info("Loading and preprocessing data...")

        # Load metadata
        metadata_df = pd.read_csv(self.data_dir / "wheat_pest_metadata.csv")
        weather_df = pd.read_csv(self.data_dir / "weather_data.csv")

        # Merge datasets
        data_df = metadata_df.merge(weather_df, on='image_id', how='inner')

        # Prepare data generators
        train_gen, val_gen, test_gen = self.create_data_generators(data_df)

        return train_gen, val_gen, test_gen

    def create_data_generators(self, data_df):
        """Create data generators for training"""
        # Split data
        train_df = data_df[data_df['image_id'].apply(lambda x: self.get_image_split(x) == 'train')]
        val_df = data_df[data_df['image_id'].apply(lambda x: self.get_image_split(x) == 'val')]
        test_df = data_df[data_df['image_id'].apply(lambda x: self.get_image_split(x) == 'test')]

        # Create generators
        train_gen = self.create_generator(train_df, 'train', batch_size=32, shuffle=True)
        val_gen = self.create_generator(val_df, 'val', batch_size=32, shuffle=False)
        test_gen = self.create_generator(test_df, 'test', batch_size=32, shuffle=False)

        return train_gen, val_gen, test_gen

    def get_image_split(self, image_id):
        """Determine which split an image belongs to"""
        for split in ['train', 'val', 'test']:
            for pest_class in self.dataset_builder.wheat_pests:
                img_path = self.data_dir / "images" / split / pest_class / image_id
                if img_path.exists():
                    return split
        return 'train'  # Default

    def create_generator(self, df, split, batch_size=32, shuffle=True):
        """Create data generator"""
        def generator():
            indices = list(range(len(df)))
            if shuffle:
                np.random.shuffle(indices)

            for i in range(0, len(df), batch_size):
                batch_indices = indices[i:i+batch_size]
                batch_df = df.iloc[batch_indices]

                images = []
                weather_data = []
                labels = []

                for _, row in batch_df.iterrows():
                    # Load image
                    img_path = self.data_dir / "images" / split / row['pest_class'] / row['image_id']
                    if img_path.exists():
                        img = cv2.imread(str(img_path))
                        img = cv2.resize(img, self.dataset_builder.img_size)
                        img = img.astype(np.float32) / 255.0
                        images.append(img)

                        # Weather features
                        weather_features = [
                            row['temperature_max'], row['temperature_min'],
                            row['humidity'], row['rainfall'],
                            row['wind_speed'], row['soil_moisture']
                        ]
                        weather_data.append(weather_features)

                        # Label
                        labels.append(self.dataset_builder.wheat_pests.index(row['pest_class']))

                if images:  # Only yield if we have data
                    yield (
                        [np.array(images), np.array(weather_data)],
                        tf.keras.utils.to_categorical(labels, len(self.dataset_builder.wheat_pests))
                    )

        return generator

    def build_and_train_model(self, train_gen, val_gen, epochs=50):
        """Build and train the fusion model"""
        logger.info("Building fusion model...")

        # Initialize model builder
        self.model_builder = WeatherImageFusionModel(
            num_classes=len(self.dataset_builder.wheat_pests),
            img_size=self.dataset_builder.img_size
        )

        # Build model
        self.model = self.model_builder.build_fusion_model()

        # Compile model
        self.model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=0.001),
            loss='categorical_crossentropy',
            metrics=['accuracy', 'top_3_accuracy']
        )

        logger.info("Model architecture:")
        self.model.summary()

        # Callbacks
        callbacks = [
            keras.callbacks.ModelCheckpoint(
                str(self.data_dir / "models" / "best_model.h5"),
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1
            ),
            keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=5,
                verbose=1
            ),
            keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=10,
                verbose=1,
                restore_best_weights=True
            )
        ]

        # Convert generators to tf.data.Dataset
        train_dataset = tf.data.Dataset.from_generator(
            train_gen,
            output_signature=(
                [
                    tf.TensorSpec(shape=(None, *self.dataset_builder.img_size, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(None, 6), dtype=tf.float32)
                ],
                tf.TensorSpec(shape=(None, len(self.dataset_builder.wheat_pests)), dtype=tf.float32)
            )
        )

        val_dataset = tf.data.Dataset.from_generator(
            val_gen,
            output_signature=(
                [
                    tf.TensorSpec(shape=(None, *self.dataset_builder.img_size, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(None, 6), dtype=tf.float32)
                ],
                tf.TensorSpec(shape=(None, len(self.dataset_builder.wheat_pests)), dtype=tf.float32)
            )
        )

        # Train model
        logger.info("Starting training...")
        history = self.model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=epochs,
            callbacks=callbacks,
            verbose=1
        )

        return history

    def evaluate_model(self, test_gen):
        """Evaluate the trained model"""
        logger.info("Evaluating model...")

        test_dataset = tf.data.Dataset.from_generator(
            test_gen,
            output_signature=(
                [
                    tf.TensorSpec(shape=(None, *self.dataset_builder.img_size, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(None, 6), dtype=tf.float32)
                ],
                tf.TensorSpec(shape=(None, len(self.dataset_builder.wheat_pests)), dtype=tf.float32)
            )
        )

        results = self.model.evaluate(test_dataset, verbose=1)

        logger.info("Test Results:")
        for name, value in zip(self.model.metrics_names, results):
            logger.info(f"{name}: {value:.4f}")

        return results

    def predict_pest(self, image_path, weather_data):
        """Predict pest from image and weather data"""
        # Load and preprocess image
        img = cv2.imread(str(image_path))
        img = cv2.resize(img, self.dataset_builder.img_size)
        img = img.astype(np.float32) / 255.0
        img = np.expand_dims(img, axis=0)

        # Prepare weather data
        weather_features = np.array([weather_data]).astype(np.float32)

        # Predict
        predictions = self.model.predict([img, weather_features])
        predicted_class_idx = np.argmax(predictions[0])
        confidence = predictions[0][predicted_class_idx]

        predicted_class = self.dataset_builder.wheat_pests[predicted_class_idx]

        return predicted_class, confidence, predictions[0]

def main():
    """Main function to run the complete pipeline"""
    logger.info("Starting Wheat Pest Detection Pipeline...")

    # Initialize pipeline
    pipeline = WheatPestPipeline()

    # Setup dataset
    metadata_df, weather_df = pipeline.setup_dataset()

    # Load and preprocess data
    train_gen, val_gen, test_gen = pipeline.load_and_preprocess_data()

    # Build and train model
    history = pipeline.build_and_train_model(train_gen, val_gen, epochs=20)

    # Evaluate model
    results = pipeline.evaluate_model(test_gen)

    # Example prediction
    logger.info("Pipeline setup complete!")
    logger.info("Dataset statistics:")
    logger.info(f"Total samples: {len(metadata_df)}")
    logger.info(f"Pest classes: {pipeline.dataset_builder.wheat_pests}")

if __name__ == "__main__":
    main()



TypeError: `output_signature` must contain objects that are subclass of `tf.TypeSpec` but found <class 'list'> which is not.

In [None]:
!pip install opencv-python seaborn -q

import logging
logging.basicConfig(level=logging.INFO)

import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import cv2
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
import requests
import json
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class WheatPestDatasetBuilder:
    """Class to build and manage wheat pest dataset with weather integration"""

    def __init__(self, data_dir="wheat_pest_data", img_size=(224, 224)):
        self.data_dir = Path(data_dir)
        self.img_size = img_size
        self.weather_scaler = StandardScaler()
        self.label_encoder = LabelEncoder()

        # Common wheat pests in India
        self.wheat_pests = [
            'healthy',
            'wheat_rust',
            'aphids',
            'stem_borer',
            'armyworm',
            'bacterial_blight',
            'powdery_mildew',
            'leaf_spot'
        ]

        # Create directory structure
        self.setup_directories()

    def setup_directories(self):
        """Setup directory structure for the dataset"""
        directories = [
            self.data_dir / "images" / "train",
            self.data_dir / "images" / "val",
            self.data_dir / "images" / "test",
            self.data_dir / "weather",
            self.data_dir / "processed",
            self.data_dir / "models"
        ]

        for dir_path in directories:
            dir_path.mkdir(parents=True, exist_ok=True)

        # Create pest class subdirectories
        for split in ['train', 'val', 'test']:
            for pest in self.wheat_pests:
                (self.data_dir / "images" / split / pest).mkdir(exist_ok=True)

    def download_sample_dataset(self):
        """Download and organize sample wheat pest images"""
        logger.info("Setting up sample dataset structure...")

        # Create sample metadata
        sample_data = []
        districts = ['Haryana_Karnal', 'Punjab_Ludhiana', 'UP_Meerut', 'MP_Indore']

        for i in range(1000):  # Sample 1000 entries
            pest_class = np.random.choice(self.wheat_pests)
            district = np.random.choice(districts)
            date = datetime.now() - timedelta(days=np.random.randint(0, 365))

            sample_data.append({
                'image_id': f'wheat_{i:04d}.jpg',
                'pest_class': pest_class,
                'district': district,
                'state': district.split('_')[0],
                'location': district.split('_')[1],
                'date': date.strftime('%Y-%m-%d'),
                'severity': np.random.choice(['mild', 'moderate', 'severe']),
                'crop_stage': np.random.choice(['seedling', 'tillering', 'heading', 'maturity'])
            })

        # Save metadata
        df = pd.DataFrame(sample_data)
        df.to_csv(self.data_dir / "wheat_pest_metadata.csv", index=False)
        logger.info(f"Created metadata for {len(df)} samples")

        return df

    def fetch_weather_data(self, district, date, api_key=None):
        """Fetch weather data for given district and date"""
        # Mock weather data (replace with actual IMD API calls)
        np.random.seed(hash(f"{district}_{date}") % 2**32)

        weather_data = {
            'date': date,
            'district': district,
            'temperature_max': np.random.normal(28, 5),
            'temperature_min': np.random.normal(15, 3),
            'humidity': np.random.normal(65, 15),
            'rainfall': np.random.exponential(2),
            'wind_speed': np.random.normal(8, 3),
            'soil_moisture': np.random.normal(45, 10)
        }

        return weather_data

    def build_weather_dataset(self, metadata_df):
        """Build weather dataset corresponding to pest images"""
        weather_data = []

        logger.info("Fetching weather data...")
        for _, row in metadata_df.iterrows():
            weather = self.fetch_weather_data(row['district'], row['date'])
            weather['image_id'] = row['image_id']
            weather_data.append(weather)

        weather_df = pd.DataFrame(weather_data)
        weather_df.to_csv(self.data_dir / "weather_data.csv", index=False)

        return weather_df

    def create_synthetic_images(self, metadata_df):
        """Create synthetic wheat pest images for demonstration"""
        logger.info("Creating synthetic images...")

        for _, row in metadata_df.iterrows():
            # Create synthetic image based on pest class
            img = self.generate_synthetic_pest_image(row['pest_class'])

            # Determine split (80% train, 10% val, 10% test)
            rand_val = np.random.random()
            if rand_val < 0.8:
                split = 'train'
            elif rand_val < 0.9:
                split = 'val'
            else:
                split = 'test'

            # Save image
            img_path = self.data_dir / "images" / split / row['pest_class'] / row['image_id']
            cv2.imwrite(str(img_path), img)

    def generate_synthetic_pest_image(self, pest_class):
        """Generate synthetic pest image based on class"""
        # Create base wheat crop image
        img = np.random.randint(50, 200, (*self.img_size, 3), dtype=np.uint8)

        # Add pest-specific patterns
        if pest_class == 'wheat_rust':
            # Add rust-like orange/brown spots
            for _ in range(np.random.randint(5, 15)):
                center = (np.random.randint(0, self.img_size[0]), np.random.randint(0, self.img_size[1]))
                radius = np.random.randint(5, 20)
                cv2.circle(img, center, radius, (30, 100, 200), -1)  # Orange-brown color

        elif pest_class == 'aphids':
            # Add small green/black dots
            for _ in range(np.random.randint(10, 30)):
                center = (np.random.randint(0, self.img_size[0]), np.random.randint(0, self.img_size[1]))
                cv2.circle(img, center, 2, (50, 50, 200), -1)

        elif pest_class == 'powdery_mildew':
            # Add white powdery patches
            for _ in range(np.random.randint(3, 8)):
                center = (np.random.randint(0, self.img_size[0]), np.random.randint(0, self.img_size[1]))
                axes = (np.random.randint(15, 30), np.random.randint(10, 25))
                cv2.ellipse(img, center, axes, 0, 0, 360, (200, 200, 200), -1)

        # Add some noise and blur for realism
        img = cv2.GaussianBlur(img, (3, 3), 0)
        noise = np.random.randint(-20, 20, img.shape, dtype=np.int16)
        img = np.clip(img.astype(np.int16) + noise, 0, 255).astype(np.uint8)

        return img

class WeatherImageFusionModel:
    """MSFNet-inspired model for fusing wheat pest images with weather data"""

    def __init__(self, num_classes, img_size=(224, 224), weather_features=6):
        self.num_classes = num_classes
        self.img_size = img_size
        self.weather_features = weather_features
        self.model = None

    def build_image_branch(self):
        """Build CNN branch for image processing"""
        img_input = layers.Input(shape=(*self.img_size, 3), name='image_input')

        # CNN backbone (similar to ResNet blocks)
        x = layers.Conv2D(64, (7, 7), strides=2, padding='same')(img_input)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.MaxPooling2D((3, 3), strides=2, padding='same')(x)

        # Residual blocks
        for filters in [64, 128, 256, 512]:
            x = self.residual_block(x, filters)
            x = self.residual_block(x, filters)
            if filters < 512:
                x = layers.MaxPooling2D((2, 2))(x)

        # Global pooling
        x = layers.GlobalAveragePooling2D()(x)
        img_features = layers.Dense(256, activation='relu', name='img_features')(x)

        return img_input, img_features

    def residual_block(self, x, filters):
        """Residual block for CNN"""
        shortcut = x

        x = layers.Conv2D(filters, (3, 3), padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)

        x = layers.Conv2D(filters, (3, 3), padding='same')(x)
        x = layers.BatchNormalization()(x)

        # Adjust shortcut if needed
        if shortcut.shape[-1] != filters:
            shortcut = layers.Conv2D(filters, (1, 1))(shortcut)

        x = layers.Add()([x, shortcut])
        x = layers.ReLU()(x)

        return x

    def build_weather_branch(self):
        """Build MLP branch for weather data processing"""
        weather_input = layers.Input(shape=(self.weather_features,), name='weather_input')

        x = layers.Dense(128, activation='relu')(weather_input)
        x = layers.Dropout(0.3)(x)
        x = layers.Dense(64, activation='relu')(x)
        x = layers.Dropout(0.3)(x)
        weather_features = layers.Dense(64, activation='relu', name='weather_features')(x)

        return weather_input, weather_features

    def build_fusion_model(self):
        """Build complete fusion model"""
        # Build branches
        img_input, img_features = self.build_image_branch()
        weather_input, weather_features = self.build_weather_branch()

        # Fusion mechanism - Cross-modal attention
        # Attention from weather to image
        weather_att = layers.Dense(256)(weather_features)
        weather_att = layers.Softmax()(weather_att)
        img_attended = layers.Multiply()([img_features, weather_att])

        # Attention from image to weather
        img_att = layers.Dense(64)(img_features)
        img_att = layers.Softmax()(img_att)
        weather_attended = layers.Multiply()([weather_features, img_att])

        # Concatenate attended features
        fused_features = layers.Concatenate()([img_attended, weather_attended])

        # Final classification layers
        x = layers.Dense(512, activation='relu')(fused_features)
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(256, activation='relu')(x)
        x = layers.Dropout(0.3)(x)

        output = layers.Dense(self.num_classes, activation='softmax', name='pest_prediction')(x)

        # Create model
        self.model = keras.Model(
            inputs=[img_input, weather_input],
            outputs=output,
            name='wheat_pest_weather_fusion'
        )

        return self.model

class WheatPestPipeline:
    """Complete pipeline for wheat pest detection with weather fusion"""

    def __init__(self, data_dir="wheat_pest_data"):
        self.data_dir = Path(data_dir)
        self.dataset_builder = WheatPestDatasetBuilder(data_dir)
        self.model_builder = None
        self.model = None

    def setup_dataset(self):
        """Setup complete dataset"""
        logger.info("Setting up wheat pest dataset...")

        # Create sample dataset
        metadata_df = self.dataset_builder.download_sample_dataset()

        # Build weather dataset
        weather_df = self.dataset_builder.build_weather_dataset(metadata_df)

        # Create synthetic images
        self.dataset_builder.create_synthetic_images(metadata_df)

        logger.info("Dataset setup complete!")
        return metadata_df, weather_df

    def load_and_preprocess_data(self):
        """Load and preprocess data for training"""
        logger.info("Loading and preprocessing data...")

        # Load metadata
        metadata_df = pd.read_csv(self.data_dir / "wheat_pest_metadata.csv")
        weather_df = pd.read_csv(self.data_dir / "weather_data.csv")

        # Merge datasets
        data_df = metadata_df.merge(weather_df, on='image_id', how='inner')

        # Split data by directories that actually exist
        train_data = []
        val_data = []
        test_data = []

        for _, row in data_df.iterrows():
            found = False
            for split in ['train', 'val', 'test']:
                img_path = self.data_dir / "images" / split / row['pest_class'] / row['image_id']
                if img_path.exists():
                    if split == 'train':
                        train_data.append(row)
                    elif split == 'val':
                        val_data.append(row)
                    else:
                        test_data.append(row)
                    found = True
                    break

            if not found:
                train_data.append(row)  # Default to train if not found

        train_df = pd.DataFrame(train_data)
        val_df = pd.DataFrame(val_data)
        test_df = pd.DataFrame(test_data)

        logger.info(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}, Test samples: {len(test_df)}")

        return train_df, val_df, test_df

    def create_data_arrays(self, df, split):
        """Create numpy arrays from dataframe"""
        images = []
        weather_data = []
        labels = []

        for _, row in df.iterrows():
            # Find image path
            img_path = None
            for s in ['train', 'val', 'test']:
                potential_path = self.data_dir / "images" / s / row['pest_class'] / row['image_id']
                if potential_path.exists():
                    img_path = potential_path
                    break

            if img_path and img_path.exists():
                # Load image
                img = cv2.imread(str(img_path))
                if img is not None:
                    img = cv2.resize(img, self.dataset_builder.img_size)
                    img = img.astype(np.float32) / 255.0
                    images.append(img)

                    # Weather features
                    weather_features = [
                        row['temperature_max'], row['temperature_min'],
                        row['humidity'], row['rainfall'],
                        row['wind_speed'], row['soil_moisture']
                    ]
                    weather_data.append(weather_features)

                    # Label
                    labels.append(self.dataset_builder.wheat_pests.index(row['pest_class']))

        if len(images) == 0:
            logger.warning(f"No valid images found for {split} split")
            return None, None, None

        return np.array(images), np.array(weather_data), np.array(labels)

    def build_and_train_model(self, train_df, val_df, epochs=20):
        """Build and train the fusion model"""
        logger.info("Building fusion model...")

        # Initialize model builder
        self.model_builder = WeatherImageFusionModel(
            num_classes=len(self.dataset_builder.wheat_pests),
            img_size=self.dataset_builder.img_size
        )

        # Build model
        self.model = self.model_builder.build_fusion_model()

        # Compile model
        self.model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=0.001),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )

        logger.info("Model architecture:")
        self.model.summary()

        # Prepare training data
        X_train_img, X_train_weather, y_train = self.create_data_arrays(train_df, 'train')
        X_val_img, X_val_weather, y_val = self.create_data_arrays(val_df, 'val')

        if X_train_img is None or X_val_img is None:
            logger.error("Failed to load training or validation data")
            return None

        # Convert labels to categorical
        y_train_cat = tf.keras.utils.to_categorical(y_train, len(self.dataset_builder.wheat_pests))
        y_val_cat = tf.keras.utils.to_categorical(y_val, len(self.dataset_builder.wheat_pests))

        # Callbacks
        callbacks = [
            keras.callbacks.ModelCheckpoint(
                str(self.data_dir / "models" / "best_model.h5"),
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1
            ),
            keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=5,
                verbose=1
            ),
            keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=10,
                verbose=1,
                restore_best_weights=True
            )
        ]

        # Train model
        logger.info("Starting training...")
        history = self.model.fit(
            [X_train_img, X_train_weather], y_train_cat,
            validation_data=([X_val_img, X_val_weather], y_val_cat),
            epochs=epochs,
            batch_size=32,
            callbacks=callbacks,
            verbose=1
        )

        return history

    def evaluate_model(self, test_df):
        """Evaluate the trained model"""
        logger.info("Evaluating model...")

        X_test_img, X_test_weather, y_test = self.create_data_arrays(test_df, 'test')

        if X_test_img is None:
            logger.error("Failed to load test data")
            return None

        y_test_cat = tf.keras.utils.to_categorical(y_test, len(self.dataset_builder.wheat_pests))

        results = self.model.evaluate([X_test_img, X_test_weather], y_test_cat, verbose=1)

        logger.info("Test Results:")
        for name, value in zip(self.model.metrics_names, results):
            logger.info(f"{name}: {value:.4f}")

        return results

    def predict_pest(self, image_path, weather_data):
        """Predict pest from image and weather data"""
        # Load and preprocess image
        img = cv2.imread(str(image_path))
        img = cv2.resize(img, self.dataset_builder.img_size)
        img = img.astype(np.float32) / 255.0
        img = np.expand_dims(img, axis=0)

        # Prepare weather data
        weather_features = np.array([weather_data]).astype(np.float32)

        # Predict
        predictions = self.model.predict([img, weather_features])
        predicted_class_idx = np.argmax(predictions[0])
        confidence = predictions[0][predicted_class_idx]

        predicted_class = self.dataset_builder.wheat_pests[predicted_class_idx]

        return predicted_class, confidence, predictions[0]

    def plot_training_history(self, history):
        """Plot training history"""
        if history is None:
            return

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

        # Plot training & validation accuracy
        ax1.plot(history.history['accuracy'], label='Training Accuracy')
        ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
        ax1.set_title('Model Accuracy')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Accuracy')
        ax1.legend()

        # Plot training & validation loss
        ax2.plot(history.history['loss'], label='Training Loss')
        ax2.plot(history.history['val_loss'], label='Validation Loss')
        ax2.set_title('Model Loss')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss')
        ax2.legend()

        plt.tight_layout()
        plt.show()

def main():
    """Main function to run the complete pipeline"""
    logger.info("Starting Wheat Pest Detection Pipeline...")

    # Initialize pipeline
    pipeline = WheatPestPipeline()

    # Setup dataset
    metadata_df, weather_df = pipeline.setup_dataset()

    # Load and preprocess data
    train_df, val_df, test_df = pipeline.load_and_preprocess_data()

    # Build and train model
    history = pipeline.build_and_train_model(train_df, val_df, epochs=10)

    # Plot training history
    pipeline.plot_training_history(history)

    # Evaluate model
    if len(test_df) > 0:
        results = pipeline.evaluate_model(test_df)
    else:
        logger.warning("No test data available for evaluation")

    # Display dataset statistics
    logger.info("Pipeline setup complete!")
    logger.info("Dataset statistics:")
    logger.info(f"Total samples: {len(metadata_df)}")
    logger.info(f"Pest classes: {pipeline.dataset_builder.wheat_pests}")
    logger.info(f"Training samples: {len(train_df)}")
    logger.info(f"Validation samples: {len(val_df)}")
    logger.info(f"Test samples: {len(test_df)}")

    return pipeline

# Run the main function
pipeline = main()

Epoch 1/10
[1m10/27[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m2:34[0m 9s/step - accuracy: 0.2173 - loss: 1.9746