In [None]:
#from telegram_toolchain.data.database import get_conn
from telegram_data_models import Message, Chat, MessageTextContent, Queue
from dotenv import load_dotenv
load_dotenv()   # loads .env from cwd (or parents)
load_dotenv("../../credentials/credentials.env")
from sqlalchemy import select, func, case, create_engine
from tqdm.auto import tqdm  # works in both notebooks & terminals
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as ds
import pandas as pd
import os
import time
from pathlib import Path
import json
import glob
%pip install duckdb
import duckdb
import shutil

In [None]:
# Database setup
db_user = os.environ.get("DB_USER")
db_pass = os.environ.get("DB_PASSWORD")
db_host = os.environ.get("DB_HOST")
db_port = os.environ.get("DB_PORT")
db_name = os.environ.get("DB_NAME")

db_url = f'postgresql+psycopg2://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}'

# Dask can't work with ORM models
message_table = Message.__table__
chat_table = Chat.__table__
queue_table = Queue.__table__

In [None]:
db_url = f'postgresql+psycopg2://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}'

# Dask can't work with ORM models
message_table = Message.__table__
chat_table = Chat.__table__
queue_table = Queue.__table__

In [None]:
engine = create_engine(
    db_url,
    pool_pre_ping=True,  # good for long streaming jobs
    future=True,
)

In [None]:
df_lang = pd.read_parquet("../../data/chat_languages.parquet")
if df_lang.index.name == "chat_id" and "chat_id" not in df_lang.columns:
    df_lang = df_lang.reset_index()

In [None]:
# ------------------- 1) Detailed edge SQL statement -------------------------

# src_id: forward_from_chat_id if present else forward_from_id
# sender_id: sender_chat_id if present else from_user_id
# Type tags:
#   src_is_chat = 1 iff forward_from_chat_id is not null
#   sender_is_chat = 1 iff sender_chat_id is not null
# dst is always a chat in your model (chat_id)

# ---------------------------------------------------------------------------
# Build expressions once so we can reuse them consistently
# ---------------------------------------------------------------------------

# Source ID:
# - Prefer forward_from_chat_id when present
# - Fall back to forward_from_id otherwise
# Some messages have *both* NULL → must be filtered out
src_expr = func.coalesce(
    message_table.c.forward_from_chat_id,
    message_table.c.forward_from_id,
)

# Sender ID:
# - Prefer sender_chat_id when present
# - Fall back to from_user_id otherwise
# Some messages have *both* NULL → must be filtered out
sender_expr = func.coalesce(
    message_table.c.sender_chat_id,
    message_table.c.from_user_id,
)

# ---------------------------------------------------------------------------
# Detailed edge list SQL
#   One row per forwarded message
#   Guaranteed non-null src and sender
# ---------------------------------------------------------------------------

stmt_detailed_edges = (
    select(
        message_table.c.id.label("msg_id"),
        # Source node (chat or user)
        src_expr.label("src"),
        # Destination chat (always present)
        message_table.c.chat_id.label("dst"),
        # Sender node (chat or user)
        sender_expr.label("sender"),
        # Forward timestamp, truncated to second precision
        func.date_trunc("second", message_table.c.forward_date).label("ts"),
        # Type tag:
        # 1 if src came from forward_from_chat_id (i.e., a chat)
        # 0 if src came from forward_from_id (i.e., a user)
        case(
            (message_table.c.forward_from_chat_id.isnot(None), 1),
            else_=0,
        ).label("src_is_chat"),
        # Type tag:
        # 1 if sender came from sender_chat_id (chat)
        # 0 if sender came from from_user_id (user)
        case(
            (message_table.c.sender_chat_id.isnot(None), 1),
            else_=0,
        ).label("sender_is_chat"),
    )
    # -----------------------------------------------------------------------
    # Filters
    # -----------------------------------------------------------------------
    # Keep only forwarded messages
    .where(message_table.c.forward_date.isnot(None))
    # Ensure src is never NULL
    # (required so pandas can safely cast to uint32)
    .where(src_expr.isnot(None))
    # Ensure sender is never NULL
    # (required so pandas can safely cast to uint32)
    .where(sender_expr.isnot(None))
)

