# RAG Agent

In [2]:
from typing import List, Dict, Tuple
import numpy as np
from sentence_transformers import SentenceTransformer
import json




## Upload knowledge base

In [3]:
def prepare_knowledge_base(file_paths):
    knowledge_base = []

    for file_path in file_paths:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        documents = []

        for mushroom_name, mushroom_info in data.items():
            doc_parts = [mushroom_name]
            for label, value in mushroom_info.items():
                doc_parts.append(f"{label}: {value}")

            full_document = "\n".join(doc_parts)
            documents.append(full_document)

        knowledge_base.extend(documents)
    return knowledge_base

In [4]:
file_paths = [
    "../Knowledge_base/mushroom_guide.json",
    "../Knowledge_base/mushroom_world.json",
    "../Knowledge_base/wikipedia.json",
]

knowledge_base = prepare_knowledge_base(file_paths)

## Load Embeddings

In [5]:
embeddings = np.load(f"../App/Knowledge_base/embeddings_all.npy")

## Create Embeddings

In [4]:
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(knowledge_base)

np.save("../App/Knowledge_base/embeddings_all.npy", embeddings)

## Agent

In [24]:
import google.genai as genai
from google.genai.types import GenerateContentConfig

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

class MushroomRAGAgent:
    def __init__(self, api_key: str, knowledge_base: List[str],
                 model_name: str, embedding_model: str, embeddings: np.ndarray = None):

        self.system_instructions = """You are an expert mycologist - a mushroom specialist.

            CRITICAL RULES:
            1. Answer ONLY based on the provided MUSHROOM KNOWLEDGE BASE documents
            2. If information is not in the provided documents, clearly state: "This information is not available in my knowledge base"
            3. For POISONOUS/TOXIC mushrooms, emphasize warnings STRONGLY with emojis (üö®‚ö†Ô∏èüíÄ)
            4. Never make up information - only use what's in the provided documents
            5. If the species is not specified in question answer about primary species or alternative species

            OUTPUT FORMAT:
            - For the FIRST identification query: Use the structured format below
            - For follow-up questions: Respond conversationally and concisely


            STRUCTURED FORMAT (first query only):
            Name: [Common Name] ([Scientific Name])
            Confidence: [Confidence]

            Safety Status:
            ‚úÖ EDIBLE / ‚ö†Ô∏è POISONOUS / üíÄ DEADLY POISONOUS
            [Brief safety note]

            Key Identification Features:
            - Cap: [description]
            - Stem: [description]
            - Gills/Pores: [description]
            - Distinctive traits: [unique features]

            Location, Habitat & Season:
            - Geographic range: [location]
            - Habitat: [where it grows]
            - Season: [when it appears]

            Look-alikes:
            - [Any dangerous similar species]

            Alternative predictions:
            - [All alternative predictions]

            Keep all descriptions very concise - only few words per point.
        """

        # Configure Gemini with new API
        self.client = genai.Client(api_key=api_key)
        self.model_name = model_name

        # Initialize chat with automatic history management ‚úÖ
        self.chat = self.client.chats.create(model=model_name,
                                             config=GenerateContentConfig(system_instruction=self.system_instructions))

        # Knowledge base
        self.knowledge_base = knowledge_base

        # Load semantic embedding model
        self.embedding_model = SentenceTransformer(embedding_model)

        # Create semantic embeddings for documents
        if embeddings is None:
            self.doc_embeddings = self.embedding_model.encode(knowledge_base)
        else:
            self.doc_embeddings = embeddings

        # Track current identification
        self.current_identification = None

        # Track first retrieved documents
        self.first_retrieved_docs = None


    def _retrieve_relevant_docs(self, query: str, top_k: int = 3) -> List[Dict[str, any]]:
        # Create embedding for the query
        query_embedding = self.embedding_model.encode(query)

        # Calculate semantic similarity with all documents
        similarities = []
        for doc_embedding in self.doc_embeddings:
            sim = cosine_similarity(query_embedding, doc_embedding)
            similarities.append(float(sim))
        similarities = np.array(similarities)

        # Get top-k most similar documents
        top_indices = np.argsort(similarities)[-top_k:][::-1]

        results = []
        for idx in top_indices:
            if similarities[idx] > 0:
                results.append({
                    "document": self.knowledge_base[idx],
                    "similarity": float(similarities[idx]),
                    "index": int(idx)
                })
        return results


    def _build_context(self, relevant_docs: List[Dict], query: str) -> str:
        context_parts = [
            f"{self.system_instructions}\n",
            "\n=== MUSHROOM KNOWLEDGE BASE (Retrieved Documents) ===\n"
        ]

        for i, doc_data in enumerate(relevant_docs, 1):
            doc = doc_data["document"]
            score = doc_data["similarity"]

            context_parts.append(f"\n--- Document {i} (relevance: {score:.3f}) ---")
            context_parts.append(doc)
            context_parts.append("---\n")

        context_parts.append(f"\n=== USER QUESTION ===")
        context_parts.append(f"{query}\n")
        context_parts.append("\nYour response (following all rules):")

        return "\n".join(context_parts)

    def send_message(self, user_message: str, top_k: int = 3, verbose: bool = False) -> str:
        relevant_docs = self._retrieve_relevant_docs(user_message, top_k=top_k)

        # ADD CONTEXT FROM PREVIOUS QUERY
        current_indices = {doc['index'] for doc in relevant_docs}
        for prev_doc in self.first_retrieved_docs:
            if prev_doc['index'] not in current_indices:
                relevant_docs.append(prev_doc)

        # AUGMENTATION
        context = self._build_context(relevant_docs, user_message)

        if verbose:
            print(f"\nüîç Found {len(relevant_docs)} relevant documents:")
            for doc_data in relevant_docs:
                doc_preview = doc_data['document'][:80].replace('\n', ' ')
                print(f"  - {doc_preview}... (score: {doc_data['similarity']:.3f})")

            print(f"\nüìù Context length: {len(context)} characters")
            history = self.chat.get_history()
            print(f"üí¨ Chat history length: {len(history)} messages")

        # GENERATION (chat automatically handles history!) ‚úÖ
        try:
            response = self.chat.send_message(message=context)
            return response.text
        except Exception as e:
            return f"Error: {str(e)}"

    def clear_history(self):
        self.chat = self.client.chats.create(model=self.model_name)
        self.first_retrieved_docs = []
        print("‚úì Conversation history cleared")

    def get_history(self) -> List:
        return self.chat.get_history()


    def initialize_from_predictions(self, predictions: Dict[str, float], verbose: bool = False) -> str:

        # Sort by confidence
        sorted_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)

        # Retrieve documents for predicted species
        relevant_docs = []
        for species, confidence in sorted_predictions:
            relevant_docs.extend(self._retrieve_relevant_docs(species, top_k=2))

        self.first_retrieved_docs = relevant_docs

        if verbose:
            print(f"\nüîç Retrieved documents for all candidates:")
            for doc_data in relevant_docs[:5]:
                doc_preview = doc_data['document'][:80].replace('\n', ' ')
                print(f"  - {doc_preview}... (score: {doc_data['similarity']:.3f})")

        # Build context with special instructions
        context = self._build_identification_context(relevant_docs, sorted_predictions)

        # Generate response using chat (maintains history automatically)
        try:
            response = self.chat.send_message(message=context)

            #Store current identification
            self.current_identification = {
                'primary': sorted_predictions[0],
                'alternatives': sorted_predictions[1:],
            }

            return response.text
        except Exception as e:
            return f"Error: {str(e)}"


    def _build_identification_context(self, relevant_docs: List[Dict],
                                      predictions: List[Tuple[str, float]]) -> str:

        primary_species, primary_conf = predictions[0]

        context_parts = [
            f"{self.system_instructions}\n",
            "\n=== COMPUTER VISION IDENTIFICATION RESULTS ===\n",
            f"PRIMARY PREDICTION: {primary_species} (Confidence: {primary_conf:.2%})\n",
        ]

        if len(predictions) > 1:
            context_parts.append("ALTERNATIVE PREDICTIONS:")
            for i, (species, conf) in enumerate(predictions[1:], 2):
                context_parts.append(f"{species} (Confidence: {conf:.2%})")

        context_parts.append("\n=== MUSHROOM KNOWLEDGE BASE ===\n")

        for i, doc_data in enumerate(relevant_docs, 1):
            doc = doc_data["document"]
            score = doc_data["similarity"]

            context_parts.append(f"\n--- Document {i} (relevance: {score:.3f}) ---")
            context_parts.append(doc)
            context_parts.append("---\n")

        context_parts.append("\n=== TASK ===")
        context_parts.append(
            f"Provide a detailed identification card for {primary_species} "
            f"(the primary prediction with {primary_conf:.2%} confidence).\n"
        )

        if primary_conf < 0.90:
            context_parts.append(
                f"""Tell the user this warning:

                ‚ö†Ô∏è Confidence: {primary_conf:.2%} - EXPERT VERIFICATION REQUIRED

                üîç For better identification, please provide additional photos:
                - Cap underside (gills/pores)
                - Full stem with base
                - Growing habitat and surroundings

                Clear photos from multiple angles help distinguish similar species."""
            )

        if len(predictions) > 1:
            context_parts.append(
                "\n‚ö†Ô∏è IMPORTANT: Also check if any alternative predictions "
                f"({', '.join([s for s, _ in predictions[1:]])}) are dangerous species. "
                "If so, add a warning section at the very beginning before identification card and inform that this alternative species is in predictions.\n"
            )

        return "\n".join(context_parts)

