In [18]:
from vanna.vannadb import VannaDB_VectorStore
import os
from dotenv import load_dotenv
from openai import OpenAI
from vanna.base import VannaBase

In [19]:
load_dotenv()
vanna_api_key = os.getenv("vanna_api_key")
vanna_model_name = "vannagetsomecoffee"
krutrim_api_key = os.getenv("krutrim_api_key")

In [20]:
class KrutrimAI(VannaBase):
    
    def __init__(self, config=None):
        if config is None:
            raise ValueError(
                "For using Krutrim, config must be provided with including an api_key and a model"
            )

        if "api_key" not in config:
            raise ValueError("config must contain a Krutrim api_key")

        if "model" not in config:
            raise ValueError("config must contain a model made available in krutrim cloud")

        api_key = config["api_key"]
        model = config["model"]
        self.client = OpenAI(api_key= api_key, base_url="https://cloud.olakrutrim.com/v1",)
        self.model = model

    def system_message(self, message: str) -> any:
        return {"role": "system", "content": message}

    def user_message(self, message: str) -> any:
        return {"role": "user", "content": message}

    def assistant_message(self, message: str) -> any:
        return {"role": "assistant", "content": message}

    def generate_sql(self, question: str, **kwargs) -> str:
        # Use the super generate_sql
        sql = super().generate_sql(question, **kwargs)

        # Replace "\_" with "_"
        sql = sql.replace("\\_", "_")

        return sql

    def submit_prompt(self, prompt, **kwargs) -> str:
        chat_response = self.client.chat.completions.create(
            model=self.model,
            messages=prompt,
            logit_bias= {2435: -100, 640: -100},
            max_tokens= 5000
        )

        return chat_response.choices[0].message.content

In [21]:
class CustomLLMVanna(VannaDB_VectorStore, KrutrimAI):
    def __init__(self, config=None):
        MY_VANNA_MODEL = vanna_model_name# Your model name from https://vanna.ai/account/profile
        VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=vanna_api_key, config=config)
        KrutrimAI.__init__(self, config=config)

In [22]:
vn = CustomLLMVanna(config={'model': 'Meta-Llama-3-8B-Instruct', "api_key" : f"{krutrim_api_key}"})

In [23]:
vn.connect_to_sqlite("chinook.db")

In [24]:
def askVanna(question):
    return vn.ask(question=question, print_results=True, auto_train=False, visualize=False) 


In [25]:
askVanna("Which customers are from brazil?")

SQL Prompt: [{'role': 'system', 'content': 'You are a SQLite expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \nCREATE TABLE "customers"\r\n(\r\n    [CustomerId] INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,\r\n    [FirstName] NVARCHAR(40)  NOT NULL,\r\n    [LastName] NVARCHAR(20)  NOT NULL,\r\n    [Company] NVARCHAR(80),\r\n    [Address] NVARCHAR(70),\r\n    [City] NVARCHAR(40),\r\n    [State] NVARCHAR(40),\r\n    [Country] NVARCHAR(40),\r\n    [PostalCode] NVARCHAR(10),\r\n    [Phone] NVARCHAR(24),\r\n    [Fax] NVARCHAR(24),\r\n    [Email] NVARCHAR(60)  NOT NULL,\r\n    [SupportRepId] INTEGER,\r\n    FOREIGN KEY ([SupportRepId]) REFERENCES "employees" ([EmployeeId]) \r\n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\r\n)\n\nCREATE TABLE "invoices"\r\n(\r\n    [InvoiceId] INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,\r\n    [CustomerId] INTEGER  N

(None, None, None)