In [2]:
!pip install google-generativeai ipywidgets numpy





In [9]:
import google.generativeai as genai
import json
import numpy as np
from typing import List, Dict, Any
import os
from datetime import datetime
from IPython.display import display, HTML, Markdown
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
import warnings
warnings.filterwarnings('ignore')

class MathToolsRAG:
    def __init__(self, api_key: str):
        """Initialize the RAG system with Gemini API"""
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel('gemini-2.0-flash-exp')
        self.embedding_model = genai.GenerativeModel('text-embedding-004')
        
        # Initialize storage
        self.chat_history = []
        self.knowledge_base = []
        self.embeddings_cache = {}
        
        # Define available tools
        self.tools = [
            genai.protos.Tool(
                function_declarations=[
                    genai.protos.FunctionDeclaration(
                        name="basic_math_operation",
                        description="Performs basic mathematical operations on two natural numbers",
                        parameters=genai.protos.Schema(
                            type=genai.protos.Type.OBJECT,
                            properties={
                                "num1": genai.protos.Schema(
                                    type=genai.protos.Type.NUMBER,
                                    description="First natural number"
                                ),
                                "num2": genai.protos.Schema(
                                    type=genai.protos.Type.NUMBER,
                                    description="Second natural number"
                                ),
                                "operation": genai.protos.Schema(
                                    type=genai.protos.Type.STRING,
                                    description="Operation to perform: addition, subtraction, multiplication, or division",
                                    enum=["addition", "subtraction", "multiplication", "division"]
                                )
                            },
                            required=["num1", "num2", "operation"]
                        )
                    ),
                    genai.protos.FunctionDeclaration(
                        name="array_sum_operation",
                        description="Creates an array with specified number of elements and calculates their sum",
                        parameters=genai.protos.Schema(
                            type=genai.protos.Type.OBJECT,
                            properties={
                                "array_size": genai.protos.Schema(
                                    type=genai.protos.Type.INTEGER,
                                    description="Number of elements in the array"
                                ),
                                "elements": genai.protos.Schema(
                                    type=genai.protos.Type.ARRAY,
                                    items=genai.protos.Schema(type=genai.protos.Type.NUMBER),
                                    description="Array elements provided by user"
                                )
                            },
                            required=["array_size", "elements"]
                        )
                    )
                ]
            )
        ]
        
    
    def generate_embedding(self, text: str) -> List[float]:
        """Generate embedding for given text using Gemini embedding model"""
        try:
            if text in self.embeddings_cache:
                return self.embeddings_cache[text]
            
            # Use Gemini's embedding model
            result = genai.embed_content(
                model="models/text-embedding-004",
                content=text,
                task_type="retrieval_document"
            )
            embedding = result['embedding']
            self.embeddings_cache[text] = embedding
            return embedding
        except Exception as e:
            print(f"Error generating embedding: {e}")
            return []
    
    def add_to_knowledge_base(self, query: str, result: Dict[str, Any]):
        """Add query-result pair to knowledge base with embeddings"""
        knowledge_entry = {
            "query": query,
            "result": result,
            "timestamp": datetime.now().isoformat(),
            "embedding": self.generate_embedding(query)
        }
        self.knowledge_base.append(knowledge_entry)
        print(f"Added to knowledge base. Total entries: {len(self.knowledge_base)}")
    
    def retrieve_similar_queries(self, query: str, top_k: int = 3) -> List[Dict]:
        """Retrieve similar queries from knowledge base using embeddings"""
        if not self.knowledge_base:
            return []
        
        query_embedding = self.generate_embedding(query)
        if not query_embedding:
            return []
        
        similarities = []
        for entry in self.knowledge_base:
            if entry.get('embedding'):
                # Calculate cosine similarity
                similarity = np.dot(query_embedding, entry['embedding']) / (
                    np.linalg.norm(query_embedding) * np.linalg.norm(entry['embedding'])
                )
                similarities.append((similarity, entry))
        
        # Sort by similarity and return top_k
        similarities.sort(key=lambda x: x[0], reverse=True)
        similar_queries = [entry for _, entry in similarities[:top_k]]
        
        if similar_queries:
            print(f"Found {len(similar_queries)} similar queries in knowledge base")
        
        return similar_queries
    
    def basic_math_operation(self, num1: float, num2: float, operation: str) -> Dict[str, Any]:
        """Perform basic mathematical operation"""
        try:
            num1, num2 = int(abs(num1)), int(abs(num2))  # Ensure natural numbers
            
            if operation == "addition":
                result = num1 + num2
            elif operation == "subtraction":
                result = abs(num1 - num2)
            elif operation == "multiplication":
                result = num1 * num2
            elif operation == "division":
                if num2 == 0:
                    return {"error": "Division by zero is not allowed", "success": False}
                result = num1 / num2
            else:
                return {"error": f"Unknown operation: {operation}", "success": False}
            
            return {
                "operation": operation,
                "num1": num1,
                "num2": num2,
                "result": result,
                "success": True
            }
        except Exception as e:
            return {"error": str(e), "success": False}
    
    def array_sum_operation(self, array_size: int, elements: List[float]) -> Dict[str, Any]:
        """Calculate sum of array elements"""
        try:
            if len(elements) != array_size:
                return {"error": f"Expected {array_size} elements, got {len(elements)}", "success": False}
            
            array_sum = sum(elements)
            return {
                "array_size": array_size,
                "elements": elements,
                "sum": array_sum,
                "success": True
            }
        except Exception as e:
            return {"error": str(e), "success": False}
    
    def process_function_call(self, function_call):
        """Process function calls from Gemini"""
        function_name = function_call.name
        function_args = {}
        
        # Convert RepeatedComposite and other Gemini types to Python types
        for key, value in function_call.args.items():
            if hasattr(value, '__iter__') and not isinstance(value, (str, bytes)):
                # Convert arrays/lists
                function_args[key] = list(value)
            else:
                function_args[key] = value
        
        print(f"Executing tool: {function_name}")
        try:
            print(f"Arguments: {json.dumps(function_args, indent=2)}")
        except TypeError:
            # Fallback if still can't serialize
            print(f"Arguments: {function_args}")
        
        if function_name == "basic_math_operation":
            return self.basic_math_operation(**function_args)
        elif function_name == "array_sum_operation":
            return self.array_sum_operation(**function_args)
        else:
            return {"error": f"Unknown function: {function_name}", "success": False}
    
    def chat_with_tools(self, user_input: str) -> str:
        """Main chat function with tool calling and RAG"""
        try:
            print(f"\nProcessing query: '{user_input}'")
            print("=" * 60)
            
            # Retrieve similar queries from knowledge base
            similar_queries = self.retrieve_similar_queries(user_input)
            
            # similar queries
            context = ""
            if similar_queries:
                context = "\n\nPrevious similar interactions (for context):\n"
                for i, entry in enumerate(similar_queries[:2], 1):
                    try:
                        context += f"{i}. Query: {entry['query']}\n   Result: {str(entry['result'])}\n"
                    except:
                        context += f"{i}. Query: {entry['query']}\n   Result: [Previous calculation]\n"
            
            # Prompt with RAG context
            enhanced_prompt = f"""
            User query: {user_input}
            
            {context}
            
            You have access to two tools:
            1. basic_math_operation: For performing math operations (addition, subtraction, multiplication, division) on two natural numbers
            2. array_sum_operation: For creating arrays and calculating their sum
            
            Instructions:
            - If the user wants to perform both operations, use both tools and provide the final sum
            - If user references "previous", "same as before", or "last time", check the context above
            - If user says "double the array" or similar, look for arrays in context and modify them
            - Always use the appropriate tools when mathematical operations are requested
            - Provide explanations for your calculations
            """
            
            # Generate response
            response = self.model.generate_content(
                enhanced_prompt,
                tools=self.tools
            )
            
            # Process function calls if any
            function_results = []
            if response.candidates[0].content.parts:
                for part in response.candidates[0].content.parts:
                    if hasattr(part, 'function_call') and part.function_call:
                        function_result = self.process_function_call(part.function_call)
                        function_results.append(function_result)
            
            # If we have function results, generate final response
            if function_results:
                # Store in knowledge base
                self.add_to_knowledge_base(user_input, {
                    "function_results": function_results,
                    "type": "tool_calling"
                })
                
                # Calculate final result if both functions were called
                if len(function_results) == 2:
                    result1 = function_results[0].get('result', 0) if function_results[0].get('success') else 0
                    result2 = function_results[1].get('sum', 0) if function_results[1].get('success') else 0
                    final_sum = result1 + result2
                    
                    print("\n" + "=" * 60)
                    print("RESULTS:")
                    print("=" * 60)
                    print(f"Function 1 (Basic Math): {json.dumps(function_results[0], indent=2)}")
                    print(f"Function 2 (Array Sum): {json.dumps(function_results[1], indent=2)}")
                    print(f"Final Sum: {result1} + {result2} = {final_sum}")
                    
                    # Add to chat history
                    self.chat_history.append({
                        "user": user_input,
                        "function_results": function_results,
                        "final_sum": final_sum,
                        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    })
                    
                    return f"Final Sum: {final_sum}"
                else:
                    result = function_results[0]
                    print("\n" + "=" * 60)
                    print("RESULT:")
                    print("=" * 60)
                    print(f"Function Result: {json.dumps(result, indent=2)}")
                    
                    # Add to chat history
                    self.chat_history.append({
                        "user": user_input,
                        "function_results": function_results,
                        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    })
                    
                    return json.dumps(result, indent=2)
            
            # If no function calls, return regular response
            print(f"\nResponse: {response.text}")
            return response.text
            
        except Exception as e:
            error_msg = f"Error: {str(e)}"
            print(error_msg)
            return error_msg
    
    def show_knowledge_base_stats(self):
        """Display knowledge base statistics"""
        print("\nKNOWLEDGE BASE STATS:")
        print("=" * 40)
        print(f"Stored Queries: {len(self.knowledge_base)}")
        print(f"Cached Embeddings: {len(self.embeddings_cache)}")
        print(f"Chat History: {len(self.chat_history)}")
        
        if self.knowledge_base:
            print("\nRecent Queries:")
            for i, entry in enumerate(self.knowledge_base[-3:], 1):
                print(f"  {i}. {entry['query'][:50]}..." if len(entry['query']) > 50 else f"  {i}. {entry['query']}")
    
    def show_chat_history(self):
        """Display chat history"""
        if not self.chat_history:
            print("No chat history yet.")
            return
            
        print("\nCHAT HISTORY:")
        print("=" * 50)
        for i, chat in enumerate(self.chat_history[-5:], 1):
            print(f"\n{i}. [{chat['timestamp']}]")
            print(f"   User: {chat['user']}")
            if 'final_sum' in chat:
                print(f"   Result: Final Sum = {chat['final_sum']}")
            else:
                print(f"   Result: {chat['function_results'][0].get('result', 'N/A')}")
    
    def clear_knowledge_base(self):
        """Clear all stored data"""
        self.knowledge_base = []
        self.embeddings_cache = {}
        self.chat_history = []
        print("Knowledge base cleared!")