## Load API KEY

In [25]:
import os
from dotenv import load_dotenv

load_dotenv()
API_KEY = os.getenv("API_KEY")

## Initialization of agent

In [26]:
agent = MushroomRAGAgent(
    api_key=API_KEY,
    knowledge_base=knowledge_base,
    model_name="gemini-2.5-flash-lite",
    embedding_model="all-MiniLM-L6-v2",
    embeddings=embeddings
)

## TESTING

In [27]:
predictions = {
    "Boletus edulis": 0.97,
    "Boletus aereus": 0.02,
    "Boletus reticulatus": 0.01
}
# predictions = {
#     "Amanita muscaria": 0.88,
#     "Amanita rubescens": 0.06,  # Edible
#     "Amanita pantherina": 0.01  # TOXIC
# }
predictions = {
    "Agaricus campestris": 0.88,
    "Amanita phalloides": 0.07,  # DEADLY
    "Agaricus arvensis": 0.01
}

In [28]:
agent.clear_history()

response = agent.initialize_from_predictions(predictions, verbose=False)
print(response)
print("\n--------------------------------------------------\n")
#
# question = "Are they safe to eat?"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=False)
# print(response)
# print("\n--------------------------------------------------\n")
#
# question = "It is March right now, it is possible that it is this mushroom?"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=False)
# print(response)
# print("\n--------------------------------------------------\n")
#
# question = "Tell me the best recipe with this mushroom"
# print(f"{question}\n")
# response = agent.send_message(question, verbose=False)
# print(response)
# print("\n--------------------------------------------------\n")

