# RAG Agent

In [1]:
from typing import List, Dict, Tuple
import numpy as np
from sentence_transformers import SentenceTransformer, util
import json
from pathlib import Path
from IPython.display import Image




## Upload knowledge base

In [2]:
def prepare_knowledge_base(file_paths):
    knowledge_base = []

    for file_path in file_paths:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        documents = []

        for mushroom_name, mushroom_info in data.items():
            doc_parts = [mushroom_name]
            for label, value in mushroom_info.items():
                doc_parts.append(f"{label}: {value}")

            full_document = "\n".join(doc_parts)
            documents.append(full_document)

        knowledge_base.extend(documents)
    return knowledge_base

In [3]:
file_paths = [
    "../App/Knowledge_base/wild_food_uk.json",
    "../App/Knowledge_base/mushroom_world.json",
    "../App/Knowledge_base/wikipedia.json",
]

knowledge_base = prepare_knowledge_base(file_paths)

## Load Embeddings

In [4]:
embeddings = np.load(f"../App/Knowledge_base/embeddings.npy")

## Create Embeddings

In [5]:
# model = SentenceTransformer('all-MiniLM-L6-v2')
# embeddings = model.encode(knowledge_base)
#
# np.save("../App/Knowledge_base/embeddings_all.npy", embeddings)

## Agent

In [17]:
import google.genai as genai
from google.genai.types import GenerateContentConfig

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

