In [None]:
import os
import json
import importlib
from pathlib import Path
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
import sys

In [None]:
def load_schema_from_file(schema_file_path):
    """
    Load schema from a Python file.
    
    Args:
        schema_file_path (str): Path to the schema Python file
    
    Returns:
        dict: Schema dictionary or None if error
    """
    try:
        spec = importlib.util.spec_from_file_location("schema_module", schema_file_path)
        schema_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(schema_module)
        return schema_module.schema
    except Exception as e:
        print(f"Error loading schema from {schema_file_path}: {e}")
        return None

In [None]:
def extract_sample_data(db_path, table_name, limit=5):
    """
    Extract sample data from a table, prioritizing rows with no null values.
    
    Args:
        db_path (str): Path to the database file
        table_name (str): Name of the table
        limit (int): Number of rows to extract (default: 5)
    
    Returns:
        list: List of dictionaries representing rows
    """
    try:
        engine = create_engine(f'sqlite:///{db_path}')
        
        # First, try to get rows with no null values
        query_no_nulls = f"""
        SELECT * FROM {table_name} 
        WHERE {' AND '.join([f'{col} IS NOT NULL' for col in get_table_columns(db_path, table_name)])}
        LIMIT {limit}
        """
        
        with engine.connect() as conn:
            try:
                result = conn.execute(text(query_no_nulls))
                rows = result.fetchall()
                columns = result.keys()
                
                # If we don't have enough rows without nulls, get any rows
                if len(rows) < limit:
                    query_any = f"SELECT * FROM {table_name} LIMIT {limit}"
                    result = conn.execute(text(query_any))
                    rows = result.fetchall()
                    columns = result.keys()
                
                # Convert to list of dictionaries
                sample_data = []
                for row in rows:
                    row_dict = {}
                    for i, col in enumerate(columns):
                        row_dict[col] = row[i]
                    sample_data.append(row_dict)
                
                return sample_data
                
            except Exception as e:
                print(f"Error querying table {table_name}: {e}")
                return []
                
    except SQLAlchemyError as e:
        print(f"Database error: {e}")
        return []

In [None]:
def get_table_columns(db_path, table_name):
    """
    Get column names for a table.
    
    Args:
        db_path (str): Path to the database file
        table_name (str): Name of the table
    
    Returns:
        list: List of column names
    """
    try:
        engine = create_engine(f'sqlite:///{db_path}')
        with engine.connect() as conn:
            result = conn.execute(text(f"PRAGMA table_info({table_name})"))
            columns = [row[1] for row in result.fetchall()]  # Column name is at index 1
            return columns
    except Exception as e:
        print(f"Error getting columns for {table_name}: {e}")
        return []

In [None]:
def add_samples_to_schema(schema_folder="schemas", database_folder="database"):
    """
    Find schema files, extract sample data, and add samples to schema.
    
    Args:
        schema_folder (str): Folder containing schema Python files
        database_folder (str): Folder containing database files
    
    Returns:
        dict: Updated schemas with sample data
    """
    
    schema_files = list(Path(schema_folder).glob("*.py"))
    
    if not schema_files:
        print(f"No schema files found in {schema_folder}")
        return {}
    
    updated_schemas = {}
    
    for schema_file in schema_files:
        print(f"\nProcessing {schema_file.name}...")
        
        # Load schema from file
        schema = load_schema_from_file(str(schema_file))
        if not schema:
            continue
        
        # Extract database name from schema file name
        # Remove .py extension and _db suffix if present
        db_name = schema_file.stem
        if db_name.endswith('_db'):
            db_name = db_name[:-3]  # Remove '_db'
        
        # Look for corresponding database file
        db_path = Path(database_folder) / f"{db_name}.db"
        
        if not db_path.exists():
            print(f"Database file not found: {db_path}")
            continue
        
        print(f"Found database: {db_path}")
        
        # Add sample data to each table in schema
        updated_schema = {}
        for table_name, columns in schema.items():
            print(f"  Extracting samples from table: {table_name}")
            
            # Extract sample data
            sample_data = extract_sample_data(str(db_path), table_name, limit=5)
            
            # Add to schema
            updated_schema[table_name] = {
                "columns": columns,
                "sample": sample_data
            }
        
        # Save updated schema back to file
        output_file = schema_file
        with open(output_file, 'w') as f:
            f.write(f"# Database schema for {db_name}.db\n")
            f.write(f"# Generated automatically with sample data\n\n")
            f.write(f"schema = {json.dumps(updated_schema, indent=4, default=str)}\n")
        
        updated_schemas[db_name] = updated_schema
        print(f"  Updated schema saved to: {output_file}")
    
    return updated_schemas

In [None]:
def process_single_schema(schema_file_path, database_folder="database"):
    """
    Process a single schema file and add sample data.
    
    Args:
        schema_file_path (str): Path to the schema file
        database_folder (str): Folder containing database files
    
    Returns:
        dict: Updated schema with sample data
    """
    
    schema_file = Path(schema_file_path)
    
    if not schema_file.exists():
        print(f"Schema file not found: {schema_file_path}")
        return None
    
    print(f"Processing {schema_file.name}...")
    
    # Load schema from file
    schema = load_schema_from_file(str(schema_file))
    if not schema:
        return None
    
    # Extract database name from schema file name
    db_name = schema_file.stem
    if db_name.endswith('_db'):
        db_name = db_name[:-3]  # Remove '_db'
    
    # Look for corresponding database file
    db_path = Path(database_folder) / f"{db_name}.db"
    
    if not db_path.exists():
        print(f"Database file not found: {db_path}")
        return None
    
    print(f"Found database: {db_path}")
    
    # Add sample data to each table in schema
    updated_schema = {}
    for table_name, columns in schema.items():
        print(f"  Extracting samples from table: {table_name}")
        
        # Extract sample data
        sample_data = extract_sample_data(str(db_path), table_name, limit=5)
        
        # Add to schema
        updated_schema[table_name] = {
            "columns": columns,
            "sample": sample_data
        }
    
    # Save updated schema back to file
    with open(schema_file, 'w') as f:
        f.write(f"# Database schema for {db_name}.db\n")
        f.write(f"# Generated automatically with sample data\n\n")
        f.write(f"schema = {json.dumps(updated_schema, indent=4, default=str)}\n")
    
    print(f"Updated schema saved to: {schema_file}")
    return updated_schema

In [None]:
# Example usage
if __name__ == "__main__":
    # Process all schema files in schemas folder
    updated_schemas = add_samples_to_schema("schemas", "database")
    
    # Or process a single schema file
    # updated_schema = process_single_schema("schemas/chinook.py", "database")
    
    # Print summary
    print(f"\nProcessed {len(updated_schemas)} schemas:")
    for db_name, schema in updated_schemas.items():
        print(f"  {db_name}: {len(schema)} tables")
        for table_name, table_info in schema.items():
            sample_count = len(table_info.get('sample', []))
            print(f"    - {table_name}: {sample_count} sample rows")