# Infer Relationships from PostgreSQL Sample Data

> **Features**:
> - Connects to your DB (or loads CSVs)
> - Samples data
> - Infers **FK relationships** with **confidence scores**
> - Visualizes as **interactive ER diagram** using **Mermaid**
> - Exports to **CSV, SQL, and HTML**

In [None]:
# Install Dependencies (Run Once)
!pip install pandas sqlalchemy psycopg2-binary ipywidgets tqdm matplotlib seaborn plotly

In [None]:
# CELL 1: Imports and Setup
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from tqdm.notebook import tqdm
from collections import namedtuple
import warnings
warnings.filterwarnings("ignore")

# For interactive widgets
from IPython.display import display, HTML, Markdown
import ipywidgets as widgets

Relationship = namedtuple('Relationship', ['parent_table', 'parent_col', 'child_table', 'child_col', 'confidence', 'reason', 'sample_overlap'])

print("Setup complete!")

In [None]:
# CELL 2: Database Connection Widget
db_url_widget = widgets.Text(
    value='postgresql://username:password@localhost:5432/yourdb',
    placeholder='Enter PostgreSQL URL',
    description='DB URL:',
    layout={'width': '100%'}
)

sample_size_widget = widgets.IntSlider(value=5000, min=1000, max=20000, step=1000, description='Sample Size:')
min_overlap_widget = widgets.FloatSlider(value=0.8, min=0.5, max=1.0, step=0.05, description='Min Overlap:')
min_uniqueness_widget = widgets.FloatSlider(value=0.9, min=0.7, max=1.0, step=0.05, description='Min Uniqueness:')

display(db_url_widget, sample_size_widget, min_overlap_widget, min_uniqueness_widget)

In [None]:
# CELL 3: Load Data & Infer Relationships
def infer_fk_from_db(db_url, sample_size, min_overlap, min_uniqueness):
    engine = create_engine(db_url)
    
    # Get tables
    tables_df = pd.read_sql("SELECT tablename FROM pg_tables WHERE schemaname='public'", engine)
    tables = tables_df['tablename'].tolist()
    
    print(f"Found {len(tables)} tables. Sampling {sample_size} rows each...")
    
    data = {}
    for table in tqdm(tables, desc="Sampling tables"):
        try:
            query = text(f"SELECT * FROM {table} LIMIT :limit")
            data[table] = pd.read_sql(query, engine, params={"limit": sample_size})
        except Exception as e:
            print(f"Skipping {table}: {e}")
    
    candidates = []
    print("Inferring relationships...")
    
    for child_table, child_df in tqdm(data.items(), desc="Checking child tables"):
        for child_col in child_df.columns:
            if child_df[child_col].dtype.kind not in 'Oib':
                continue
            child_vals = set(child_df[child_col].dropna().astype(str))
            if len(child_vals) == 0:
                continue
                
            for parent_table, parent_df in data.items():
                if child_table == parent_table:
                    continue
                for parent_col in parent_df.columns:
                    if parent_df[parent_col].dtype.kind not in 'Oib':
                        continue
                    parent_vals = set(parent_df[parent_col].dropna().astype(str))
                    if len(parent_vals) == 0:
                        continue

                    overlap_count = len(child_vals & parent_vals)
                    overlap_ratio = overlap_count / len(child_vals)
                    uniqueness = len(parent_vals) / len(parent_df)

                    if overlap_ratio >= min_overlap and uniqueness >= min_uniqueness:
                        confidence = overlap_ratio * uniqueness
                        reason = f"{overlap_ratio:.1%} values exist in parent, {uniqueness:.1%} unique"
                        candidates.append(Relationship(
                            parent_table, parent_col,
                            child_table, child_col,
                            confidence, reason, overlap_count
                        ))

    candidates.sort(key=lambda x: x.confidence, reverse=True)
    return candidates, data

# Run inference
try:
    relationships, sampled_data = infer_fk_from_db(
        db_url=db_url_widget.value,
        sample_size=sample_size_widget.value,
        min_overlap=min_overlap_widget.value,
        min_uniqueness=min_uniqueness_widget.value
    )
    rel_df = pd.DataFrame(relationships)
    display(Markdown(f"### Found **{len(rel_df)}** candidate relationships"))
    display(rel_df.head(10))
except Exception as e:
    display(Markdown(f"**Error**: {e}"))

In [None]:
# CELL 4: Visualize Top Relationships
if 'rel_df' in locals() and len(rel_df) > 0:
    top_n = min(15, len(rel_df))
    plt.figure(figsize=(12, 8))
    sns.barplot(
        data=rel_df.head(top_n),
        y=rel_df.head(top_n).apply(lambda x: f"{x.parent_table}.{x.parent_col} → {x.child_table}.{x.child_col}", axis=1),
        x='confidence',
        palette='viridis'
    )
    plt.title(f"Top {top_n} Inferred Relationships by Confidence")
    plt.xlabel("Confidence Score")
    plt.ylabel("")
    plt.xlim(0, 1)
    plt.show()

In [None]:
# CELL 5: Interactive Mermaid ER Diagram
def generate_mermaid(rels, threshold=0.7):
    lines = ["erDiagram"]
    seen = set()
    for r in rels:
        if r.confidence < threshold:
            continue
        key = f"{r.parent_table}_{r.child_table}"
        if key in seen:
            continue
        seen.add(key)
        lines.append(f'    {r.parent_table} {{}}')
        lines.append(f'    {r.child_table} {{}}')
        card = "||--o{ " if r.confidence > 0.9 else "||--|{ "
        lines.append(f'    {r.parent_table} {card} {r.child_table} : "{r.parent_col} → {r.child_col} [{r.confidence:.0%}]"')
    return "\n".join(lines)

mermaid_code = generate_mermaid(relationships, threshold=0.75)
print("Copy & paste into https://mermaid.live to view interactively")

display(HTML(f"""
<div class=\"mermaid\">
{mermaid_code}
</div>
<script src=\"https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js\"></script>
<script>mermaid.initialize({{startOnLoad:true}});</script>
"""))

In [None]:
# CELL 6: Export Results
export_btn = widgets.Button(description="Export to CSV + SQL")
output = widgets.Output()

def on_export(b):
    with output:
        if 'rel_df' not in locals():
            print("No data to export")
            return
            
        rel_df.to_csv("inferred_relationships.csv", index=False)
        print("Saved: inferred_relationships.csv")
        
        # Generate ALTER TABLE statements
        sql_lines = []
        for _, r in rel_df.iterrows():
            if r.confidence > 0.9:
                sql_lines.append(f"-- Confidence: {r.confidence:.1%}")
                sql_lines.append(
                    f"ALTER TABLE {r.child_table} ADD CONSTRAINT fk_{r.child_table}_{r.child_col} "
                    f"FOREIGN KEY ({r.child_col}) REFERENCES {r.parent_table}({r.parent_col});"
                )
                sql_lines.append("")
        
        with open("add_foreign_keys.sql", "w") as f:
            f.write("\n".join(sql_lines))
        print("Saved: add_foreign_keys.sql (high confidence only)")

export_btn.on_click(on_export)
display(export_btn, output)

In [None]:
# CELL 7: Explore Sample Data (Optional)
if 'sampled_data' in locals():
    table_selector = widgets.Dropdown(options=list(sampled_data.keys()), description='Table:')
    display(table_selector)

    def show_sample(change):
        if change['new']:
            display(Markdown(f"### Sample from `{change['new']}`"))
            display(sampled_data[change['new']].head(10))

    table_selector.observe(show_sample, names='value')