‚úì Conversation history cleared
Amanita phalloides is a DEADLY POISONOUS mushroom. üö®‚ö†Ô∏èüíÄ If any of your alternative predictions include a deadly poisonous species, it is crucial to exercise extreme caution.

Name: Meadow Mushroom (Agaricus campestris)
Confidence: 88.00% - EXPERT VERIFICATION REQUIRED

üîç For better identification, please provide additional photos:
- Cap underside (gills/pores)
- Full stem with base
- Growing habitat and surroundings

Clear photos from multiple angles help distinguish similar species.

Safety Status:
‚úÖ EDIBLE
This mushroom should be cooked before consumption.

Key Identification Features:
- Cap: White, sometimes with fine scales, hemispherical to flat
- Stem: Short, cylindric, white with an indistinct ring
- Gills/Pores: Free, pale pink to dark purple-brown
- Distinctive traits: Flesh is thick and white, not changing when sliced. Often found in fairy rings.

Location, Habitat & Season:
- Geographic range: Europe, North America, North Afric

In [29]:
question = "Is it safe to eat?"
print(f"{question}\n")
response = agent.send_message(question, verbose=True)
print(response)
print("\n--------------------------------------------------\n")

Is it safe to eat?


üîç Found 9 relevant documents:
  - Gyromitra esculenta Introduction: Gyromitra esculenta  is an ascomycete fungus f... (score: 0.324)
  - Amanita phalloides Introduction: Amanita phalloides ( AM-…ô-NY-t…ô f…ô-LOY-deez), c... (score: 0.307)
  - Fuligo septica Scientific name: Fuligo septica Common name: Scrambled Egg Slime ... (score: 0.280)
  - Agaricus campestris Scientific name: Agaricus campestris Common name: Meadow Mus... (score: 0.624)
  - Agaricus campestris Description: This used to be a very common mushroom but is b... (score: 0.618)
  - Amanita phalloides Scientific name: Amanita phalloides Common name: Death Cap Am... (score: 0.652)
  - Amanita phalloides Description: An innocuous looking mushroom that is among the ... (score: 0.604)
  - Agaricus arvensis Description: A great mushroom with a rich, strong taste and as... (score: 0.592)
  - Agaricus arvensis Scientific name: Agaricus arvensis Common name: Horse Mushroom... (score: 0.584)

