In [None]:
import os
import sqlite3
import pandas as pd

def csvs_to_sqlite(csv_dir, sqlite_path="output.db"):
    # Connect (or create) SQLite database
    conn = sqlite3.connect(sqlite_path)

    # Loop over all CSV files in directory
    for file in os.listdir(csv_dir):
        if file.endswith(".csv"):
            # Extract table name from filename: keep middle part
            # e.g. beataml2__patient__20250924071940.csv -> patient
            parts = file.split("__")
            if len(parts) >= 3:
                table_name = parts[1]
            else:
                table_name = os.path.splitext(file)[0]  # fallback to filename

            print(f"Loading file: {file} into table: {table_name}")
            
            # Read CSV
            df = pd.read_csv(os.path.join(csv_dir, file))
            
            # Write DataFrame to SQLite
            df.to_sql(table_name, conn, if_exists="replace", index=False)

    conn.commit()
    conn.close()
    print(f"All CSV files from {csv_dir} are now stored in {sqlite_path}")


csvs_to_sqlite("BeatAML/Tables", sqlite_path="BeatAML/BeatAML.db")

In [None]:
import sqlite3
import pandas as pd
import concurrent.futures
import numpy as np
from openai import OpenAI

client = OpenAI()

def generate_llm_description(prompt, model="gpt-4o-mini"):
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are an expert data documentation assistant."},
            {"role": "user", "content": prompt}
        ],
        temperature=0.2,
    )
    return response.choices[0].message.content.strip()


def is_probably_numeric_string(series, threshold=0.9):
    """
    Detect if a string column is mostly numeric.
    If >= threshold fraction can be converted to numeric, treat as numeric.
    """
    non_null = series.dropna().astype(str)
    converted = pd.to_numeric(non_null, errors="coerce")
    numeric_ratio = converted.notna().mean()
    return numeric_ratio >= threshold


