In [1]:
# Generate dbdiagram.io DSL from the Northwind SQLite database
# Output is printed and also saved to data/northwind_dbdiagram.txt

import os
import sqlite3
import re
from typing import Dict, List, Set

DB_PATH = "/home/mohammed/Desktop/tech_projects/northwind_ai_workflow/data/northwind.sqlite"
OUTPUT_PATH = "/home/mohammed/Desktop/tech_projects/northwind_ai_workflow/data/northwind_dbdiagram.txt"

if not os.path.exists(DB_PATH):
    raise FileNotFoundError(f"Database not found at: {DB_PATH}")

connection = sqlite3.connect(DB_PATH)
connection.row_factory = sqlite3.Row
cursor = connection.cursor()

# Helpers

def q(identifier: str) -> str:
    return '"' + str(identifier).replace('"', '\\"') + '"'

def to_snake_case(name: str) -> str:
    name = re.sub(r"[^0-9A-Za-z]+", " ", name).strip()
    name = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", name)
    name = re.sub(r"\s+", "_", name)
    return name.lower()


def get_table_names(cur: sqlite3.Cursor) -> List[str]:
    cur.execute(
        """
        SELECT name
        FROM sqlite_master
        WHERE type = 'table'
          AND name NOT LIKE 'sqlite_%'
        ORDER BY name
        """
    )
    return [row["name"] for row in cur.fetchall()]


def get_table_columns(cur: sqlite3.Cursor, table_name: str) -> List[sqlite3.Row]:
    escaped = table_name.replace("'", "''")
    cur.execute("PRAGMA table_info('%s')" % escaped)
    return cur.fetchall()


def get_foreign_keys(cur: sqlite3.Cursor, table_name: str) -> List[sqlite3.Row]:
    escaped = table_name.replace("'", "''")
    cur.execute("PRAGMA foreign_key_list('%s')" % escaped)
    return cur.fetchall()


def get_unique_single_column_columns(cur: sqlite3.Cursor, table_name: str) -> Set[str]:
    escaped = table_name.replace("'", "''")
    cur.execute("PRAGMA index_list('%s')" % escaped)
    uniques = set()
    for idx in cur.fetchall():
        if idx["unique"] == 1:
            idx_name = idx["name"]
            idx_escaped = idx_name.replace("'", "''")
            cur.execute("PRAGMA index_info('%s')" % idx_escaped)
            cols = [r["name"] for r in cur.fetchall()]
            if len(cols) == 1:
                uniques.add(cols[0])
    return uniques


def get_primary_key_columns(columns: List[sqlite3.Row]) -> List[str]:
    # In SQLite, pk column has pk > 0; order defines composite pk ordering
    return [col["name"] for col in sorted(columns, key=lambda c: c["pk"]) if col["pk"]]


def normalize_type(sqlite_type: str) -> str:
    t = (sqlite_type or "").strip()
    return t or "text"


def format_default(dflt_value) -> str:
    if dflt_value is None:
        return ""
    return f"default: {dflt_value}"

# Build metadata

tables = get_table_names(cursor)

primary_keys: Dict[str, List[str]] = {}
columns_by_table: Dict[str, List[sqlite3.Row]] = {}
unique_cols_by_table: Dict[str, Set[str]] = {}

for tbl in tables:
    cols = get_table_columns(cursor, tbl)
    columns_by_table[tbl] = cols
    primary_keys[tbl] = get_primary_key_columns(cols)
    unique_cols_by_table[tbl] = get_unique_single_column_columns(cursor, tbl)

# Build table blocks (original Northwind table names)
lines: List[str] = []
for tbl in tables:
    lines.append(f"Table {q(tbl)} {{")
    uniques = unique_cols_by_table[tbl]
    for col in columns_by_table[tbl]:
        col_name = col["name"]
        col_type = normalize_type(col["type"])
        attrs: List[str] = []
        if col["pk"]:
            attrs.append("primary key")
        if col["notnull"]:
            attrs.append("not null")
        if col_name in uniques and not col["pk"]:
            attrs.append("unique")
        dflt = format_default(col["dflt_value"])  # may be ''
        if dflt:
            attrs.append(dflt)
        attr_str = f" [{', '.join(attrs)}]" if attrs else ""
        lines.append(f"  {q(col_name)} {col_type}{attr_str}")
    lines.append("}")

# Build relationships (original table names)
for child in tables:
    fks = get_foreign_keys(cursor, child)
    for fk in fks:
        parent = fk["table"]
        child_col = fk["from"]
        parent_col = fk["to"]
        if not parent_col:
            pk_cols = primary_keys.get(parent) or []
            parent_col = pk_cols[0] if pk_cols else "id"
        lines.append(
            f"Ref: {q(child)}.{q(child_col)} > {q(parent)}.{q(parent_col)}"
        )

dsl_text = "\n".join(lines)

print(dsl_text)
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
    f.write(dsl_text)

print(f"\nSaved DSL to: {OUTPUT_PATH}")


Table "Categories" {
  "CategoryID" INTEGER [primary key]
  "CategoryName" TEXT
  "Description" TEXT
  "Picture" BLOB
}
Table "CustomerCustomerDemo" {
  "CustomerID" TEXT [primary key, not null]
  "CustomerTypeID" TEXT [primary key, not null]
}
Table "CustomerDemographics" {
  "CustomerTypeID" TEXT [primary key, not null]
  "CustomerDesc" TEXT
}
Table "Customers" {
  "CustomerID" TEXT [primary key]
  "CompanyName" TEXT
  "ContactName" TEXT
  "ContactTitle" TEXT
  "Address" TEXT
  "City" TEXT
  "Region" TEXT
  "PostalCode" TEXT
  "Country" TEXT
  "Phone" TEXT
  "Fax" TEXT
}
Table "EmployeeTerritories" {
  "EmployeeID" INTEGER [primary key, not null]
  "TerritoryID" TEXT [primary key, not null]
}
Table "Employees" {
  "EmployeeID" INTEGER [primary key]
  "LastName" TEXT
  "FirstName" TEXT
  "Title" TEXT
  "TitleOfCourtesy" TEXT
  "BirthDate" DATE
  "HireDate" DATE
  "Address" TEXT
  "City" TEXT
  "Region" TEXT
  "PostalCode" TEXT
  "Country" TEXT
  "HomePhone" TEXT
  "Extension" TEXT
  "