In [None]:
%load_ext autoreload
%autoreload 2

## DB Setup

We assume you already have a postgres database ready.

In [None]:
DATABASE = "database"
USER = "postgres"
PASSWORD = "password"
HOST = "localhost"
PORT = 5432
TABLES = []  # list of tables to load or [] to load all tables

In [None]:
from db_connectors import PostgresConnector
from prompt_formatters import RajkumarFormatter

# Get the connector and formatter
postgres_connector = PostgresConnector(
    user=USER, password=PASSWORD, dbname=DATABASE, host=HOST, port=PORT
)
postgres_connector.connect()
if len(TABLES) <= 0:
    TABLES.extend(postgres_connector.get_tables())

print(f"Loading tables: {TABLES}")

db_schema = [postgres_connector.get_schema(table) for table in TABLES]
formatter = RajkumarFormatter(db_schema)

## Model Setup

In a separate screen or window, first install [Manifest](https://github.com/HazyResearch/manifest)
```bash
pip install manifest-ml\[all\]
```

Then run
```bash
python3 -m manifest.api.app \
    --model_type huggingface \
    --model_generation_type text-generation \
    --model_name_or_path NumbersStation/nsql-350M \
    --device 0
```

If successful, you will see an output like
```bash
* Running on http://127.0.0.1:5000
```

In [None]:
from manifest import Manifest

manifest_client = Manifest(client_name="huggingface", client_connection="http://127.0.0.1:5000")

def get_sql(instruction: str, max_tokens: int = 300) -> str:
    prompt = formatter.format_prompt(instruction)
    res = manifest_client.run(prompt, max_tokens=max_tokens)
    return formatter.format_model_output(res)

In [None]:
sql = get_sql("Number of rows in table?")
print(sql)

In [None]:
print(postgres_connector.run_sql_as_df(sql))