In [1]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

# Data organization
dataset_root = '/kaggle/input/fruit-infection-disease-dataset/sl_train'
categories = ['sl_trainBeans_Angular_LeafSpot', 'sl_trainBeans_Rust', 'sl_trainStrawberry_Angular_LeafSpot',
              'sl_trainStrawberry_Anthracnose_Fruit_Rot', 'sl_trainStrawberry_Blossom_Blight',
              'sl_trainStrawberry_Gray_Mold', 'sl_trainStrawberry_Leaf_Spot', 'sl_trainStrawberry_Powdery_Mildew_Fruit',
              'sl_trainStrawberry_Powdery_Mildew_Leaf', 'sl_trainTomato_Blight', 'sl_trainTomato_Leaf_Mold',
              'sl_trainTomato_Spider_Mites']

def create_cnn_model():
    model = Sequential()
    model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)))
    model.add(MaxPooling2D((2, 2)))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Flatten())
    model.add(Dense(64, activation='relu'))
    model.add(Dense(len(categories), activation='softmax'))  # Number of units should be len(categories)
    return model

# Create and compile the CNN model
model = create_cnn_model()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Load and preprocess the training data
train_images = []
train_labels = []

for category_idx, category in enumerate(categories):
    category_folder = os.path.join(dataset_root, category)
    image_names = os.listdir(category_folder)
    for image_name in image_names:
        image_path = os.path.join(category_folder, image_name)
        image = cv2.imread(image_path)
        resized_image = cv2.resize(image, (64, 64))
        normalized_image = resized_image / 255.0
        train_images.append(normalized_image)
        label = np.zeros(len(categories))
        label[category_idx] = 1
        train_labels.append(label)

train_images = np.array(train_images)
train_labels = np.array(train_labels)

# Train the model
model.fit(train_images, train_labels, epochs=10, batch_size=32)

# Save the trained model weights
model.save_weights('trained_model_weights.h5')

# Preprocessing and disease detection
def preprocess_and_detect_disease(image_path):
    image = cv2.imread(image_path)

    # Preprocess the image
    resized_image = cv2.resize(image, (64, 64))
    normalized_image = resized_image / 255.0

    # Reshape the image to match the input shape of the CNN model
    input_image = np.expand_dims(normalized_image, axis=0)

    # Load the trained model with the saved weights
    model = create_cnn_model()
    model.load_weights('trained_model_weights.h5')

    # Predict the class probabilities using the CNN model
    class_probs = model.predict(input_image)[0]
    predicted_class_index = np.argmax(class_probs)
    predicted_class = categories[predicted_class_index]

    # Draw the predicted class label on 
    cv2.putText(image, predicted_class, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

    # Save the output image
    output_path = 'output.jpg'
    cv2.imwrite(output_path, image)
    print(f"Output image saved at: {output_path}")

# Example usage
image_path = '/kaggle/input/fruit-infection-disease-dataset/sl_train/sl_trainBeans_Rust/1619076008186_jpg.rf.74604a8bc09e89108fad740c9776b6d4.jpg'
preprocess_and_detect_disease(image_path)

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Output image saved at: output.jpg


In [2]:
    # Save the output image
    output_path = 'output.jpg'
    cv2.imwrite(output_path, image)
    print(f"Output image saved at: {output_path}")

# Example usage
image_path = '/kaggle/input/fruit-infection-disease-dataset/sl_test/sl_testStrawberry_Gray_Mold/gray_mold123_jpg.rf.d446c7ec5827602463d1b092b00a20f9.jpg'
preprocess_and_detect_disease(image_path)

Output image saved at: output.jpg
Output image saved at: output.jpg