class MushroomRAGAgent:
    def __init__(self, api_key: str, knowledge_base: List[str],
                 model_name: str, embedding_model: str, embeddings: np.ndarray = None):

        self.system_instructions = """You are an expert mycologist - a mushroom specialist.

            CRITICAL RULES:
            1. Answer ONLY based on the provided MUSHROOM KNOWLEDGE BASE documents
            2. If information is not in the provided documents, clearly state: "This information is not available in my knowledge base"
            3. For POISONOUS/TOXIC mushrooms, emphasize warnings STRONGLY with emojis (üö®‚ö†Ô∏èüíÄ)
            4. Never make up information - only use what's in the provided documents
            5. If the species is not specified in question answer about primary species or alternative species
            6. Only use images from the PROVIDED IMAGE PATHS
            7. Max 2 images per response
            8. If the user asks for more images, please select a different one than before

            OUTPUT FORMAT:
            - For the FIRST identification query: Use the structured format below
            - For follow-up questions: Respond conversationally and concisely

            Image usage format: <img src="{IMAGE_PATH}"/>

            STRUCTURED FORMAT (first query only):
            Name: [Common Name] ([Scientific Name])
            Confidence: [Confidence]

            Safety Status:
            ‚úÖ EDIBLE / ‚ö†Ô∏è POISONOUS / üíÄ DEADLY POISONOUS
            [Brief safety note]

            Key Identification Features:
            - Cap: [description]
            - Stem: [description]
            - Gills/Pores: [description]
            - Distinctive traits: [unique features]

            Location, Habitat & Season:
            - Geographic range: [location]
            - Habitat: [where it grows]
            - Season: [when it appears]

            Look-alikes:
            - [Any dangerous similar species]

            Alternative predictions:
            - [All alternative predictions]

            Image of primary prediction.

            Keep all descriptions very concise - only few words per point.
        """

        # Configure Gemini with new API
        self.client = genai.Client(api_key=api_key)
        self.model_name = model_name

        # Initialize chat with automatic history management ‚úÖ
        self.chat = self.client.chats.create(model=model_name,
                                             config=GenerateContentConfig(system_instruction=self.system_instructions))

        # Knowledge base
        self.knowledge_base = knowledge_base

        # Load semantic embedding model
        self.embedding_model = SentenceTransformer(embedding_model)

        # Create semantic embeddings for documents
        if embeddings is None:
            self.doc_embeddings = self.embedding_model.encode(knowledge_base)
        else:
            self.doc_embeddings = embeddings

        # Track current identification
        self.current_identification = None

        # Track first retrieved documents
        self.first_retrieved_docs = None


    def _retrieve_relevant_docs(self, query: str, top_k: int = 3) -> List[Dict[str, any]]:
        # Create embedding for the query
        query_embedding = self.embedding_model.encode(query)

        # Calculate semantic similarity with all documents
        similarities = []
        for doc_embedding in self.doc_embeddings:
            sim = cosine_similarity(query_embedding, doc_embedding)
            similarities.append(float(sim))
        similarities = np.array(similarities)

        # Get top-k most similar documents
        top_indices = np.argsort(similarities)[-top_k:][::-1]

        results = []
        for idx in top_indices:
            if similarities[idx] > 0:
                results.append({
                    "document": self.knowledge_base[idx],
                    "similarity": float(similarities[idx]),
                    "index": int(idx)
                })
        return results


    def _build_context(self, relevant_docs: List[Dict], query: str) -> str:
        context_parts = [
            f"{self.system_instructions}\n",
            "\n=== MUSHROOM KNOWLEDGE BASE (Retrieved Documents) ===\n"
        ]

        for i, doc_data in enumerate(relevant_docs, 1):
            doc = doc_data["document"]
            score = doc_data["similarity"]

            context_parts.append(f"\n--- Document {i} (relevance: {score:.3f}) ---")
            context_parts.append(doc)
            context_parts.append("---\n")

        context_parts.append(f"\n=== USER QUESTION ===")
        context_parts.append(f"{query}\n")
        context_parts.append("\nYour response (following all rules):")

        return "\n".join(context_parts)

    def send_message(self, user_message: str, top_k: int = 3, verbose: bool = False) -> str:
        relevant_docs = self._retrieve_relevant_docs(user_message, top_k=top_k)

        # ADD CONTEXT FROM PREVIOUS QUERY
        current_indices = {doc['index'] for doc in relevant_docs}
        for prev_doc in self.first_retrieved_docs:
            if prev_doc['index'] not in current_indices:
                relevant_docs.append(prev_doc)

        # AUGMENTATION
        context = self._build_context(relevant_docs, user_message)

        if verbose:
            print(f"\nüîç Found {len(relevant_docs)} relevant documents:")
            for doc_data in relevant_docs:
                doc_preview = doc_data['document'][:80].replace('\n', ' ')
                print(f"  - {doc_preview}... (score: {doc_data['similarity']:.3f})")

            print(f"\nüìù Context length: {len(context)} characters")
            history = self.chat.get_history()
            print(f"üí¨ Chat history length: {len(history)} messages")

        # GENERATION (chat automatically handles history!) ‚úÖ
        try:
            response = self.chat.send_message(message=context)
            return response.text
        except Exception as e:
            return f"Error: {str(e)}"

    def clear_history(self):
        self.chat = self.client.chats.create(model=self.model_name)
        self.first_retrieved_docs = []
        print("‚úì Conversation history cleared")

    def get_history(self) -> List:
        return self.chat.get_history()


    def initialize_from_predictions(self, predictions: Dict[str, float], verbose: bool = False) -> str:

        # Sort by confidence
        sorted_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)

        # Retrieve documents for predicted species
        relevant_docs = []
        for species, confidence in sorted_predictions:
            relevant_docs.extend(self._retrieve_relevant_docs(species, top_k=2))

        self.first_retrieved_docs = relevant_docs

        if verbose:
            print(f"\nüîç Retrieved documents for all candidates:")
            for doc_data in relevant_docs[:5]:
                doc_preview = doc_data['document'][:80].replace('\n', ' ')
                print(f"  - {doc_preview}... (score: {doc_data['similarity']:.3f})")

        # Build context with special instructions
        context = self._build_identification_context(relevant_docs, sorted_predictions)

        # Generate response using chat (maintains history automatically)
        try:
            response = self.chat.send_message(message=context)

            #Store current identification
            self.current_identification = {
                'primary': sorted_predictions[0],
                'alternatives': sorted_predictions[1:],
            }

            return response.text
        except Exception as e:
            return f"Error: {str(e)}"

    def _build_image_context(self, base_dir):
        base = Path(base_dir)

        lines = []

        for d in base.iterdir():
            if d.is_dir():
                lines.append(f"- {d.name}:")
                for p in d.iterdir():
                    if p.suffix.lower() in {".jpg", ".jpeg", ".png"}:
                        lines.append(f"  - {p.as_posix()}")

        return "\n".join(lines)


    def _build_identification_context(self, relevant_docs: List[Dict],
                                      predictions: List[Tuple[str, float]]) -> str:

        primary_species, primary_conf = predictions[0]

        context_parts = [
            f"{self.system_instructions}\n",
            "\n=== COMPUTER VISION IDENTIFICATION RESULTS ===\n",
            f"PRIMARY PREDICTION: {primary_species} (Confidence: {primary_conf:.2%})\n",
        ]

        if len(predictions) > 1:
            context_parts.append("ALTERNATIVE PREDICTIONS:")
            for i, (species, conf) in enumerate(predictions[1:], 2):
                context_parts.append(f"{species} (Confidence: {conf:.2%})")

        context_parts.append("\n=== PROVIDED IMAGE PATHS ===\n")
        context_parts.append(self._build_image_context("../App/example_images"))

        context_parts.append("\n=== MUSHROOM KNOWLEDGE BASE ===\n")

        for i, doc_data in enumerate(relevant_docs, 1):
            doc = doc_data["document"]
            score = doc_data["similarity"]

            context_parts.append(f"\n--- Document {i} (relevance: {score:.3f}) ---")
            context_parts.append(doc)
            context_parts.append("---\n")

        context_parts.append("\n=== TASK ===")
        context_parts.append(
            f"Provide a detailed identification card for {primary_species} "
            f"(the primary prediction with {primary_conf:.2%} confidence).\n"
        )

        if primary_conf < 0.90:
            context_parts.append(
                f"""Tell the user this warning:

                ‚ö†Ô∏è Confidence: {primary_conf:.2%} - EXPERT VERIFICATION REQUIRED

                üîç For better identification, please provide additional photos:
                - Cap underside (gills/pores)
                - Full stem with base
                - Growing habitat and surroundings

                Clear photos from multiple angles help distinguish similar species."""
            )

        if len(predictions) > 1:
            context_parts.append(
                "\n‚ö†Ô∏è IMPORTANT: Also check if any alternative predictions "
                f"({', '.join([s for s, _ in predictions[1:]])}) are dangerous species. "
                "If so, add a warning section at the very beginning before identification card and inform that this alternative species is in predictions.\n"
            )

        return "\n".join(context_parts)

