In [None]:
# ==== Hybrid LangGraph: agents for intelligence, Python for execution ==========
# What this does:
# 1) schema  -> read both DB schemas (Python)
# 2) mapplan -> LLM proposes mapping JSON (fallback to your heuristic)
# 3) plan    -> deterministic plan from mapping (Python)
# 4) sqlgen  -> deterministic SQL from plan (Python)
# 5) exec    -> execute SQL (Python; captures errors)
# 6) validate-> counts + checksum (Python)
# 7) remediate (optional, 1 retry) -> LLM patches mapping if validation/execution failed
# 8) narrate -> LLM explains result in human language


# Imports, keys, DB names

# Loads std libs, LangGraph, and LangChain’s OpenAI wrapper

# Reads OPENAI_API_KEY from env; if missing, the flow still works using deterministic fallbacks

# Sets MODEL = "gpt-4o-mini"

# Sets SRC_DB and TGT_DB file names


!pip -q install langgraph==0.2.39 langchain-openai==0.2.5 --upgrade

import os, re, json, uuid, sqlite3, pandas as pd
from typing import TypedDict, Optional, List, Dict, Any, Tuple
from langgraph.graph import StateGraph, END

# If using Colab secrets:
import os
from google.colab import userdata
OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')

# OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
USE_LLM = bool(OPENAI_API_KEY)
MODEL = "gpt-4o-mini"

SRC_DB, TGT_DB = "source.db", "target.db"


# ---- Optional clean slate on every run (comment out if using UPSERT) ---------
def reset_targets(db_path: str):
    with sqlite3.connect(db_path) as con:
        for t in ["dim_customer","fact_order"]:
            try: con.execute(f"DELETE FROM {t}")
            except Exception: pass
        con.commit()
reset_targets(TGT_DB)

# Your exact database setup (source.db, target.db)

import os, sqlite3

SRC_DB, TGT_DB = "source.db", "target.db"
for f in [SRC_DB, TGT_DB]:
    try: os.remove(f)
    except FileNotFoundError: pass

src_con = sqlite3.connect(SRC_DB)
src_cur = src_con.cursor()
src_cur.executescript("""
CREATE TABLE customers (
  id INTEGER PRIMARY KEY,
  fname TEXT,
  lname TEXT,
  email TEXT,
  signup_date TEXT,
  active_flag INTEGER
);
CREATE TABLE orders (
  order_id INTEGER PRIMARY KEY,
  customer_id INTEGER,
  amount_cents INTEGER,
  order_ts TEXT
);
""")
src_cur.executemany("INSERT INTO customers VALUES (?,?,?,?,?,?)", [
    (1,'Amina','Alavi','amina@example.com','2023-01-15',1),
    (2,'Sam','Lee','sam.lee@example.com','2022-11-01',1),
    (3,'Dana','Kim','dana.kim@example.com','2021-05-20',0),
])
src_cur.executemany("INSERT INTO orders VALUES (?,?,?,?)", [
    (101,1,2599,'2024-09-01 12:30:00'),
    (102,1,1099,'2024-09-15 09:00:00'),
    (103,2, 500,'2024-09-21 18:10:00'),
])
src_con.commit(); src_con.close()

tgt_con = sqlite3.connect(TGT_DB)
tgt_cur = tgt_con.cursor()
tgt_cur.executescript("""
CREATE TABLE dim_customer (
  customer_id INTEGER PRIMARY KEY,
  full_name TEXT,
  email TEXT,
  signup_date TEXT,
  is_active INTEGER
);
CREATE TABLE fact_order (
  order_id INTEGER PRIMARY KEY,
  customer_id INTEGER,
  amount_usd REAL,
  order_date TEXT
);
""")
tgt_con.commit(); tgt_con.close()
print("Databases created.")


Databases created.


In [None]:
# Hybrid LangGraph migration: agents for mapping/remediation/narration, Python for execution

!pip -q install langgraph==0.2.39 langchain-openai==0.2.5 --upgrade

import os, re, json, uuid, sqlite3, pandas as pd
from typing import TypedDict, Optional, List, Dict, Any
from langgraph.graph import StateGraph, END

