## Image Classification

### flatifying images 

In [None]:
import shutil
import os

# Create a flat test folder
flat_test_folder = "flat_test_images/"
os.makedirs(flat_test_folder, exist_ok=True)

# Copy all images from processed/test to flat folder
source_folder = "data/processed/test/"
for root, dirs, files in os.walk(source_folder):
    for file in files:
        if file.lower().endswith(('.jpg', '.jpeg', '.png')):
            src_path = os.path.join(root, file)
            # Create unique filename with class name
            class_name = os.path.basename(root)
            new_filename = f"{class_name}_{file}"
            dst_path = os.path.join(flat_test_folder, new_filename)
            shutil.copy2(src_path, dst_path)

print(f"Copied images to: {flat_test_folder}")


Copied images to: flat_test_images/


## INFERENCE


In [None]:
# plant_detector.py - Core Detection Engine
import tensorflow as tf
import numpy as np
import cv2

class PlantDiseaseDetector:
    
    CLASS_NAMES = [
        'Pepper__bell___Bacterial_spot',
        'Pepper__bell___healthy', 
        'Potato___Early_blight',
        'Potato___Late_blight',
        'Potato___healthy',
        'Tomato__Target_Spot',
        'Tomato__Tomato_YellowLeaf__Curl_Virus',
        'Tomato__Tomato_mosaic_virus',
        'Tomato_healthy'
    ]
    
    # User-friendly names
    FRIENDLY_NAMES = {
        'Pepper__bell___Bacterial_spot': 'Pepper - Bacterial Spot',
        'Pepper__bell___healthy': 'Pepper - Healthy',
        'Potato___Early_blight': 'Potato - Early Blight',
        'Potato___Late_blight': 'Potato - Late Blight',
        'Potato___healthy': 'Potato - Healthy',
        'Tomato__Target_Spot': 'Tomato - Target Spot',
        'Tomato__Tomato_YellowLeaf__Curl_Virus': 'Tomato - Yellow Leaf Curl Virus',
        'Tomato__Tomato_mosaic_virus': 'Tomato - Mosaic Virus',
        'Tomato_healthy': 'Tomato - Healthy'
    }
    
    def __init__(self, model_path="Models/tflite_conversion/best_model.tflite"):
        self.model = self._load_model(model_path)
        
    def _load_model(self, model_path):
        self.interpreter = tf.lite.Interpreter(model_path)
        self.interpreter.allocate_tensors()
        self.input_details = self.interpreter.get_input_details()
        self.output_details = self.interpreter.get_output_details()
        print(f"Model loaded: {model_path}")
        return self.interpreter
    

    def preprocess_image(self, image):
  
        # Handle different input types
        if isinstance(image, str):  # File path
            img = cv2.imread(image)
            if img is None:
                raise ValueError(f"Could not read image from path: {image}")
        elif isinstance(image, bytes):  # Bytes
            nparr = np.frombuffer(image, np.uint8)
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            if img is None:
                raise ValueError("Could not decode image from bytes")
        elif hasattr(image, 'convert'):  # PIL Image
            img = np.array(image.convert('RGB'))
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        elif isinstance(image, np.ndarray):  # numpy array
            img = image.copy()
            # Convert RGBA to RGB if needed
            if len(img.shape) == 3 and img.shape[2] == 4:
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
            # Assume RGB, convert to BGR for OpenCV
            elif len(img.shape) == 3 and img.shape[2] == 3:
                # Check if it's already BGR (from OpenCV)
                # We'll assume RGB and convert to BGR
                img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        else:
            raise TypeError("Unsupported image type. Use file path, bytes, PIL Image, or numpy array")
        
        # Resize to 224x224 
        img = cv2.resize(img, (224, 224))
        
        # Convert BGR to RGB (model expects RGB)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Convert to float32 (keep 0-255 range)
        img = img.astype(np.float32)
        
        # Add batch dimension
        img = np.expand_dims(img, axis=0)
        
        return img
    
    def predict(self, image):
        # Preprocess image
        processed_img = self.preprocess_image(image)
        
        # Make prediction
        self.interpreter.set_tensor(self.input_details[0]["index"], processed_img)
        self.interpreter.invoke()
        predictions = self.interpreter.get_tensor(self.output_details[0]["index"])[0]
        
        # Get top prediction
        predicted_idx = predictions.argmax()
        confidence = float(predictions[predicted_idx])
        predicted_class = self.CLASS_NAMES[predicted_idx]
        
        # Get top 3 predictions
        top3_indices = predictions.argsort()[-3:][::-1]
        top3_classes = [self.CLASS_NAMES[i] for i in top3_indices]
        top3_confidences = [float(predictions[i]) for i in top3_indices]
        
        # Get confidence level
        confidence_level = self._get_confidence_level(confidence)
        
        # Get plant and condition
        plant, condition = self._parse_class_name(predicted_class)
        
        # Get advice based on confidence
        advice = self._get_advice(confidence, predicted_class)
        
        # Create result dictionary
        result = {
            'success': True,
            'predicted_class': predicted_class,
            'friendly_name': self.FRIENDLY_NAMES.get(predicted_class, predicted_class),
            'confidence': confidence,
            'confidence_level': confidence_level,
            'plant': plant,
            'condition': condition,
            'advice': advice,
            'top3_predictions': [
                {
                    'class': top3_classes[i],
                    'friendly_name': self.FRIENDLY_NAMES.get(top3_classes[i], top3_classes[i]),
                    'confidence': top3_confidences[i]
                }
                for i in range(len(top3_classes))
            ],
            'all_predictions': {
                self.CLASS_NAMES[i]: float(predictions[i]) 
                for i in range(len(self.CLASS_NAMES))
            }
        }
        
        return result
    
    def predict_batch(self, images):
        
        results = []
        for image in images:
            try:
                result = self.predict(image)
                results.append(result)
            except Exception as e:
                results.append({
                    'success': False,
                    'error': str(e)
                })
        return results
    
    def predict_from_webcam_frame(self, frame):
        """
        Special method for webcam frames (optimized for speed)"""
        
        # Fast preprocessing for webcam
        img = cv2.resize(frame, (224, 224))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32)
        img = np.expand_dims(img, axis=0)
        
        # Make prediction
        self.interpreter.set_tensor(self.input_details[0]["index"], img)
        self.interpreter.invoke()
        predictions = self.interpreter.get_tensor(self.output_details[0]["index"])[0]
        
        # Get top prediction only (for speed)
        predicted_idx = predictions.argmax()
        confidence = float(predictions[predicted_idx])
        predicted_class = self.CLASS_NAMES[predicted_idx]
        
        return {
            'predicted_class': predicted_class,
            'friendly_name': self.FRIENDLY_NAMES.get(predicted_class, predicted_class),
            'confidence': confidence,
            'confidence_level': self._get_confidence_level(confidence)
        }
    
    def _get_confidence_level(self, confidence):
        if confidence >= 0.8:
            return "high"
        elif confidence >= 0.6:
            return "medium"
        else:
            return "low"
    
    def _parse_class_name(self, class_name):
        if '___' in class_name:
            parts = class_name.split('___')
            plant = parts[0].replace('__', ' ').strip()
        
        # remove duplicate "Tomato"
            if plant.startswith('Tomato Tomato'):
                plant = plant.replace('Tomato Tomato', 'Tomato')
        
            condition = parts[-1].replace('_', ' ').strip()
            return plant, condition
        else:
            return "Unknown", class_name
    
    def _get_advice(self, confidence, predicted_class):
        if "healthy" in predicted_class.lower():
            return "Plant appears healthy. Continue regular monitoring."
        
        if confidence >= 0.8:
            return "High confidence diagnosis. Consider appropriate treatment."
        elif confidence >= 0.6:
            return "Moderate confidence. Monitor closely and consider retesting."
        else:
            return "Low confidence. Please verify with expert or take clearer photos."
    
    def get_class_info(self):
        classes_info = []
        for class_name in self.CLASS_NAMES:
            plant, condition = self._parse_class_name(class_name)
            classes_info.append({
                'class_name': class_name,
                'friendly_name': self.FRIENDLY_NAMES.get(class_name, class_name),
                'plant': plant,
                'condition': condition,
                'is_healthy': 'healthy' in class_name.lower()
            })
        return classes_info

