In [None]:
import os
import json
from typing import Dict, Any

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, HumanMessage


 Initialize the gemini_vision_model with the API key from environment variables

In [None]:
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
gemini_vision_model = init_chat_model(
    "gemini-2.0-flash", model_provider="google_genai", api_key=GEMINI_API_KEY
)

In [None]:
def classify_plant_image(
    image_base64: str, plant_type: list[str], disease_type: list[str]
) -> Dict[str, Any]:
    """
    Classify a plant image into one of the possible classes.

    Args:
        image_base64: Base64 encoded image data
        plant_type: List of possible class names
        disease_type: List of possible disease class names

    Returns:
        Dict containing classification results

    Raises:
        Exception: If classification fails
    """
    classes_str = ", ".join(plant_type)

    classification_prompt = f"""
    You are an expert plant classifier. Given the image, classify it into one of the following classes: {classes_str}.

    Return your response in JSON format as:
    {{
        "plant_type": "One of the provided class names",
        "disease_type": "One of the provided disease class names or `none` if healthy",
        "is_healthy": "true/false",
        "confidence": "0-1",
        "explanation": "Brief explanation of why this class was chosen"
    }}
    """

    # Create message with image
    message = HumanMessage(
        content=[
            {"type": "text", "text": classification_prompt},
            {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
            },
        ]
    )

    response = gemini_vision_model.invoke([message])

    response_text = response.content.strip()

    # Clean up response if it has markdown code blocks
    if "```json" in response_text:
        response_text = response_text.split("```json")[1].split("```")[0].strip()
    elif "```" in response_text:
        response_text = response_text.split("```")[1].split("```")[0].strip()

    classification_data = json.loads(response_text)

    return {
        "plant_type": classification_data.get("plant_type", ""),
        "disease_type": classification_data.get("disease_type", "none"),
        "is_healthy": classification_data.get("is_healthy", "false").lower() == "true",
        "confidence": classification_data.get("confidence", "0"),
        "explanation": classification_data.get("explanation", ""),
    }


def get_image_encoding(image_path: str) -> list[float]:
    # get the base64 encoding of the image
    with open(image_path, "rb") as img_file:
        image_data = img_file.read()
    image_base64 = base64.b64encode(image_data).decode("utf-8")
    return image_base64