# Optional: set your key in the environment beforehand
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
USE_LLM = bool(OPENAI_API_KEY)
MODEL = "gpt-4o-mini"

SRC_DB, TGT_DB = "source.db", "target.db"

# ------------------ Helpers (self-contained) ------------------
# sqlite_schema_tool(db_path): introspects SQLite tables and columns into a Python dict. This is how we show the LLM what exists
def sqlite_schema_tool(db_path: str) -> Dict[str, Any]:
    out = {"db": db_path, "tables": {}}
    with sqlite3.connect(db_path) as con:
        cur = con.cursor()
        cur.execute("SELECT name, sql FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
        for name, ddl in cur.fetchall():
            cols = []
            cur.execute(f"PRAGMA table_info('{name}')")
            for _, col, ctype, notnull, dflt, pk in cur.fetchall():
                cols.append({"name": col, "type": ctype, "notnull": bool(notnull), "default": dflt, "pk": bool(pk)})
            out["tables"][name] = {"ddl": ddl, "columns": cols}
    return out

# infer_mapping_with_derivations(src_schema, tgt_schema): your deterministic mapping for customers→dim_customer
# and orders→fact_order with expressions for derived columns. This acts as a fallback if the LLM isn’t available or goes off-spec
def infer_mapping_with_derivations(src_schema: Dict[str,Any], tgt_schema: Dict[str,Any]) -> Dict[str, Any]:
    mapping = {
        "customers": {
            "target": "dim_customer",
            "columns": [
                {"src": "id",                 "tgt": "customer_id", "expr": "id"},
                {"src": "fname||' '||lname",  "tgt": "full_name",   "expr": "fname || ' ' || lname"},
                {"src": "email",              "tgt": "email",       "expr": "email"},
                {"src": "signup_date",        "tgt": "signup_date", "expr": "signup_date"},
                {"src": "active_flag",        "tgt": "is_active",   "expr": "active_flag"},
            ]
        },
        "orders": {
            "target": "fact_order",
            "columns": [
                {"src": "order_id",     "tgt": "order_id",   "expr": "order_id"},
                {"src": "customer_id",  "tgt": "customer_id","expr": "customer_id"},
                {"src": "amount_cents", "tgt": "amount_usd", "expr": "CAST(amount_cents AS REAL) / 100.0"},
                {"src": "order_ts",     "tgt": "order_date", "expr": "substr(order_ts, 1, 10)"},
            ]
        }
    }
    for s_table, m in list(mapping.items()):
        tgt = m["target"]
        mapping[s_table]["create_target"] = tgt not in tgt_schema["tables"]
    return {"proposed_mapping": mapping}

# propose_plan_tool(src_schema, tgt_schema, mapping): turns a mapping into a step list. For each source table:
#  [CREATE_TABLE if needed, TRANSFER, VALIDATE_COUNTS, CHECKSUM]. Returns {"plan_id","steps","mapping"}
def propose_plan_tool(src_schema: Dict[str,Any], tgt_schema: Dict[str,Any], mapping: Dict[str,Any]) -> Dict[str, Any]:
    steps = []
    mp = mapping.get("proposed_mapping", mapping)
    for s_table, m in mp.items():
        tgt = m.get("target")
        if m.get("create_target", False) or tgt not in tgt_schema["tables"]:
            steps.append({"action":"CREATE_TABLE", "source_table": s_table, "target_table": tgt})
        steps.append({"action":"TRANSFER", "source_table": s_table, "target_table": tgt})
        steps.append({"action":"VALIDATE_COUNTS", "source_table": s_table, "target_table": tgt})
        steps.append({"action":"CHECKSUM", "source_table": s_table, "target_table": tgt, "columns":"*"})
    import uuid as _uuid
    return {"plan_id": str(_uuid.uuid4()), "steps": steps, "mapping": mp}

# generate_sql_tool(src_db, tgt_db, step, src_schema, mapping):
# For CREATE_TABLE: builds a target DDL from the mapping and simple type rules
# For TRANSFER: emits an ATTACH/INSERT/SELECT/DETACH script that pulls from src and inserts into tgt with the mapping exprs

def generate_sql_tool(src_db: str, tgt_db: str, step: dict, src_schema: dict, mapping: dict) -> dict:
    action = step["action"]
    if action == "CREATE_TABLE":
        s_table = step["source_table"]; t_table = step["target_table"]
        col_map = mapping[s_table]["columns"]
        src_cols = {c["name"]: (c["type"] or "TEXT") for c in src_schema["tables"][s_table]["columns"]}
        def guess_type(tgt_col, src_expr):
            if tgt_col.lower().endswith("_id") or tgt_col in ("order_id","customer_id","is_active"):
                return "INTEGER"
            if tgt_col == "amount_usd":
                return "REAL"
            return src_cols.get(src_expr, "TEXT") if src_expr in src_cols else "TEXT"
        col_defs = []
        pk = "customer_id" if t_table == "dim_customer" else ("order_id" if t_table == "fact_order" else None)
        for cm in col_map:
            tgt_col = cm["tgt"]
            typ = guess_type(tgt_col, cm.get("src") or cm.get("expr",""))
            if pk and tgt_col == pk:
                col_defs.append(f"{tgt_col} {typ} PRIMARY KEY")
            else:
                col_defs.append(f"{tgt_col} {typ}")
        ddl = f"CREATE TABLE IF NOT EXISTS {t_table} (\n  " + ",\n  ".join(col_defs) + "\n)"
        return {"sql": ddl}

    if action == "TRANSFER":
        s_table = step["source_table"]; t_table = step["target_table"]
        col_map = mapping[s_table]["columns"]
        tgt_cols = [cm["tgt"] for cm in col_map]
        selects = [cm.get("expr") or cm["src"] for cm in col_map]
        sql = f"""
ATTACH DATABASE '{src_db}' AS src;
INSERT INTO {t_table} ({', '.join(tgt_cols)})
SELECT {', '.join(selects)} FROM src.{s_table};
DETACH DATABASE src;
"""
        return {"sql": sql}

    if action in ("VALIDATE_COUNTS","CHECKSUM"):
        return {"sql": None}
    return {"sql": None}

# execute_sql_tool(db_path, sql): runs executescript against the target DB so multi-statement SQL works
def execute_sql_tool(db_path: str, sql: str) -> dict:
    with sqlite3.connect(db_path) as con:
        cur = con.cursor()
        cur.executescript(sql)
        con.commit()
    return {"ok": True}

# validate_counts_tool and checksum_tool: basic validation primitives used later
def validate_counts_tool(src_db: str, tgt_db: str, s_table: str, t_table: str) -> Dict[str, Any]:
    with sqlite3.connect(src_db) as cs, sqlite3.connect(tgt_db) as ct:
        src_cnt = pd.read_sql_query(f"SELECT COUNT(*) AS n FROM {s_table}", cs)["n"].iloc[0]
        try:
            tgt_cnt = pd.read_sql_query(f"SELECT COUNT(*) AS n FROM {t_table}", ct)["n"].iloc[0]
        except Exception:
            tgt_cnt = None
    return {"src_count": int(src_cnt), "tgt_count": None if tgt_cnt is None else int(tgt_cnt),
            "match": tgt_cnt is not None and int(src_cnt)==int(tgt_cnt)}

def checksum_tool(db_path: str, table: str, cols: Optional[List[str]]=None) -> Dict[str, Any]:
    if cols is None or cols=="*" or cols==["*"]:
        with sqlite3.connect(db_path) as con:
            df = pd.read_sql_query(f"PRAGMA table_info('{table}')", con)
            cols = df["name"].tolist()
    expr = " || '|' || ".join([f"IFNULL({c},'NULL')" for c in cols])
    with sqlite3.connect(db_path) as con:
        q = f"SELECT SUM(LENGTH({expr})) AS len_sum FROM {table}"
        len_sum = pd.read_sql_query(q, con)["len_sum"].iloc[0] or 0
    return {"columns": cols, "len_sum": int(len_sum)}

# Optional: clean target tables so reruns don’t hit PK collisions
def reset_targets(db_path: str):
    with sqlite3.connect(db_path) as con:
        for t in ["dim_customer", "fact_order"]:
            try: con.execute(f"DELETE FROM {t}")
            except Exception: pass
        con.commit()

reset_targets(TGT_DB)



# ------------------ Tiny LLM utilities ------------------
# llm_json(prompt, default): calls the model and tries to parse the first JSON object
# in the reply. If there’s no key or bad JSON, returns default
def llm_json(prompt: str, default: Any) -> Any:
    if not USE_LLM:
        return default
    from langchain_openai import ChatOpenAI
    llm = ChatOpenAI(model=MODEL, temperature=0.2, openai_api_key=OPENAI_API_KEY)
    txt = llm.invoke(prompt).content
    m = re.search(r"\{[\s\S]*\}", txt)
    if not m:
        return default
    try:
        return json.loads(m.group(0))
    except Exception:
        return default

# llm_text(prompt): calls the model and returns the text
def llm_text(prompt: str) -> str:
    if not USE_LLM:
        return "(narration skipped: no OPENAI_API_KEY in environment)"
    from langchain_openai import ChatOpenAI
    llm = ChatOpenAI(model=MODEL, temperature=0.2, openai_api_key=OPENAI_API_KEY)
    return llm.invoke(prompt).content

In [None]:


# ------------------ Graph state and nodes ------------------
# The state type (MigState)
# This is the working memory the graph passes through nodes
# Keys you will look at:
# mapping: the mapping currently in use
# plan: the plan with steps
# sql_blocks: the generated SQLs that were actually executed
# exec_log and exec_errors: what ran and any errors
# validations: count and checksum results
# narration: agent-written summary at the end
# attempt: retry counter
# Nodes (this is where the work happens)

class MigState(TypedDict):
    plan_id: str                     # 🆔 A unique ID for this migration run (like a project code)
    attempt: int                     # 🔁 How many times we’ve tried running this plan (helps with retries)
    src_schema: Dict[str, Any]       # 📜 Structure of the source database (tables, columns, etc.)
    tgt_schema: Dict[str, Any]       # 📜 Structure of the target database
    mapping: Dict[str, Any]          # 🧭 The “map” showing how source tables/columns match to target ones
    plan: Dict[str, Any]             # 🪄 The actual migration plan (list of steps: create table, transfer data, etc.)
    sql_blocks: List[Dict[str, Any]] # 📦 The SQL commands we generated to run each step
    exec_log: List[Dict[str, Any]]   # 🧰 A history of what SQL statements were executed and their results
    exec_errors: List[str]           # 🚨 Any errors that happened during execution
    validations: List[Dict[str, Any]]# ✅ Results of checks like row counts and checksums
    narration: Optional[str]         # 🗣️ A plain-English summary of what happened (agents write this)

# 🧱 node_schema reads the structure (schema) of both the source and target databases.
# It saves what tables and columns exist in each one into MigState.
# Later steps use this info to plan how data should be transferred.
def node_schema(state: MigState) -> MigState:
    state["src_schema"] = sqlite_schema_tool(SRC_DB)
    state["tgt_schema"] = sqlite_schema_tool(TGT_DB)
    return state

# 🗺️ node_mapplan figures out how tables and columns from the source match up with the target.
# It uses that mapping to build a detailed migration plan — like a checklist of what needs to happen.
# The plan includes steps like CREATE_TABLE, TRANSFER, VALIDATE_COUNTS, and CHECKSUM.
def node_mapplan(state: MigState) -> MigState:
    fallback = infer_mapping_with_derivations(state["src_schema"], state["tgt_schema"])["proposed_mapping"]
    prompt = f"""
You are a data migration agent. Propose mapping JSON in this structure:
{{ "<source_table>": {{ "target":"<target_table>", "columns":[{{"src":"...", "tgt":"...", "expr":"..."}}] }} }}
Use valid SQLite expressions. Include derived fields where needed.
Source schema:
{json.dumps(state["src_schema"], indent=2)}
Target schema:
{json.dumps(state["tgt_schema"], indent=2)}
"""
    proposed = llm_json(prompt, default=fallback)
    state["mapping"] = proposed if isinstance(proposed, dict) and proposed else fallback
    return state

def node_plan(state: MigState) -> MigState:
    plan = propose_plan_tool(state["src_schema"], state["tgt_schema"], {"proposed_mapping": state["mapping"]})
    state["plan"] = plan
    return state

# 🧪 node_sqlgen generates the actual SQL commands for each step in the plan.
# Example: it creates CREATE TABLE statements and INSERT ... SELECT queries for data transfer.
# These SQL statements will later be executed to perform the migration.
def node_sqlgen(state: MigState) -> MigState:
    blocks = []
    for i, step in enumerate(state["plan"]["steps"], 1):
        g = generate_sql_tool(SRC_DB, TGT_DB, step, state["src_schema"], state["plan"]["mapping"])
        if g.get("sql"):
            blocks.append({"i": i, "action": step["action"], "sql": g["sql"],
                           "source_table": step.get("source_table"), "target_table": step.get("target_table")})
    state["sql_blocks"] = blocks
    return state

# ⚙️ node_execute runs the SQL statements created earlier against the target database.
# It also performs validation checks — like counting rows and verifying checksums — to make sure
# the data transferred correctly. All results (successes or errors) are saved into MigState.
def node_execute(state: MigState) -> MigState:
    state["exec_log"] = []
    state["exec_errors"] = []
    for b in state["sql_blocks"]:
        if b["action"] not in ("CREATE_TABLE","TRANSFER"):
            continue
        try:
            execute_sql_tool(TGT_DB, b["sql"])
            state["exec_log"].append({"i": b["i"], "action": b["action"]})
        except Exception as e:
            state["exec_errors"].append(f"Step {b['i']} {b['action']} error: {e}")
    return state

def _all_ok(vs: List[Dict[str,Any]]) -> bool:
    ok = True
    for v in vs:
        if v["type"] == "counts" and not v.get("match", False): ok = False
        if v["type"] == "checksum" and not v.get("match", False): ok = False
    return ok

def node_validate(state: MigState) -> MigState:
    reports = []
    for step in state["plan"]["steps"]:
        if step["action"] == "VALIDATE_COUNTS":
            r = validate_counts_tool(SRC_DB, TGT_DB, step["source_table"], step["target_table"])
            reports.append({"type":"counts","src":step["source_table"],"tgt":step["target_table"], **r})
        if step["action"] == "CHECKSUM":
            s = checksum_tool(SRC_DB, step["source_table"])
            t = checksum_tool(TGT_DB, step["target_table"])
            reports.append({"type":"checksum","src":step["source_table"],"tgt":step["target_table"],
                            "src_len_sum": s["len_sum"], "tgt_len_sum": t["len_sum"], "match": s["len_sum"]==t["len_sum"]})
    state["validations"] = reports
    return state

def node_remediate(state: MigState) -> MigState:
    if state["attempt"] >= 1:
        return state
    bad = (len(state["exec_errors"]) > 0) or (not _all_ok(state["validations"]))
    if not bad:
        return state
    prompt = f"""
Execution or validation failed. Patch the mapping so it works.
Errors:
{json.dumps(state["exec_errors"], indent=2)}
Validations:
{json.dumps(state["validations"], indent=2)}
Current mapping:
{json.dumps(state["mapping"], indent=2)}
Source schema:
{json.dumps(state["src_schema"], indent=2)}
Target schema:
{json.dumps(state["tgt_schema"], indent=2)}
Return only a JSON mapping object: {{ "<source_table>": {{ "target":"...", "columns":[{{"src":"...","tgt":"...","expr":"..."}}] }} }}
"""
    patched = llm_json(prompt, default=state["mapping"])
    state["mapping"] = patched if isinstance(patched, dict) and patched else state["mapping"]
    state["attempt"] += 1
    return state

# 📜 node_narrate writes a simple human-readable summary of everything that happened:
# how the migration went, what steps ran, if validations passed, and if the plan succeeded.
# This is helpful for debugging or reporting what the agents did.
def node_narrate(state: MigState) -> MigState:
    status = "OK" if (len(state["exec_errors"])==0 and _all_ok(state["validations"])) else "FAIL"
    summary = {
        "plan_id": state["plan_id"],
        "attempts": state["attempt"],
        "exec_log": state["exec_log"],
        "exec_errors": state["exec_errors"],
        "validations": state["validations"],
        "final_status": status
    }
    prompt = (
        "You are the Coordinator. Summarize the migration in <=160 words using ONLY this JSON. "
        "End with 'STATUS: OK' or 'STATUS: FAIL'.\n\n" + json.dumps(summary, indent=2)
    )
    state["narration"] = llm_text(prompt)
    return state

In [None]:


# ------------------ Graph wiring ------------------

graph = StateGraph(MigState)
graph.add_node("get_schema",      node_schema)
graph.add_node("propose_mapping", node_mapplan)
graph.add_node("build_plan",      node_plan)
graph.add_node("gen_sql",         node_sqlgen)
graph.add_node("run_sql",         node_execute)
graph.add_node("do_validate",     node_validate)
graph.add_node("remediate",       node_remediate)
graph.add_node("narrate",         node_narrate)

graph.set_entry_point("get_schema")
graph.add_edge("get_schema", "propose_mapping")
graph.add_edge("propose_mapping", "build_plan")
graph.add_edge("build_plan", "gen_sql")
graph.add_edge("gen_sql", "run_sql")
graph.add_edge("run_sql", "do_validate")

def after_validate(state: MigState) -> str:
    bad = (len(state["exec_errors"]) > 0) or (not _all_ok(state["validations"]))
    return "remediate" if (bad and state["attempt"] < 1) else "narrate"

graph.add_conditional_edges("do_validate", after_validate, {
    "remediate": "remediate",
    "narrate": "narrate"
})
graph.add_edge("remediate", "build_plan")
graph.add_edge("narrate", END)

app = graph.compile()

In [None]:


# ------------------ Run once ------------------

initial: MigState = {
    "plan_id": str(uuid.uuid4()),
    "attempt": 0,
    "src_schema": {},
    "tgt_schema": {},
    "mapping": {},
    "plan": {},
    "sql_blocks": [],
    "exec_log": [],
    "exec_errors": [],
    "validations": [],
    "narration": None
}

In [None]:

final_state = app.invoke(initial)

print("Plan ID:", final_state["plan_id"])
print("Attempts:", final_state["attempt"])
print("Executed statements:", len(final_state["exec_log"]))
print("Exec errors:", final_state["exec_errors"])
print("Validations:")
for v in final_state["validations"]:
    print(" ", v)
print("\nNarration:\n", final_state["narration"])

with sqlite3.connect(TGT_DB) as con:
    print("\n-- dim_customer")
    print(pd.read_sql_query("SELECT * FROM dim_customer ORDER BY customer_id", con))
    print("\n-- fact_order")
    print(pd.read_sql_query("SELECT * FROM fact_order ORDER BY order_id", con))


Plan ID: 7ec13e78-1663-4f39-83a2-074a315420cb
Attempts: 1
Executed statements: 0
Exec errors: ['Step 1 TRANSFER error: UNIQUE constraint failed: dim_customer.customer_id', 'Step 4 TRANSFER error: UNIQUE constraint failed: fact_order.order_id']
Validations:
  {'type': 'counts', 'src': 'customers', 'tgt': 'dim_customer', 'src_count': 3, 'tgt_count': 3, 'match': True}
  {'type': 'checksum', 'src': 'customers', 'tgt': 'dim_customer', 'src_len_sum': 130, 'tgt_len_sum': 130, 'match': True}
  {'type': 'counts', 'src': 'orders', 'tgt': 'fact_order', 'src_count': 3, 'tgt_count': 3, 'match': True}
  {'type': 'checksum', 'src': 'orders', 'tgt': 'fact_order', 'src_len_sum': 89, 'tgt_len_sum': 64, 'match': False}

Narration:
 (narration skipped: no OPENAI_API_KEY in environment)

-- dim_customer
   customer_id    full_name                 email signup_date  is_active
0            1  Amina Alavi     amina@example.com  2023-01-15          1
1            2      Sam Lee   sam.lee@example.com  2022-11-0