## Load API KEY

In [18]:
import os
from dotenv import load_dotenv

load_dotenv()
API_KEY = os.getenv("API_KEY")

## Initialization of agent

In [19]:
agent = MushroomRAGAgent(
    api_key=API_KEY,
    knowledge_base=knowledge_base,
    model_name="gemini-2.5-flash-lite",
    embedding_model="all-MiniLM-L6-v2",
    embeddings=embeddings
)

## TESTING

In [20]:
# predictions = {
#     "Boletus edulis": 0.97,
#     "Boletus aereus": 0.02,
#     "Boletus reticulatus": 0.01
# }
predictions = {
    "Amanita muscaria": 0.88,
    "Amanita rubescens": 0.06,  # Edible
    "Amanita pantherina": 0.01  # TOXIC
}
# predictions = {
#     "Agaricus campestris": 0.88,
#     "Amanita phalloides": 0.07,  # DEADLY
#     "Agaricus arvensis": 0.01
# }

In [21]:
agent.clear_history()

response = agent.initialize_from_predictions(predictions, verbose=False)
print(response)
print("\n--------------------------------------------------\n")
#
# question = "Are they safe to eat?"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=False)
# print(response)
# print("\n--------------------------------------------------\n")
#
# question = "It is March right now, it is possible that it is this mushroom?"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=False)
# print(response)
# print("\n--------------------------------------------------\n")
#
# question = "Tell me the best recipe with this mushroom"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=False)
# print(response)
# print("\n--------------------------------------------------\n")

‚úì Conversation history cleared
It appears one of the alternative predictions, *Amanita pantherina* (Panther Amanita), is considered poisonous. üö®‚ö†Ô∏èüíÄ

‚ö†Ô∏è Confidence: 88.00% - EXPERT VERIFICATION REQUIRED

üîç For better identification, please provide additional photos:
- Cap underside (gills/pores)
- Full stem with base
- Growing habitat and surroundings

Clear photos from multiple angles help distinguish similar species.

Name: Fly Agaric (Amanita muscaria)
Confidence: 88.00%

Safety Status:
‚ö†Ô∏è POISONOUS
This mushroom is considered toxic and can cause unpleasant or dangerous effects.

Key Identification Features:
- Cap: Red, with white to yellow scales.
- Stem: White, with shaggy rings and a bulbous base.
- Gills/Pores: White.
- Distinctive traits: White patches on a red cap, ring on stem, bulbous base.

Location, Habitat & Season:
- Geographic range: North America, Europe, Asia.
- Habitat: Birch woods, mixed woodland.
- Season: August to December.

Look-alikes:
- P

In [26]:
question = "Can you show me images of amanita cibarius?"
print(f"{question}\n")
response = agent.send_message(question, verbose=True)
print(response)
print("\n--------------------------------------------------\n")

Can you show me images of amanita cibarius?


üîç Found 9 relevant documents:
  - Amanita crocea Description: ThisAmanitararely seems to have remains of the veil ... (score: 0.474)
  - Amanita citrina var. citrina Description: Although this mushroom looks like a de... (score: 0.455)
  - Amanita citrina var. alba Description: Although this mushroom looks like a deadl... (score: 0.449)
  - Amanita muscaria Scientific name: Amanita muscaria Common name: Fly Amanita Fami... (score: 0.569)
  - Amanita muscaria Description: A beautiful and photogenic mushroom that it is con... (score: 0.559)
  - Amanita rubescens Description: Common and available before many other species ar... (score: 0.632)
  - Amanita rubescens Scientific name: Amanita rubescens Common name: Blushing Amani... (score: 0.604)
  - Amanita pantherina Scientific name: Amanita pantherina Common name: Panther Aman... (score: 0.617)
  - Amanita pantherina Description: An exciting find, this visually striking mushroo... (score: 0

In [None]:
# question = "Is it safe to eat?"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=True)
# print(response)
# print("\n--------------------------------------------------\n")

In [None]:
# question = "In what season can we find this mushroom?"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=True)
# print(response)
# print("\n--------------------------------------------------\n")

In [None]:
# question = "Tell me more about alternative predictions"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=True)
# print(response)
# print("\n--------------------------------------------------\n")

In [None]:
# question = "Tell me more about lookalikes"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=True)
# print(response)
# print("\n--------------------------------------------------\n")

In [None]:
# question = "The best recipe for this mushroom"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=True)
# print(response)
# print("\n--------------------------------------------------\n")

In [None]:
# question = "Now tell me about Boletus edulis"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=True)
# print(response)
# print("\n--------------------------------------------------\n")

In [None]:
# question = "Tell me about its lookalikes"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=True)
# print(response)
# print("\n--------------------------------------------------\n")