def generate_schema_from_db(
    db_path, 
    field_dict_csv=None, 
    max_unique=50, 
    sample_size=5, 
    max_workers=10
):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    schema = {}
    concept_rows = []
    
    # --- NEW: integrate field dictionary ---
    field_dict = {}
    if field_dict_csv:
        fd_df = pd.read_csv(field_dict_csv).dropna(subset=["Table Name", "fields "])
        for _, row in fd_df.iterrows():
            t = str(row["Table Name"]).strip()
            f = str(row["fields "]).strip()
            field_dict[(t, f)] = {
                "dtype": str(row.get("data type", "")).strip() if not pd.isna(row.get("data type", "")) else None,
                "description": str(row.get("description ", "")).strip() if not pd.isna(row.get("description ", "")) else None
            }
    
    # Get table names
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [t[0] for t in cursor.fetchall()]
    
    for table in tables:
        df = pd.read_sql(f"SELECT * FROM {table}", conn)

        # ----------------------------
        # Build richer table-level prompt
        # ----------------------------
        field_examples = []
        for col in df.columns:
            sample_vals = pd.Series(df[col].dropna().unique()).head(sample_size).tolist()
            dtype = str(df[col].dtype)
            field_examples.append(f"- {col} ({dtype}), e.g. {sample_vals}")

        table_prompt = f"""
        You are given the following information about a database table:

        Table name: {table}

        Columns with example values:
        {chr(10).join(field_examples)}

        Please write a clear, concise (2–3 lines), human-readable and accurate description of what this table contains overall. Focus on the purpose and nature of data stored in the table, not on detailed descriptions of individual fields. Provide your response in a plain text format.
        """

        # ----------------------------
        # Collect prompts (table + fields)
        # ----------------------------
        prompts = {"__table__": table_prompt}
        candidate_concepts = []  # store candidate concept values only after filtering

        for col in df.columns:
            col_data = df[col].dropna()
            dtype = str(df[col].dtype)

            total_rows = len(df)
            num_unique = col_data.nunique()
            uniqueness_pct = (num_unique / total_rows) * 100 if total_rows > 0 else 0
            
            unique_vals = col_data.unique()
            if len(unique_vals) > max_unique:
                unique_info = f"{len(unique_vals)} unique values"
                concept_samples = unique_vals
            else:
                unique_info = unique_vals.tolist()
                concept_samples = unique_vals

            # ---- Exclusion logic for concept table ----
            exclude = False
            if np.issubdtype(df[col].dtype, np.number):
                exclude = True  # numeric column
            elif is_probably_numeric_string(col_data):
                exclude = True  # stringified numeric
            elif uniqueness_pct == 100:
                exclude = True  # ID-like column
            
            if not exclude:
                candidate_concepts.append((table, col, concept_samples))

            # ---- Field description prompt ----
            fd_info = field_dict.get((table, col), {})
            dict_dtype = fd_info.get("dtype")
            dict_desc = fd_info.get("description")

            extra_context = ""
            if dict_dtype or dict_desc:
                extra_context = f"""
                Additional context from data dictionary:
                - Declared data type: {dict_dtype}
                - Provided description: {dict_desc}
                """

            field_prompt = f"""
            You are given the following information about a database column:

            - Column name: {col}
            - Table: {table}
            - Data type (inferred from DB): {dtype}
            - Sample values: {pd.Series(col_data.unique()).head(sample_size).tolist()}
            - Unique values: {unique_info}
            - Uniqueness percent: {round(uniqueness_pct, 2)}
            {extra_context}

            Write a clear, concise and human-readable description of what this field likely represents. Provide your response as a plain text format.
            """
            prompts[col] = field_prompt

        # ----------------------------
        # Parallel execution of all prompts
        # ----------------------------
        results = {}
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_key = {
                executor.submit(generate_llm_description, p): k
                for k, p in prompts.items()
            }
            for future in concurrent.futures.as_completed(future_to_key):
                key = future_to_key[future]
                try:
                    results[key] = future.result()
                except Exception as e:
                    results[key] = f"Error: {e}"

        # ----------------------------
        # Build schema with results
        # ----------------------------
        table_info = {
            "table_description": results["__table__"],
            "fields": {}
        }

        for col in df.columns:
            col_data = df[col].dropna()
            dtype = str(df[col].dtype)

            total_rows = len(df)
            num_unique = col_data.nunique()
            uniqueness_pct = (num_unique / total_rows) * 100 if total_rows > 0 else 0
            
            unique_vals = col_data.unique()
            if len(unique_vals) > max_unique:
                unique_info = f"{len(unique_vals)} unique values"
            else:
                unique_info = unique_vals.tolist()

            table_info["fields"][col] = {
                "field_data_type": dtype,
                "field_description": results[col],
                "field_sample_values": pd.Series(col_data.unique()).head(sample_size).tolist(),
                "field_unique_values": unique_info,
                "field_uniqueness_percent": round(uniqueness_pct, 2)
            }

        schema[table] = table_info

        # ----------------------------
        # Finalize concept table (after LLM review)
        # ----------------------------
        for table_name, col_name, values in candidate_concepts:
            for val in values:
                concept_rows.append({
                    "concept_name": str(val),
                    "table_name": table_name,
                    "field_name": col_name
                })

    concept_df = pd.DataFrame(concept_rows)
    return schema, concept_df


# Example usage:
db_path = "BeatAML/BeatAML.db"  # path to SQLite DB
field_dict_csv = "BeatAML/BeatAML_data_dict.csv"  # path to your CSV with descriptions
schema, concept_df = generate_schema_from_db(db_path, field_dict_csv=field_dict_csv)


In [None]:
import json

# Assuming schema is already generated from your function
output_file = "BeatAML/BeatAML_schema.json"

with open(output_file, "w", encoding="utf-8") as f:
    json.dump(schema, f, indent=2, ensure_ascii=False)

print(f"Schema saved to {output_file}")


In [None]:
concept_df["concept_with_context"] = concept_df["table_name"].astype(str) + "_" + concept_df["field_name"].astype(str) + "_" + concept_df["concept_name"].astype(str)
concept_df.to_csv("BeatAML/concept_table.csv", index=False)
concept_df