# Helper functions for easy integration


def create_detector(model_path="Models/tflite_conversion/best_model.tflite"):
    return PlantDiseaseDetector(model_path)

def predict_image(image, model_path=None):
  
    detector = PlantDiseaseDetector(model_path) if model_path else create_detector()
    return detector.predict(image)

def predict_webcam_frame(frame, model_path=None):
   
    detector = PlantDiseaseDetector(model_path) if model_path else create_detector()
    return detector.predict_from_webcam_frame(frame)


# Example usage for gui 


if __name__ == "__main__":
    # Example 1: Basic usage
    print("Example 1: Basic usage")
    detector = PlantDiseaseDetector()
    
    # From file path
    result = detector.predict("path/to/your/image.jpg")
    print(f"Prediction: {result['friendly_name']}")
    print(f"Confidence: {result['confidence']:.2%}")
    print(f"Advice: {result['advice']}")
    
    # Example 2: For Streamlit integration
    print("\nExample 2: Streamlit-ready usage")
    
    # In Streamlit, gui would use:
    """
    import streamlit as st
    from plant_detector import predict_image
    from PIL import Image
    
    uploaded_file = st.file_uploader("Upload plant image", type=["jpg", "jpeg", "png"])
    
    if uploaded_file:
        # Convert to PIL Image
        image = Image.open(uploaded_file)
        
        # Get prediction
        result = predict_image(image)
        
        # Display results
        st.image(image, caption="Uploaded Image", width=300)
        st.success(f"**Diagnosis:** {result['friendly_name']}")
        st.metric("Confidence", f"{result['confidence']:.2%}")
        st.info(f"**Advice:** {result['advice']}")
    """
    
    # Example 3: Webcam in Streamlit
    print("\nExample 3: Webcam usage")
    """
    # In Streamlit with streamlit-webrtc:
    from streamlit_webrtc import webrtc_streamer
    from plant_detector import predict_webcam_frame
    
    def video_frame_callback(frame):
        img = frame.to_ndarray(format="bgr24")
        result = predict_webcam_frame(img)
        
        # Add text to frame
        cv2.putText(img, result['friendly_name'], (10, 30), 
                   cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        
        return frame
    
    webrtc_streamer(key="plant-detector", video_frame_callback=video_frame_callback)
    """