# ------------------- 2) Lang map from df_lang -------------------------------

df_lang = df_lang.copy()
df_lang["chat_id"] = df_lang["chat_id"].astype("uint32")
lang_map = df_lang.set_index("chat_id")["lang"]  # chat_id -> lang (string/NA)

# ------------------- 3) Parquet output schema -------------------------------

edge_schema = pa.schema(
    [
        pa.field("msg_id", pa.uint64()),
        pa.field("src", pa.uint32()),
        pa.field("dst", pa.uint32()),
        pa.field("sender", pa.uint32()),
        pa.field("ts", pa.timestamp("s")),  # second precision
        pa.field("src_is_chat", pa.uint8()),  # 1 if src came from forward_from_chat_id else 0
        pa.field("sender_is_chat", pa.uint8()),  # 1 if sender came from sender_chat_id else 0
        pa.field("primary", pa.uint8()),  # 1 if (src_is_chat==1 and src in df_lang), else 0
        pa.field("lang", pa.string()),  # logic below
    ]
)

out_dir = Path("../../data/timed_edge_list")
out_dir.mkdir(parents=True, exist_ok=True)

last_id_path = out_dir / "_last_id.txt"
last_id = int(last_id_path.read_text()) if last_id_path.exists() else 0
print(f"Resuming from msg_id > {last_id}", flush=True)

# Apply the resume filter
stmt_run = stmt_detailed_edges.where(message_table.c.id > last_id).order_by(message_table.c.id)

In [None]:
chunksize = 1_000_000

manifest_path = out_dir / "_manifest.jsonl"  # append-only checkpoint log


def next_part_id():
    existing = list(out_dir.glob("part-*.parquet"))
    if not existing:
        return 1
    last = max(int(p.stem.split("-")[1]) for p in existing)
    return last + 1


part_id = next_part_id()
print(f"Resuming at part_id={part_id}", flush=True)

total_rows = 0
t0 = time.time()  # job start time
last_report = t0

pbar = tqdm(unit="rows", dynamic_ncols=True)

total_bytes = sum(p.stat().st_size for p in out_dir.glob("part-*.parquet"))
size_gb = total_bytes / (1024**3)

try:
    with engine.connect().execution_options(stream_results=True) as conn:
        for i, df_chunk in enumerate(pd.read_sql(stmt_run, conn, chunksize=chunksize), start=1):
            # c0 = time.time()

            df_chunk = df_chunk.astype(
                {
                    "msg_id": "uint64",
                    "src": "uint32",
                    "dst": "uint32",
                    "sender": "uint32",
                    "src_is_chat": "uint8",
                    "sender_is_chat": "uint8",
                }
            )

            df_chunk["ts"] = pd.to_datetime(df_chunk["ts"], errors="coerce").dt.floor("s")

            src_lang = df_chunk["src"].map(lang_map)
            dst_lang = df_chunk["dst"].map(lang_map)
            primary = ((df_chunk["src_is_chat"] == 1) & src_lang.notna()).astype("uint8")

            lang_col = pd.Series(pd.NA, index=df_chunk.index, dtype="string")
            same_lang = src_lang.eq(dst_lang) & src_lang.notna()
            mask_primary_same = (primary == 1) & same_lang
            lang_col[mask_primary_same] = src_lang[mask_primary_same].astype("string")
            mask_primary_other = (primary == 1) & ~mask_primary_same
            lang_col[mask_primary_other] = "NA"
            mask_not_primary = primary == 0
            lang_col[mask_not_primary] = dst_lang[mask_not_primary].astype("string")

            df_chunk["primary"] = primary
            df_chunk["lang"] = lang_col

            table = pa.Table.from_pandas(df_chunk, schema=edge_schema, preserve_index=False)

            # IMPORTANT: define paths per part
            part_path = out_dir / f"part-{part_id:06d}.parquet"
            tmp_path = out_dir / f".part-{part_id:06d}.parquet.tmp"

            pq.write_table(table, tmp_path, compression="zstd", use_dictionary=True)
            os.replace(tmp_path, part_path)

            # update size tracker
            total_bytes += part_path.stat().st_size
            size_gb = total_bytes / (1024**3)

            # checkpoint after successful write
            chunk_last_id = int(df_chunk["msg_id"].max())
            last_id_path.write_text(str(chunk_last_id))

            # manifest
            rec = {
                "part": part_id,
                "file": part_path.name,
                "rows": int(table.num_rows),
                "written_at_unix": time.time(),
            }
            with open(manifest_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(rec) + "\n")

            part_id += 1

            # progress
            n = table.num_rows
            total_rows += n
            pbar.update(n)

            now = time.time()
            if now - last_report >= 30:
                elapsed = now - t0  # correct: since job start
                rps = total_rows / elapsed if elapsed > 0 else 0.0
                pbar.set_postfix(
                    {
                        "parts": part_id - 1,
                        "rows_M": f"{total_rows/1e6:.1f}",
                        "rows/s": f"{rps:,.0f}",
                        "dir_GB": f"{size_gb:.2f}",
                    }
                )
                last_report = now

