In [1]:
import json
from typing import List, Dict, Tuple

In [2]:
dev_tables: List[Dict] = json.load(open('bird_raw_data/dev_tables.json', 'r'))
dev_tables[0]

{'db_id': 'debit_card_specializing',
 'table_names_original': ['customers',
  'gasstations',
  'products',
  'transactions_1k',
  'yearmonth'],
 'table_names': ['customers',
  'gasstations',
  'products',
  'transactions_1k',
  'yearmonth'],
 'column_names_original': [[-1, '*'],
  [0, 'CustomerID'],
  [0, 'Segment'],
  [0, 'Currency'],
  [1, 'GasStationID'],
  [1, 'ChainID'],
  [1, 'Country'],
  [1, 'Segment'],
  [2, 'ProductID'],
  [2, 'Description'],
  [3, 'TransactionID'],
  [3, 'Date'],
  [3, 'Time'],
  [3, 'CustomerID'],
  [3, 'CardID'],
  [3, 'GasStationID'],
  [3, 'ProductID'],
  [3, 'Amount'],
  [3, 'Price'],
  [4, 'CustomerID'],
  [4, 'Date'],
  [4, 'Consumption']],
 'column_names': [[-1, '*'],
  [0, 'CustomerID'],
  [0, 'client segment'],
  [0, 'Currency'],
  [1, 'Gas Station ID'],
  [1, 'Chain ID'],
  [1, 'Country'],
  [1, 'chain segment'],
  [2, 'Product ID'],
  [2, 'Description'],
  [3, 'Transaction ID'],
  [3, 'Date'],
  [3, 'Time'],
  [3, 'Customer ID'],
  [3, 'Card ID']

In [3]:
dev: List[Dict] = json.load(open('bird_raw_data/dev.json', 'r'))
dev[0]

{'question_id': 0,
 'db_id': 'california_schools',
 'question': 'What is the highest eligible free rate for K-12 students in the schools in Alameda County?',
 'evidence': 'Eligible free rate for K-12 = `FRPM Count (K-12)` / `Enrollment (K-12)`',
 'SQL': "SELECT `FRPM Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`FRPM Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1",
 'difficulty': 'simple'}

In [4]:
def convert_schema(input_schema):
    # Extract basic information
    db_id = input_schema["db_id"]
    
    # Initialize result structure
    result = {
        "db_id": db_id,
        "tables": []
    }
    
    # Create mapping for easier access
    table_idx_to_name = {}
    for i, table_name in enumerate(input_schema["table_names"]):
        table_idx_to_name[i] = table_name
    
    # Create mapping for column index to (table_idx, column_name)
    column_info_map = {}
    for col_idx, (col_table_idx, col_name_original) in enumerate(input_schema["column_names_original"]):
        if col_idx > 0:  # Skip the special "*" column
            column_info_map[col_idx] = (col_table_idx, col_name_original)
    
    # Process primary keys (can be single indices or composite keys as lists)
    primary_keys = {}
    for pk_item in input_schema["primary_keys"]:
        if isinstance(pk_item, list):
            # This is a composite primary key
            for pk_idx in pk_item:
                table_idx = column_info_map[pk_idx][0]
                col_name = column_info_map[pk_idx][1]
                if table_idx not in primary_keys:
                    primary_keys[table_idx] = []
                primary_keys[table_idx].append(col_name)
        else:
            # This is a single column primary key
            pk_idx = pk_item
            table_idx = column_info_map[pk_idx][0]
            col_name = column_info_map[pk_idx][1]
            if table_idx not in primary_keys:
                primary_keys[table_idx] = []
            primary_keys[table_idx].append(col_name)
    
    # Process foreign keys
    foreign_keys = {}
    for fk_pair in input_schema["foreign_keys"]:
        from_idx, to_idx = fk_pair
        from_table_idx = column_info_map[from_idx][0]
        from_col_name = column_info_map[from_idx][1]
        to_table_idx = column_info_map[to_idx][0]
        to_col_name = column_info_map[to_idx][1]
        
        if from_table_idx not in foreign_keys:
            foreign_keys[from_table_idx] = []
        
        foreign_keys[from_table_idx].append({
            "from_column": from_col_name,
            "to_table": table_idx_to_name[to_table_idx],
            "to_column": to_col_name
        })
    
    # Process each table
    for table_idx, table_name in enumerate(input_schema["table_names"]):
        table_info = {
            "table_name": table_name,
            "columns": [],
            "primary_keys": primary_keys.get(table_idx, []),
            "foreign_keys": foreign_keys.get(table_idx, [])
        }
        
        # Process columns for this table
        for col_idx, (col_table_idx, col_name_original) in enumerate(input_schema["column_names_original"]):
            # Skip the special "*" column
            if col_idx == 0:
                continue
                
            # Check if this column belongs to current table
            if col_table_idx == table_idx:
                # Find corresponding column type
                # Note: column_types[0] corresponds to column_names[1] due to "*" at index 0
                col_type = input_schema["column_types"][col_idx - 1]
                
                column_info = {
                    "column_name": col_name_original,
                    "column_type": col_type
                }
                
                table_info["columns"].append(column_info)
        
        result["tables"].append(table_info)
    
    return result

In [5]:
convert_schema(dev_tables[0])

{'db_id': 'debit_card_specializing',
 'tables': [{'table_name': 'customers',
   'columns': [{'column_name': 'CustomerID', 'column_type': 'text'},
    {'column_name': 'Segment', 'column_type': 'integer'},
    {'column_name': 'Currency', 'column_type': 'text'}],
   'primary_keys': ['CustomerID'],
   'foreign_keys': []},
  {'table_name': 'gasstations',
   'columns': [{'column_name': 'GasStationID', 'column_type': 'text'},
    {'column_name': 'ChainID', 'column_type': 'integer'},
    {'column_name': 'Country', 'column_type': 'integer'},
    {'column_name': 'Segment', 'column_type': 'text'}],
   'primary_keys': ['GasStationID'],
   'foreign_keys': []},
  {'table_name': 'products',
   'columns': [{'column_name': 'ProductID', 'column_type': 'text'},
    {'column_name': 'Description', 'column_type': 'integer'}],
   'primary_keys': ['ProductID'],
   'foreign_keys': []},
  {'table_name': 'transactions_1k',
   'columns': [{'column_name': 'TransactionID', 'column_type': 'text'},
    {'column_name'

In [6]:
dataset = []
for item in dev:
    db_id = item['db_id']
    for db in dev_tables:
        if db['db_id'] == db_id:
            dataset.append({
                'db_id': db_id,
                'context': f"Database schema: {json.dumps(convert_schema(db))}",
                'question': f"Based on the database schema, you *MUST* write a *SQL query* for the following question: {item['question']}  (Hint: {item['evidence']})",
                'target': item['SQL']
            })
            break
len(dataset)

1534

In [8]:
# shuffle the dataset
import random
random.seed(42)
random.shuffle(dataset)


with open('data/bird_train_700.json', 'w') as f:
    json.dump(dataset[:700], f, indent=4)
with open('data/bird_val_300.json', 'w') as f:
    json.dump(dataset[700:1000], f, indent=4)
with open('data/bird_test_534.json', 'w') as f:
    json.dump(dataset[1000:], f, indent=4)

with open('data/bird_train_700.jsonl', 'w') as f:
    for entry in dataset[:700]:
        f.write(json.dumps(entry) + '\n')
with open('data/bird_val_300.jsonl', 'w') as f:
    for entry in dataset[700:1000]:
        f.write(json.dumps(entry) + '\n')
with open('data/bird_test_534.jsonl', 'w') as f:
    for entry in dataset[1000:]:
        f.write(json.dumps(entry) + '\n')