def create_interactive_app():
    # API Key input
    api_key_widget = widgets.Password(
        placeholder='Enter your Gemini API key',
        description='API Key:',
        style={'description_width': 'initial'}
    )
    
    # Query input
    query_widget = widgets.Textarea(
        placeholder='e.g., "Multiply 12 and 8, then create an array of 4 elements: [1,2,3,4] and sum them"',
        description='Query:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='100%', height='100px')
    )
    
    process_button = widgets.Button(
        description='Process Query',
        button_style='success',
        icon='play'
    )
    
    stats_button = widgets.Button(
        description='Show Stats',
        button_style='info',
        icon='info'
    )
    
    history_button = widgets.Button(
        description='Show History',
        button_style='warning',
        icon='history'
    )
    
    clear_button = widgets.Button(
        description='Clear Data',
        button_style='danger',
        icon='trash'
    )
    
    output = widgets.Output()
    
    # Global variable to store the RAG system
    rag_system = None
    
    def on_process_click(b):
        with output:
            if not api_key_widget.value:
                print("Please enter your Gemini API key first!")
                return
            
            if not query_widget.value.strip():
                print("Please enter a query!")
                return
            
            nonlocal rag_system
            if rag_system is None:
                try:
                    rag_system = MathToolsRAG(api_key_widget.value)
                except Exception as e:
                    print(f"Error initializing system: {e}")
                    return
            
            rag_system.chat_with_tools(query_widget.value)
    
    def on_stats_click(b):
        with output:
            if rag_system:
                rag_system.show_knowledge_base_stats()
            else:
                print("Please process a query first!")
    
    def on_history_click(b):
        with output:
            if rag_system:
                rag_system.show_chat_history()
            else:
                print("No chat history yet!")
    def on_clear_click(b):
        with output:
            if rag_system:
                rag_system.clear_knowledge_base()
            else:
                print("Nothing to clear!")
    
    process_button.on_click(on_process_click)
    stats_button.on_click(on_stats_click)
    history_button.on_click(on_history_click)
    clear_button.on_click(on_clear_click)
    
    # Layout
    button_box = widgets.HBox([process_button, stats_button, history_button, clear_button])
    
    # Display interface
    display(HTML("<h1>YouData.ai Assignment</h1>"))
    display(HTML("<hr>"))
    display(api_key_widget)
    display(query_widget)
    display(button_box)
    display(output)
    

# Create and display the interactive interface
create_interactive_app()

Password(description='API Key:', placeholder='Enter your Gemini API key', style=DescriptionStyle(description_w…

Textarea(value='', description='Query:', layout=Layout(height='100px', width='100%'), placeholder='e.g., "Mult…

HBox(children=(Button(button_style='success', description='Process Query', icon='play', style=ButtonStyle()), …

Output()