finally:
    pbar.close()

In [None]:
out_dir = Path("../../data/timed_edge_list").resolve()
parts_glob = str(out_dir / "part-*.parquet")

# Write output in the parent directory of timed_edge_list
final_path = (out_dir.parent / "edges_sorted.parquet").resolve()

# Safety: ensure we are not writing inside the directory we plan to delete
if out_dir == final_path or out_dir in final_path.parents:
    raise RuntimeError(f"Refusing: final_path={final_path} is inside out_dir={out_dir}")

# (Optional) temp dir for DuckDB spills (NOT inside out_dir if you're going to delete it)
# temp_dir = str((out_dir.parent / "_duckdb_tmp").resolve())
# Path(temp_dir).mkdir(parents=True, exist_ok=True)

con = duckdb.connect()
con.execute("PRAGMA threads=8;")  # tune to your CPU
con.execute("PRAGMA preserve_insertion_order=false;")
# con.execute(f"PRAGMA temp_directory='{temp_dir}';")

# --------- 1) Count rows in all part files ----------
src_rows = con.execute(
    f"""
    SELECT COUNT(*)::BIGINT
    FROM read_parquet('{parts_glob}')
"""
).fetchone()[0]

if src_rows == 0:
    raise RuntimeError(f"No rows found in {parts_glob}. Aborting.")

print(f"Found {src_rows:,} rows across partial parquet files.")

# --------- 2) Write sorted parquet (ALL columns) to parent dir ----------
con.execute(
    f"""
COPY (
    SELECT *
    FROM read_parquet('{parts_glob}')
    ORDER BY src, sender, ts
)
TO '{str(final_path)}'
(FORMAT PARQUET, COMPRESSION ZSTD);
"""
)

print(f"Wrote sorted parquet: {final_path}")

# --------- 3) Verify row count in the output ----------
out_rows = con.execute(
    f"""
    SELECT COUNT(*)::BIGINT
    FROM read_parquet('{str(final_path)}')
"""
).fetchone()[0]

print(f"Output rows: {out_rows:,}")

if out_rows != src_rows:
    raise RuntimeError(
        f"Row-count mismatch! parts={src_rows:,} output={out_rows:,}. " f"NOT deleting {out_dir}"
    )

print("Row counts match ✅")

# --------- 4) Delete the entire timed_edge_list directory ----------
# Extra safety: refuse to delete something too shallow like /, /home, etc.
if len(out_dir.parts) < 3:
    raise RuntimeError(f"Refusing to delete suspiciously short path: {out_dir}")

shutil.rmtree(out_dir)
print(f"Deleted directory: {out_dir}")

print("Finalization complete.")