In [2]:
import os
import sqlite3
import json
from typing import List, Dict
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import faiss
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(


In [None]:
MODEL_PATH = "model/best_model.h5"
SQLITE_DB = "crop_diseases_rag.db"
FAISS_INDEX_PATH = "crop_diseases.faiss"
EMBED_MODEL_NAME = "all-MiniLM-L6-v2"
TOP_K = 3

In [4]:
# Labels (same order as training)
LABELS = [
    "Corn___Common_Rust",
    "Corn___Gray_Leaf_Spot",
    "Corn___Healthy",
    "Corn___Northern_Leaf_Blight",
    "Potato___Early_Blight",
    "Potato___Healthy",
    "Potato___Late_Blight",
    "Rice___Brown_Spot",
    "Rice___Healthy",
    "Rice___Leaf_Blast",
    "Rice___Neck_Blast",
    "Sugarcane_Bacterial Blight",
    "Sugarcane_Healthy",
    "Sugarcane_Red Rot",
    "Wheat___Brown_Rust",
    "Wheat___Healthy",
    "Wheat___Yellow_Rust"
]


In [5]:
print("Loading MobileNetV2 model...")
model = load_model(MODEL_PATH, compile=False)

print("Loading embedding model...")
embedder = SentenceTransformer(EMBED_MODEL_NAME)

print("Loading FAISS index...")
index = faiss.read_index(FAISS_INDEX_PATH)

# Load FAISS metadata
import json
with open("faiss_metadata.json", "r", encoding="utf-8") as f:
    sqlite_ids = json.load(f)

Loading MobileNetV2 model...
Loading embedding model...


Downloading .gitattributes: 1.23kB [00:00, 1.23MB/s]
Downloading config.json: 100%|██████████| 190/190 [00:00<?, ?B/s] 
Downloading README.md: 10.5kB [00:00, 7.74MB/s]
Downloading config.json: 100%|██████████| 612/612 [00:00<?, ?B/s] 
Downloading (…)ce_transformers.json: 100%|██████████| 116/116 [00:00<?, ?B/s] 
Downloading data_config.json: 39.3kB [00:00, 37.3MB/s]
Downloading model.safetensors: 100%|██████████| 90.9M/90.9M [00:16<00:00, 5.67MB/s]
Downloading model.onnx: 100%|██████████| 90.4M/90.4M [00:19<00:00, 4.75MB/s]
Downloading model_O1.onnx: 100%|██████████| 90.4M/90.4M [00:15<00:00, 5.76MB/s]
Downloading model_O2.onnx: 100%|██████████| 90.3M/90.3M [00:15<00:00, 5.74MB/s]
Downloading model_O3.onnx: 100%|██████████| 90.3M/90.3M [00:15<00:00, 5.72MB/s]
Downloading model_O4.onnx: 100%|██████████| 45.2M/45.2M [00:08<00:00, 5.29MB/s]
Downloading model_qint8_arm64.onnx: 100%|██████████| 23.0M/23.0M [00:08<00:00, 2.87MB/s]
Downloading (…)el_qint8_avx512.onnx: 100%|██████████| 23.0M/2

Loading FAISS index...


In [None]:
def retrieve_by_text(query: str, top_k:int = TOP_K) -> List[Dict]:
    q_emb = embedder.encode([query], convert_to_numpy=True).astype('float32')
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, top_k)
    results = []
    for score, idx in zip(D[0].tolist(), I[0].tolist()):
        if idx < 0 or idx >= len(sqlite_ids):
            continue
        sqlite_id = sqlite_ids[str(idx)]
        conn = sqlite3.connect(SQLITE_DB)
        cur = conn.cursor()
        cur.execute("""
            SELECT label, crop, disease, why_en, precautions_en, remedies_en, 
                   why_hi, precautions_hi, remedies_hi 
            FROM disease_kb WHERE id=?""", (sqlite_id,))
        row = cur.fetchone()
        conn.close()
        if not row:
            continue
        results.append({
            "label": row[0],
            "crop": row[1],
            "disease": row[2],
            "why_en": row[3],
            "precautions_en": row[4],
            "remedies_en": row[5],
            "score": float(score)
        })
    return results

