In [1]:
import os
from dotenv import load_dotenv

from sqlalchemy import create_engine  
from sqlalchemy.engine.url import URL

# Load credentials
load_dotenv()

REDSHIFT_HOST = os.environ.get("HOST")
REDSHIFT_PORT = int(os.environ.get("PORT")) # If PORT is not defined in .env, you need to add it.
REDSHIFT_DATABASE = os.environ.get("DATABASE")
REDSHIFT_DB_USER = os.environ.get("DB_USER")
REDSHIFT_PASSWORD = os.environ.get("PASSWORD")
REDSHIFT_LOGIN_URL = os.environ.get("LOGIN_URL")
REDSHIFT_CLUSTER_ID = os.environ.get("CLUSTER_IDENTIFIER")
REDSHIFT_REGION = os.environ.get("REGION")

# Create SQLAlchemy URL
url = URL.create(
    "redshift+redshift_connector",
    username=REDSHIFT_DB_USER, 
    database=REDSHIFT_DATABASE,
    password=REDSHIFT_PASSWORD,
    host=REDSHIFT_HOST,
    port=REDSHIFT_PORT
)

# Create engine
engine = create_engine(url, connect_args={"iam": True,
                                          "credentials_provider": "BrowserSamlCredentialsProvider",
                                          "login_url": REDSHIFT_LOGIN_URL,
                                          "cluster_identifier" : REDSHIFT_CLUSTER_ID,
                                            "region" : REDSHIFT_REGION
                                            })

# Example query
with engine.connect() as conn:
    result = conn.execute("SELECT COUNT(*) FROM information_schema.tables") #### THIS IS FOR TESTING
    for row in result:
        print(row)
import os
from dotenv import load_dotenv
from transformers import AutoTokenizer, pipeline
from langchain import HuggingFacePipeline
import torch

# Get Env Variables
load_dotenv()  # load the values for environment variables from the .env file

HF_MODEL = os.environ.get("HF_MODEL")

model = HF_MODEL 
tokenizer = AutoTokenizer.from_pretrained(model)

pipeline_ = pipeline(
    "text-generation", #task
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
    max_length=4000,
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id
)

llm = HuggingFacePipeline(pipeline=pipeline_, model_kwargs={'temperature': 0.6, "top_p": 0.9})

import langchain
langchain.debug = True
langchain.verbose = True

from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain, SQLDatabaseSequentialChain
from sqlalchemy.exc import ProgrammingError, DataError

db = SQLDatabase.from_uri(
    url, 
    engine_args={
        "connect_args": {
            "iam": True,
            "credentials_provider": "BrowserSamlCredentialsProvider",
            "login_url": REDSHIFT_LOGIN_URL,
            "cluster_identifier": REDSHIFT_CLUSTER_ID,
            "region": REDSHIFT_REGION
        }
    }
)

db_chain = SQLDatabaseChain.from_llm(
    llm, db, verbose=True, use_query_checker=True
)

try:
    db_chain("How many tables are in the database?") 
except (ProgrammingError, ValueError, DataError) as exc:
    print(f"\n\n{exc}")

  result = conn.execute("SELECT COUNT(*) FROM information_schema.tables") #### THIS IS FOR TESTING


(1214,)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[32;1m[1;3m[chain/start][0m [1m[1:chain:SQLDatabaseChain] Entering Chain run with input:
[0m{
  "query": "How many tables are in the database?"
}
[32;1m[1;3m[chain/start][0m [1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:
[0m{
  "input": "How many tables are in the database?\nSQLQuery:",
  "top_k": "5",
  "dialect": "redshift",
  "stop": [
    "\nSQLResult:"
  ]
}
[32;1m[1;3m[llm/start][0m [1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:HuggingFacePipeline] Entering LLM run with input:
[0m{
  "prompts": [
    "Given an input question, first create a syntactically correct redshift query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most 5 results. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from 