# Gemini - Get Relevant Tables

In [2]:
%pip install google cbsodata

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.3.2 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [16]:
from google.generativeai import GenerativeModel
import google.generativeai as genai
import json
from typing import List, Dict, TypedDict
from ast import literal_eval
import cbsodata as cbs

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class TableInfo(TypedDict):
    id: str
    description: str

In [4]:
def create_table_selector_prompt(tables: List[TableInfo], query: str) -> str:
    return f"""
    Given a user's query about data and a list of available tables, identify which tables would be most relevant for answering the query.
    
    Available tables:
    {json.dumps(tables, indent=2)}
    
    User query: {query}
    
    Select the table IDs that would be most relevant for answering this query. Consider:
    1. Direct matches between query topics and table descriptions
    2. Implicit relationships that might be needed to fully answer the query
    3. Select ALL tables that might contribute to a complete answer
    
    Ask clarifying questions if not enough information is available to complete the request.
    
    Provide your response as a single list of strings of table IDs. Nothing else.
    """

In [5]:
def get_relevant_ids_with_gemini(
  model: GenerativeModel,
  tables: List[TableInfo],
  query: str
) -> List[str]:
  prompt = create_table_selector_prompt(tables, query)
  response = model.generate_content(prompt)
  return literal_eval(response.text) 
  

# CBS - Get Data From Gemini Response

## Get Table Descriptions and IDs

In [6]:
def extract_table_info(table: Dict) -> TableInfo:
    return {
        "id": 
            table["Identifier"],
        "description": 
            table["Title"].replace('\n', ' ').replace('\r', ' ') + " " +
            table["Summary"].replace('\n', ' ').replace('\r', ' ') + " " +
            table["ShortDescription"].replace('\n', ' ').replace('\r', ' ')
    }

In [7]:
def get_english_table_info(json_filename: str) -> List[TableInfo]:
  with open(json_filename, 'r') as file:
    english_tables = json.load(file)
  
  return [
    extract_table_info(table) for table in english_tables
  ]

In [None]:
table_info = get_english_table_info('english_tables.json')

print(len(table_info))
table_info[0]

1010


