In [None]:
import json
import os
import requests
from typing import List
from fastapi import FastAPI
from pydantic import BaseModel, Field
from pydantic_core import from_json
import uvicorn
from starlette.responses import JSONResponse
from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
from gen_ai_hub.proxy.langchain.openai import OpenAIEmbeddings

from langchain.chains import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores.hanavector import HanaDB
from hdbcli import dbapi
import configparser


In [None]:

os.environ["AICORE_CLIENT_ID"] = "client_id"
os.environ["AICORE_CLIENT_SECRET"] = "client_secret"
os.environ["AICORE_AUTH_URL"] = "auth_url"
os.environ["AICORE_BASE_URL"] = "base_url"
os.environ["AICORE_RESOURCE_GROUP"]  = "dev"

config = configparser.ConfigParser()

connection = dbapi.connect(
    address= 'hana_db_url',
    port= '443',
    user='hana_db_user',
    password='hana_db_password',
    autocommit=True,
    sslValidateCertificate=False
)

EMBEDDING_DEPLOYMENT_ID = 'id_of_embedded_model_instance_in_sap_ai_core'
LLM_DEPLOYMENT_ID = 'id_of_llm_model_instance_in_sap_ai_core'

embeddings = OpenAIEmbeddings(deployment_id=EMBEDDING_DEPLOYMENT_ID)
db = HanaDB(
    embedding=embeddings, 
    connection=connection, 
    table_name="VECTOR_TABLE_NAME", 
    vector_column="VEC_VECTOR"
)

# # Define which model to use
chat_llm = ChatOpenAI(deployment_id=LLM_DEPLOYMENT_ID)

app = FastAPI()

# Replace with the API's token URL
TOKEN_URL = "api_token_url"
CLIENT_ID = "api_client_id"
CLIENT_SECRET = "api_client_secret"


class BaseRequest(BaseModel):
    tenantId: str
    agentId: str
    chatId: str
    toolId: str


class MetadataRequest(BaseRequest):
    pass


class CallbackRequest(BaseRequest):
    toolInput: str  # Json string of what the defined agent tool schema


class Config(BaseModel):
    name: str
    value: str


class ConfigChangedRequest(BaseRequest):
    config: List[Config]


class ResourceChangedRequest(ConfigChangedRequest):
    addedOrUpdatedResources: List[str]
    deletedResources: List[str]


class TicketDataInput(BaseModel):
    inputText: str = Field(description="Input ticket text")


class ErrorDetails(BaseModel):
    message: str


class CustomErrorResponse(BaseModel):
    error: ErrorDetails


def get_access_token():
    payload = {
        "grant_type": "client_credentials",
        "client_id": CLIENT_ID,
        "client_secret": CLIENT_SECRET,
    }
    headers = {"Content-Type": "application/x-www-form-urlencoded"}

    response = requests.post(TOKEN_URL, data=payload, headers=headers)

    if response.status_code == 200:
        return response.json().get("access_token")
    else:
        raise Exception(f"Failed to get access token: {response.text}")


@app.get("/")
async def health():
    return {
        "status": "ok",
    }



@app.post("/metadata")
async def on_metadata_fetched(req: MetadataRequest):
    print("Metadata fetched", json.dumps(req.model_dump()))

    return {
        "name": "fetch-ticket-information-tool",
        "description": "This tool can be used to fetch information about tickets.",
        "schema": json.dumps(TicketDataInput.model_json_schema()),
    }



@app.post("/callback")
async def on_tool_called(req: CallbackRequest):
    json_data = from_json(req.model_dump_json(), allow_partial=True)
    test_input_embedding = embeddings.embed_query(json_data["toolInput"])
    results = db.similarity_search_by_vector(test_input_embedding, k=10)

    output = []
    for idx, doc in enumerate(results,1):
        text = doc.page_content  # Retrieved text
        output.append((idx,text))

    return {"response": json.dumps(output, indent=4)}



@app.post("/resourcesChanged")
async def on_resource_changed(req: ResourceChangedRequest):
    err = CustomErrorResponse(error=ErrorDetails(message="Example resource error"))
    return JSONResponse(status_code=400, content=err.dict())



@app.post("/configChanged")
async def on_config_changed(req: ConfigChangedRequest):
    return {}


if __name__ == "__main__":
    uvicorn.run(
        app, host="0.0.0.0", log_level="info", port=int(os.environ.get("PORT", 8081))
    )