In [None]:
import pandas as pd
import numpy as np
import pickle
import time
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

# Initialize OpenAI client
client = OpenAI()

# -------- Step 1: Load concept table --------
concept_df = pd.read_csv("BeatAML/concept_table.csv")
concepts = concept_df["concept_with_context"].astype(str).tolist()

# -------- Step 2: Define embedding function --------
def get_single_embedding(text, model="text-embedding-3-small", max_retries=3, delay=2):
    """Generate embedding for a single text with retry logic."""
    for attempt in range(max_retries):
        try:
            response = client.embeddings.create(model=model, input=text)
            return response.data[0].embedding
        except Exception as e:
            if attempt < max_retries - 1:
                time.sleep(delay * (2 ** attempt))  # exponential backoff
            else:
                print(f"[ERROR] Failed after {max_retries} retries for text: {text[:50]}... | Error: {e}")
                return None

# -------- Step 3: Parallel embedding generation --------
def get_embeddings_parallel(texts, model="text-embedding-3-small", max_workers=10):
    embeddings = [None] * len(texts)
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(get_single_embedding, text, model): i for i, text in enumerate(texts)}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Generating embeddings"):
            idx = futures[future]
            result = future.result()
            embeddings[idx] = result
    return np.array([e if e is not None else np.zeros(1536) for e in embeddings])  # handle failed embeddings

# -------- Step 4: Generate and save embeddings --------
print(f"Generating embeddings for {len(concepts)} concepts using parallel threads...")
embeddings = get_embeddings_parallel(concepts, max_workers=10)

output_path = "BeatAML/concept_embeddings.pkl"
with open(output_path, "wb") as f:
    pickle.dump({"concepts": concepts, "embeddings": embeddings}, f)

print(f"✅ Embeddings saved to {output_path}")


In [None]:
import json
from typing import List, Dict, Any, Literal, Tuple
from rich import print
import os

from openai import OpenAI
client = OpenAI()

# open the schema json file
with open('./BeatAML/BeatAML_schema.json', 'r') as f:
    DB_SCHEMA = json.load(f)

# helper functions to extract summary text about DB fields and embed them as vectors
def extract_table_field_texts(schema: Dict[str, Dict[str, Any]]) -> Dict[str, str]:
    """
    Extract (table.field, text) pairs from DB schema.
    Text includes table/field description + sample values (if present).
    Return a dict mapping table.field name to its description text.
    """
    pairs = {}
    for table, table_info in schema.items():
        table_desc = table_info.get('table_description', {})
        fields = table_info.get("fields", {})
        for field, field_info in fields.items():
            desc = field_info.get("field_description", "")
            samples = field_info.get("field_sample_values", [])
            sample_str = f" Example Values: {', '.join(map(str, samples))}" if samples else ""
            #text = f"{table}__{table_desc.split('.')[0]}__{field}__{desc}__{sample_str}" # include the 1st line in table summary 
            text = f"{field}__{desc}__{sample_str}" # only field summary, exclude table summary 
            pairs[f'{table}.{field}'] = text.strip()
    return pairs

def embed_texts(pairs: Dict[str, str]) -> Dict[str, Dict[str, Any]]:
    """
    Embed schema table + field texts into vectors.
    Returns dict in format: {table.field: {'text': str, 'embedding': List}}
    """
    texts = [text for id, text in pairs.items()]
    response = client.embeddings.create(
        model="text-embedding-3-small", 
        input=texts
    )
    embeddings = [e.embedding for e in response.data]
    return {k: {"text": v, "embedding": embeddings[i]} for i, (k, v) in enumerate(pairs.items())}


table_field_pairs = extract_table_field_texts(DB_SCHEMA)

# Get table.field embeddings along with texts
table_field_embeddings = embed_texts({k:v for k,v in list(table_field_pairs.items())})

# Write embeddings dict to file
with open('BeatAML/db_table_field_embeddings.json', 'w') as f:
    json.dump(table_field_embeddings, f, indent=2)