üìù Context le

In [30]:
question = "In what season can we find this mushroom?"
print(f"{question}\n")
response = agent.send_message(question, verbose=True)
print(response)
print("\n--------------------------------------------------\n")

In what season can we find this mushroom?


üîç Found 9 relevant documents:
  - Cerioporus squamosus Description: The largest capped mushroom in the UK starting... (score: 0.574)
  - Hericium erinaceus Description: A Europe-wide but rare and unique looking mushro... (score: 0.551)
  - Agrocybe praecox Description: A common mushroom in Spring and early Summer but c... (score: 0.544)
  - Agaricus campestris Scientific name: Agaricus campestris Common name: Meadow Mus... (score: 0.624)
  - Agaricus campestris Description: This used to be a very common mushroom but is b... (score: 0.618)
  - Amanita phalloides Scientific name: Amanita phalloides Common name: Death Cap Am... (score: 0.652)
  - Amanita phalloides Description: An innocuous looking mushroom that is among the ... (score: 0.604)
  - Agaricus arvensis Description: A great mushroom with a rich, strong taste and as... (score: 0.592)
  - Agaricus arvensis Scientific name: Agaricus arvensis Common name: Horse Mushroom... (score: 0.5

In [31]:
question = "Tell me more about alternative predictions"
print(f"{question}\n")
response = agent.send_message(question, verbose=True)
print(response)
print("\n--------------------------------------------------\n")

Tell me more about alternative predictions


üîç Found 9 relevant documents:
  - Morchella importuna Description: This popular mushroom is a choice edible but mu... (score: 0.107)
  - Lactarius blennius Description: A common, slimy Milkcap which has a mycorrhizal ... (score: 0.102)
  - Morchella semilibera Description: Not as good a find as a true Morel, but a tast... (score: 0.098)
  - Agaricus campestris Scientific name: Agaricus campestris Common name: Meadow Mus... (score: 0.624)
  - Agaricus campestris Description: This used to be a very common mushroom but is b... (score: 0.618)
  - Amanita phalloides Scientific name: Amanita phalloides Common name: Death Cap Am... (score: 0.652)
  - Amanita phalloides Description: An innocuous looking mushroom that is among the ... (score: 0.604)
  - Agaricus arvensis Description: A great mushroom with a rich, strong taste and as... (score: 0.592)
  - Agaricus arvensis Scientific name: Agaricus arvensis Common name: Horse Mushroom... (score: 0.

In [32]:
question = "Tell me more about lookalikes"
print(f"{question}\n")
response = agent.send_message(question, verbose=True)
print(response)
print("\n--------------------------------------------------\n")

Tell me more about lookalikes


üîç Found 9 relevant documents:
  - Gymnopus fusipes Description: A cluster of mushrooms which don‚Äôt particularly ma... (score: 0.212)
  - Hygrocybe / Cuphophyllus flavipes Description: A fairly common Waxcap of nutrien... (score: 0.206)
  - Lactarius turpis Scientific name: Lactarius turpis Common name: Ugly Milkcap Fam... (score: 0.204)
  - Agaricus campestris Scientific name: Agaricus campestris Common name: Meadow Mus... (score: 0.624)
  - Agaricus campestris Description: This used to be a very common mushroom but is b... (score: 0.618)
  - Amanita phalloides Scientific name: Amanita phalloides Common name: Death Cap Am... (score: 0.652)
  - Amanita phalloides Description: An innocuous looking mushroom that is among the ... (score: 0.604)
  - Agaricus arvensis Description: A great mushroom with a rich, strong taste and as... (score: 0.592)
  - Agaricus arvensis Scientific name: Agaricus arvensis Common name: Horse Mushroom... (score: 0.584)

üìù 

In [33]:
question = "The best recipe for this mushroom"
print(f"{question}\n")
response = agent.send_message(question, verbose=True)
print(response)
print("\n--------------------------------------------------\n")

The best recipe for this mushroom


üîç Found 9 relevant documents:
  - Tricholoma murrillianum Introduction: Tricholoma murrillianum is a species of mu... (score: 0.525)
  - Tricholomopsis rutilans Description: A beautiful looking mushroom that can grow ... (score: 0.500)
  - Tricholoma columbetta Description: The Tricholoma genus is large, varied in colo... (score: 0.499)
  - Agaricus campestris Scientific name: Agaricus campestris Common name: Meadow Mus... (score: 0.624)
  - Agaricus campestris Description: This used to be a very common mushroom but is b... (score: 0.618)
  - Amanita phalloides Scientific name: Amanita phalloides Common name: Death Cap Am... (score: 0.652)
  - Amanita phalloides Description: An innocuous looking mushroom that is among the ... (score: 0.604)
  - Agaricus arvensis Description: A great mushroom with a rich, strong taste and as... (score: 0.592)
  - Agaricus arvensis Scientific name: Agaricus arvensis Common name: Horse Mushroom... (score: 0.584)

üì

In [34]:
question = "Now tell me about Boletus edulis"
print(f"{question}\n")
response = agent.send_message(question, verbose=True)
print(response)
print("\n--------------------------------------------------\n")

Now tell me about Boletus edulis


üîç Found 9 relevant documents:
  - Boletus edulis Scientific name: Boletus edulis Common name: King Bolete Family: ... (score: 0.508)
  - Boletus reticulatus Introduction: Boletus reticulatus (alternately known as Bole... (score: 0.465)
  - Boletus edulis Description: One of the stars of the mushroom world. Easily recog... (score: 0.446)
  - Agaricus campestris Scientific name: Agaricus campestris Common name: Meadow Mus... (score: 0.624)
  - Agaricus campestris Description: This used to be a very common mushroom but is b... (score: 0.618)
  - Amanita phalloides Scientific name: Amanita phalloides Common name: Death Cap Am... (score: 0.652)
  - Amanita phalloides Description: An innocuous looking mushroom that is among the ... (score: 0.604)
  - Agaricus arvensis Description: A great mushroom with a rich, strong taste and as... (score: 0.592)
  - Agaricus arvensis Scientific name: Agaricus arvensis Common name: Horse Mushroom... (score: 0.584)

üìù

In [35]:
question = "Tell me about its lookalikes"
print(f"{question}\n")
response = agent.send_message(question, verbose=True)
print(response)
print("\n--------------------------------------------------\n")

Tell me about its lookalikes


üîç Found 9 relevant documents:
  - Lactarius turpis Scientific name: Lactarius turpis Common name: Ugly Milkcap Fam... (score: 0.274)
  - Gymnopus fusipes Description: A cluster of mushrooms which don‚Äôt particularly ma... (score: 0.270)
  - Hypholoma fasciculare Introduction: Hypholoma fasciculare, commonly known as the... (score: 0.268)
  - Agaricus campestris Scientific name: Agaricus campestris Common name: Meadow Mus... (score: 0.624)
  - Agaricus campestris Description: This used to be a very common mushroom but is b... (score: 0.618)
  - Amanita phalloides Scientific name: Amanita phalloides Common name: Death Cap Am... (score: 0.652)
  - Amanita phalloides Description: An innocuous looking mushroom that is among the ... (score: 0.604)
  - Agaricus arvensis Description: A great mushroom with a rich, strong taste and as... (score: 0.592)
  - Agaricus arvensis Scientific name: Agaricus arvensis Common name: Horse Mushroom... (score: 0.584)

üìù C