def get_info_by_label(label:str, top_k:int=1) -> Dict:
    # First try exact lookup
    conn = sqlite3.connect(SQLITE_DB)
    cur = conn.cursor()
    cur.execute("""
        SELECT label, crop, disease, why_en, precautions_en, remedies_en
        FROM disease_kb WHERE label=?""", (label,))
    row = cur.fetchone()
    conn.close()
    if row:
        return {
            "label": row[0],
            "crop": row[1],
            "disease": row[2],
            "english": {
                "why": row[3],
                "precautions": row[4],
                "remedies": row[5]
            },
            "source": "local_db"
        }
    # Fallback to embedding search
    fallback_query = label.replace("___"," ").replace("_"," ")
    retrieved = retrieve_by_text(fallback_query, top_k=top_k)
    if not retrieved:
        return {"error": "not found"}
    top = retrieved[0]
    return {
        "label": top["label"],
        "crop": top["crop"],
        "disease": top["disease"],
        "english": {
            "why": top["why_en"], 
            "precautions": top["precautions_en"], 
            "remedies": top["remedies_en"]
        },
        "score": top["score"],
        "source": "embedding_retrieval"
    }

In [None]:
import cv2
import numpy as np
from tensorflow.keras.preprocessing import image

def predict_disease(img_path):
    # Validate plant image first
    img_cv = cv2.imread(img_path)
    if img_cv is None:
        return "Invalid image path ❌"

    # Hard-rule validation
    def is_leaf_colored(img, threshold=0.25):
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype("float") / 255.0
        R, G, B = img[:,:,0], img[:,:,1], img[:,:,2]
        green_mask = (G > R) & (G > B)
        yellow_mask = (R > 0.3) & (G > 0.3) & (B < 0.4) & (abs(R-G) < 0.2)
        brown_mask = (R > 0.4) & (G > 0.2) & (B < 0.3)
        valid_mask = green_mask | yellow_mask | brown_mask
        valid_ratio = np.sum(valid_mask) / (img.shape[0] * img.shape[1])
        return valid_ratio > threshold

    def has_leaf_texture(img, min_ratio=0.005):
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        edges = cv2.Canny(gray, 100, 200)
        edge_ratio = np.sum(edges > 0) / (img.shape[0] * img.shape[1])
        return edge_ratio > min_ratio

    def background_check(img, var_threshold=0.001):
        img_norm = img.astype("float") / 255.0
        var = np.var(img_norm, axis=(0,1))
        avg_var = np.mean(var)
        return avg_var >= var_threshold

    if not (is_leaf_colored(img_cv) and has_leaf_texture(img_cv) and background_check(img_cv)):
        return "Not a valid plant image ❌"

    # Preprocess for model
    img = image.load_img(img_path, target_size=(224, 224))
    img_array = image.img_to_array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)

    # Predict disease
    preds = model.predict(img_array)
    pred_idx = np.argmax(preds)
    return LABELS[pred_idx]

In [8]:
test_image = r"E:\PROJECTS\Krishi-Sahayak\krishi-model\SplitData\test\Corn___Common_Rust\image (4).JPG"
predicted_label = predict_disease(test_image)
print(f"Predicted Label: {predicted_label}")

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
Predicted Label: Corn___Common_Rust


In [9]:
info = get_info_by_label(predicted_label)
print("\n=== Result ===")
print("Crop:", info["crop"])
print("Disease:", info["disease"])
print("\n--- English ---")
print("Why:", info["english"]["why"])
print("Precautions:", info["english"]["precautions"])
print("Remedies:", info["english"]["remedies"])
print("\n--- Hindi ---")
print("क्यों:", info["hindi"]["why"])
print("सावधानियाँ:", info["hindi"]["precautions"])
print("उपचार:", info["hindi"]["remedies"])


=== Result ===
Crop: Corn
Disease: Common Rust

--- English ---
Why: Caused by the fungus Puccinia sorghi; develops in moderate temperatures (16–25°C) and high humidity; spreads by wind-borne spores.
Precautions: Grow rust-resistant hybrids; rotate with non-host crops; destroy volunteer corn and grassy weeds.
Remedies: Apply foliar fungicides (strobilurins or triazoles) when conditions favor disease; monitor regularly.

--- Hindi ---
क्यों: यह रोग Puccinia sorghi कवक के कारण होता है; 16–25°C तापमान और उच्च आर्द्रता में बढ़ता है; हवा से फैलने वाले स्पोर्स के द्वारा फैलता है।
सावधानियाँ: रस्ट-प्रतिरोधी किस्में उगाएँ; नॉन-होस्ट फसलों के साथ रोटेशन करें; खेत में जंगली मक्का और घास नष्ट करें।
उपचार: जब बीमारी के अनुकूल मौसम हो तो पत्ती-छिड़काव के लिए स्ट्रोबिल्यूरिन या ट्रायाज़ोल समूह के फंगीसाइड का उपयोग करें; नियमित निगरानी रखें।