{'id': '80783eng',
 'description': "Agriculture; crops, livestock and land use by general farm type, region Agricultural census; crops, livestock, land use and corresponding number of holdings by general farm type and region  This table contains data on land use, arable farming, horticulture, grassland, grazing livestock and housed animals, at regional level, by general farm type.  The figures in this table are derived from the agricultural census. Data collection for the agricultural census is part of a combined data collection for a.o. agricultural policy use and enforcement of the manure law.  Regional breakdown is based on the main location of the holding. Due to this the region where activities (crops, animals) are allocated may differ from the location where these activities actually occur.  The agricultural census is also used as the basis for the European Farm Structure Survey (FSS). Data from the agricultural census do not fully coincide with the FSS. In the FSS years (2000, 2

## Test Gemini on english datasets

In [17]:
# initialize the model
genai.configure(api_key="AIzaSyDKQzxYWUpb0-WBQXSVq5kWbeZiEzjm3zk")
model = genai.GenerativeModel(model_name='gemini-1.5-flash')

In [None]:
relevant_ids = get_relevant_ids_with_gemini(
  model= model,
  tables = get_english_table_info('english_tables.json'),
  query= "What are the CO2 emissions by sector?"
)

relevant_ids

['83300ENG', '85669ENG', '84917ENG', '84918ENG']

In [114]:
len(relevant_ids)

4

In [115]:
for id in relevant_ids[:10]:
  print(cbs.get_info(id)['Title'])

Emissions to air by the Dutch economy; national accounts 
Emissions of greenhouse gases according to IPCC guide-lines	
Renewable energy; consumption by energy source, technology and application
Avoided use of fossil energy and emission of CO2


# Evaluation with Spider Dataset

In [6]:
import os
import json

In [None]:
def get_test_questions(
  spider_dev_path = "C:/Users/makuz/Downloads/spider_data/spider_data/dev.json"
):

  with open(spider_dev_path, "r", encoding="utf-8") as f:
    dev_data = json.load(f)

  questions = [item["question"] for item in dev_data]
  return questions

get_test_questions()

In [None]:
def create_all_schemas_txt(
  test_database_folder_path = "C:/Users/makuz/Downloads/spider_data/spider_data/test_database"
):

  db_files = [item for item in os.listdir(test_database_folder_path) if os.path.isdir(os.path.join(test_database_folder_path, item))]

  output_path = "all_schemas.txt"

  with open(output_path, "w", encoding="utf-8") as out_file:
    for db in db_files:
      schema_path = os.path.join(test_database_folder_path, db, "schema.sql")

      if os.path.exists(schema_path):
        with open(schema_path, "r", encoding="utf-8") as f:
          content = f.read()

          # Filter out INSERT statements
          schema_lines = [
            line.strip() 
            for line in content.split('\n') 
            if line.strip() and not line.strip().upper().startswith('INSERT')
          ]

          # Write database identifier and its schema
          out_file.write(f"Database: {db}\n")
          out_file.write("=" * 50 + "\n")
          out_file.write("\n".join(schema_lines))
          out_file.write("\n\n" + "=" * 50 + "\n\n")

In [19]:
def create_gemini_prompt(questions):
    prompt_template = """You are an expert in SQL. Given a database schema and a question, generate the correct SQL query to answer the question.

I will provide you with database schemas from different databases. The schemas are separated by '=' lines.
Each schema section starts with 'Database: <database_name>'

For each question, you need to:
1. Identify which database the question is about
2. Use the correct schema to write a SQL query that answers the question
3. Ensure the query follows SQL best practices

Here are the database schemas:
{schemas}

For each of these questions, generate the SQL query:
{formatted_questions}

For each question, provide your answer in this format:
[SQL query]\t[database name]
"""

    # Read the schemas file
    with open("all_schemas.txt", "r", encoding="utf-8") as f:
        schemas = f.read()

    # Format the questions
    formatted_questions = "\n".join([f"{i+1}. {q}" for i, q in enumerate(questions)])

    # Create the final prompt
    final_prompt = prompt_template.format(
        schemas=schemas, formatted_questions=formatted_questions
    )

    return final_prompt

In [20]:
def get_gemini_preds(questions):
    prompt = create_gemini_prompt(questions)
    response = model.generate_content(prompt)
    return response.text


In [22]:
gemini_preds = get_gemini_preds(get_test_questions())
gemini_preds

"```sql\nSELECT\n  COUNT(*)\nFROM singer;\n```\tsinger\n\n```sql\nSELECT\n  SUM(1)\nFROM singer;\n```\tsinger\n\n```sql\nSELECT\n  name,\n  country,\n  age\nFROM singer\nORDER BY\n  age DESC;\n```\tsinger\n\n```sql\nSELECT\n  name,\n  country,\n  age\nFROM singer\nORDER BY\n  age DESC;\n```\tsinger\n\n```sql\nSELECT\n  AVG(age),\n  MIN(age),\n  MAX(age)\nFROM singer\nWHERE\n  country = 'France';\n```\tsinger\n\n```sql\nSELECT\n  AVG(age),\n  MIN(age),\n  MAX(age)\nFROM singer\nWHERE\n  country = 'France';\n```\tsinger\n\n```sql\nSELECT\n  T1.title,\n  T1.Song_release_year\nFROM song AS T1\nINNER JOIN singer AS T2\n  ON T1.Singer_ID = T2.Singer_ID\nORDER BY\n  T2.age\nLIMIT 1;\n```\tsinger\n\n```sql\nSELECT\n  T1.title,\n  T1.Song_release_year\nFROM song AS T1\nINNER JOIN singer AS T2\n  ON T1.Singer_ID = T2.Singer_ID\nORDER BY\n  T2.age\nLIMIT 1;\n```\tsinger\n\n```sql\nSELECT DISTINCT\n  country\nFROM singer\nWHERE\n  age > 20;\n```\tsinger\n\n```sql\nSELECT DISTINCT\n  country\nFROM 

In [29]:
print(gemini_preds)

```sql
SELECT
  COUNT(*)
FROM singer;
```	singer

```sql
SELECT
  SUM(1)
FROM singer;
```	singer

```sql
SELECT
  name,
  country,
  age
FROM singer
ORDER BY
  age DESC;
```	singer

```sql
SELECT
  name,
  country,
  age
FROM singer
ORDER BY
  age DESC;
```	singer

```sql
SELECT
  AVG(age),
  MIN(age),
  MAX(age)
FROM singer
WHERE
  country = 'France';
```	singer

```sql
SELECT
  AVG(age),
  MIN(age),
  MAX(age)
FROM singer
WHERE
  country = 'France';
```	singer

```sql
SELECT
  T1.title,
  T1.Song_release_year
FROM song AS T1
INNER JOIN singer AS T2
  ON T1.Singer_ID = T2.Singer_ID
ORDER BY
  T2.age
LIMIT 1;
```	singer

```sql
SELECT
  T1.title,
  T1.Song_release_year
FROM song AS T1
INNER JOIN singer AS T2
  ON T1.Singer_ID = T2.Singer_ID
ORDER BY
  T2.age
LIMIT 1;
```	singer

```sql
SELECT DISTINCT
  country
FROM singer
WHERE
  age > 20;
```	singer

```sql
SELECT DISTINCT
  country
FROM singer
WHERE
  age > 20;
```	singer

```sql
SELECT
  country,
  COUNT(*)
FROM singer
GROUP BY
  c

In [27]:
def format_gemini_preds(gemini_preds: str) -> str:
  # Split the predictions by newline
  lines = gemini_preds.split('\n')
  
  formatted_preds = []
  current_sql = []
  current_db = ''
  
  for line in lines:
    line = line.strip()
    
    # Skip empty lines
    if not line:
      continue
      
    # If line starts with ```sql, start collecting SQL
    if line.startswith('```sql'):
      current_sql = []
      continue
      
    # If line starts with ```, stop collecting SQL
    elif line.startswith('```'):
      # If we have collected SQL and database, format and add to results
      if current_sql and current_db:
        sql = ' '.join(current_sql)
        formatted_preds.append(f"{sql}\t{current_db}")
      continue
      
    # If line has a tab, it's the database name
    elif '\t' in line:
      current_db = line.split('\t')[1]
      continue
      
    # Otherwise it's part of the SQL query  
    else:
      current_sql.append(line)

  formatted_result = '\n'.join(formatted_preds)
      
  # Save formatted result to file
  with open('formatted_predictions.txt', 'w') as f:
    f.write(formatted_result)
      
  return formatted_result

In [28]:
format_gemini_preds(gemini_preds)

''