In [1]:
import os
from datetime import datetime
from dotenv import load_dotenv
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_cohere import CohereEmbeddings
from langchain_core.messages import AIMessage
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langgraph.graph import StateGraph, END, MessageGraph
from langgraph.prebuilt.tool_node import ToolNode, tools_condition
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage, add_messages
import uuid
from IPython.display import Image, display
import sqlite3
import json
from contextlib import contextmanager
from typing import Any, Iterator, List, Optional, Tuple

# Set up environment
load_dotenv()

# Initialize API keys and models
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT")
ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE")

GPT3 = 'gpt-3.5-turbo'
COHERE = 'embed-english-light-v3.0'

embeddings = CohereEmbeddings(model=COHERE)
llm = ChatOpenAI(model=GPT3, temperature=0)

# Custom memory class implementation
class CohereEmbeddings:
    def __init__(self, model='embed-english-light-v3.0'):
        self.client = cohere.Client(COHERE_API_KEY)  # Replace with your Cohere API key
        self.model = model

    def embed(self, texts):
        response = self.client.embed(texts=texts, model=self.model)
        return response.embeddings

class ConversationSummaryBufferMemory:
    def __init__(self, embeddings_model, max_token_limit=100):
        self.embeddings_model = embeddings_model
        self.max_token_limit = max_token_limit
        self.buffer = []
        self.summaries = []
        self.token_count = 0

    def save_context(self, user_input, ai_output):
        interaction = f"Human: {user_input} AI: {ai_output}"
        tokens = interaction.split()
        self.token_count += len(tokens)

        self.buffer.append(interaction)

        if self.token_count > self.max_token_limit:
            self.summarize_buffer()

    def summarize_buffer(self):
        summary = " ".join(self.buffer)
        self.summaries.append(summary)
        self.buffer = []
        self.token_count = 0

    def embed_summary(self):
        return self.embeddings_model.embed(self.summaries)

    def retrieve_relevant_summary(self, query):
        query_embedding = self.embeddings_model.embed([query])
        summary_embeddings = self.embed_summary()
        similarities = cosine_similarity(query_embedding, summary_embeddings)

        most_similar_index = np.argmax(similarities)
        return self.summaries[most_similar_index]

    def load_memory_variables(self):
        history = "\n".join(self.buffer + self.summaries)
        return {'history': history}

memory = ConversationSummaryBufferMemory(embeddings_model=embeddings, max_token_limit=40)

# Define custom tools
@tool
def magic_function(input: int) -> int:
    return input + 2

@tool
def fetch_weather_data(lat: float, lon: float) -> str:
    url = f"https://api.met.no/weatherapi/locationforecast/2.0/compact?lat={lat}&lon={lon}"
    headers = {"User-Agent": "MyTestApp/0.1"}
    response = requests.get(url, headers=headers)
    if response.status_code == 200:
        return json.dumps(response.json(), indent=4)
    else:
        return f"Error fetching data: {response.status_code}"

@tool
def fetch_electricity_prices(year: str, month: str, day: str, region: str) -> str:
    url = f"https://www.hvakosterstrommen.no/api/v1/prices/{year}/{month}-{day}_{region}.json"
    response = requests.get(url)
    if response.status_code == 200:
        return json.dumps(response.json(), indent=4)
    else:
        return f"Error fetching data: {response.status_code}"

@tool
def web_search(input: str) -> str:
    web_search_tool = TavilySearchResults()
    docs = web_search_tool.invoke({"query": input})
    return docs

custom_tools = [magic_function, web_search, fetch_electricity_prices, fetch_weather_data]

# Initialize SQL database
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///test.db")

# Initialize SQL tools
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
sql_tools = toolkit.get_tools()

tools = custom_tools + sql_tools

# Define the primary assistant prompt
primary_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant with three tools: (1) web search, (2) weather data fetch, and (3) electricity price fetch. Use web search for current events and queries needing real-time information. Use the weather data fetch tool for any weather-related questions. Use the electricity price fetch tool for queries related to electricity prices. For database-related questions, follow the SQL database interaction instructions. Otherwise, answer directly. Current time: {time}."),
        ("placeholder", "{messages}"),
        ("system", """
         ROLE:
         You are an agent designed to interact with various data sources using specific tools. You have access to tools for fetching weather data, electricity prices, and interacting with a SQL database.
         GOAL:
         Given an input question, determine which tool to use based on the nature of the question, fetch the necessary data, and return the answer.
         INSTRUCTIONS:
         - Only use the below tools for the specified operations.
         - Only use the information returned by the below tools to construct your final answer.
         - For weather-related questions, use the weather data fetch tool.
         - For queries about electricity prices, use the electricity price fetch tool.
         - For current events or real-time information, use the web search tool.
         - For database-related questions, follow the SQL database interaction instructions.
         """),
        ("system", """
         SQL DATABASE INTERACTION INSTRUCTIONS:
         - Only use the below tools for the following operations.
         - Only use the information returned by the below tools to construct your final answer.
         - To start you should ALWAYS look at the tables in the database to see what you can query. Do NOT skip this step.
         - Then you should query the schema of the most relevant tables.
         - Write your query based upon the schema of the tables. You MUST double check your query before executing it.
         - Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
         - You can order the results by a relevant column to return the most interesting examples in the database.
         - Never query for all the columns from a specific table, only ask for the relevant columns given the question.
         - If you get an error while executing a query, rewrite the query and try again.
         - If the query returns a result, use check_result tool to check the query result.
         - If the query result is empty, think about the table schema, rewrite the query, and try again.
         - DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
         """)
    ]
).partial(time=datetime.now())

# Define LLM chain
assistant_runnable = primary_assistant_prompt | llm.bind_tools(tools)

# Define the State type
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

# Assistant class
from langchain_core.runnables import RunnableLambda, RunnableConfig, Runnable

class Assistant:
    def __init__(self, runnable: Runnable):
        self.runnable = runnable

    def __call__(self, state: State, config: RunnableConfig):
        while True:
            result = self.runnable.invoke(state)
            if not result.tool_calls and (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "Respond with a real output.")]
                state = {**state, "messages": messages}
            else:
                break
        return {"messages": result}

# ToolNode with fallback
def create_tool_node_with_fallback(tools: list) -> dict:
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )

# Handle tool error
def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"]
            ) for tc in tool_calls
        ]
    }

assistant = Assistant(assistant_runnable)

# Build the graph
builder = StateGraph(State)
builder.add_node("assistant", RunnableLambda(assistant))
builder.add_node("tools", create_tool_node_with_fallback(tools))

# Define edges
builder.set_entry_point("assistant")
builder.add_conditional_edges(
    "assistant",
    tools_condition,
    {"tools": "tools", END: END},
)
builder.add_edge("tools", "assistant")

# Compile the graph
graph = builder.compile()

# Display the graph
try:
    display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except Exception as e:
    print(f"Failed to display graph: {e}")

# Test the assistant
questions = [
    "What is magic_function(3)",
    "Do you think the electricity prices yesterday was a biproduct of the weather and temperature yesterday, the location is Lillehammer?",
]

_printed = set()
thread_id = str(uuid.uuid4())

config = {
    "configurable": {
        "thread_id": thread_id,
    }
}

events = graph.stream(
    {"messages": [("user", questions[1])]}, config, stream_mode="values"
)

for event in events:
    _print_event(event, _printed)


ValidationError: 1 validation error for CohereEmbeddings
__root__
  Did not find cohere_api_key, please add an environment variable `COHERE_API_KEY` which contains it, or pass `cohere_api_key` as a named parameter. (type=value_error)