This notebook creates reproducible train/test FileCollections for parent-level k-NN experiments.
It uses a dataclass to embed a descriptive, structured provenance into each FileCollection.

In [1]:

import logging
import random
from dataclasses import dataclass, asdict, field
from datetime import datetime
from typing import List, Optional, Dict, Any, Iterable, Set

from sqlalchemy import and_, exists
from sqlalchemy.orm import Session

from db import get_db_engine
from db.models import (
    FilingTag, File, FileContent, FileCollection, FileCollectionMember, FileTagLabel
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("knn.splitter")


In [2]:
# %% [markdown]
# ## Configuration
# Adjust these knobs before running the build.

# %%
@dataclass
class CollectionProvenance:
    purpose: str
    parents: List[str]
    strategy: str                       # e.g., "stratified-by-child-capped"
    embedding_column: str               # "minilm_emb" | "mpnet_emb"
    split_ratio: float                  # train portion, e.g., 0.8
    per_child_cap: int
    filetype_filter: Optional[List[str]] = field(default_factory=list)  # e.g., ["pdf","docx","xlsx"] or []
    include_descendants: bool = True
    random_seed: Optional[int] = None
    created_utc: str = field(default_factory=lambda: datetime.utcnow().isoformat() + "Z")

    def to_metadata(self) -> Dict[str, Any]:
        return {
            "provenance": asdict(self),
            "counts": {}  # you can fill this after sampling (totals per parent/child)
        }

    def to_description(self) -> str:
        ft = ", ".join(self.filetype_filter) if self.filetype_filter else "all"
        train_pct = int(self.split_ratio * 100)
        test_pct = 100 - train_pct
        return f"""# Collection purpose
{self.purpose}

# Creation details
- Created (UTC): {self.created_utc}
- Scope (parents): {", ".join(self.parents)}
- Strategy: {self.strategy}
- Embedding column: {self.embedding_column}
- Train/Test split: {train_pct}/{test_pct}
- Per-child cap: {self.per_child_cap}
- Descendants included: {self.include_descendants}
- Filetypes: {ft}
- Random seed: {self.random_seed if self.random_seed is not None else "None"}

# Notes
- Files were sampled per child tag (capped) then split into train/test.
- Use this collection as the neighbor pool (train) or evaluation set (test).
"""


In [3]:
# %% [markdown]
# ## Configuration
# Adjust these knobs before running the build.

# %%
PARENTS = ['B', 'C', 'D', 'E', 'F', 'G', 'H']   # exclude 'A'
EMBED_COL = "minilm_emb"                        # or "mpnet_emb"
TRAIN_RATIO = 0.80
PER_CHILD_CAP = 50
USE_ALL_DESCENDANTS = True                      # True: children + grandchildren; False: direct children only
RANDOM_SEED = 42                                # set None for non-reproducible
FILETYPE_FILTER: Optional[List[str]] = []       # e.g., ["pdf","docx","xlsx"]; empty -> all
CREATE_PARENT_LEVEL_COLLECTIONS = True          # also aggregate child files into *_parent_train/test

PURPOSE_TEXT = (
    "Parent-level k-NN experiment (B–H). These collections provide reproducible "
    "train/test splits stratified by child tags with per-child caps."
)
STRATEGY_TEXT = "stratified-by-child-capped"


In [4]:
# %% [markdown]
# ## DB helpers

# %%
def list_direct_children(session: Session, parent_label: str) -> List[FilingTag]:
    return session.query(FilingTag).filter(FilingTag.parent_label == parent_label).all()

def list_all_descendants(session: Session, root_label: str) -> List[FilingTag]:
    """Breadth-first traversal to list all descendants under a parent (children, grandchildren, ...)."""
    descendants, queue = [], [root_label]
    seen: Set[str] = set([root_label])
    while queue:
        parent = queue.pop(0)
        kids = session.query(FilingTag).filter(FilingTag.parent_label == parent).all()
        for k in kids:
            if k.label not in seen:
                seen.add(k.label)
                descendants.append(k)
                queue.append(k.label)
    return descendants

def files_with_tag_and_embedding(
    session: Session,
    tag_label: str,
    embed_col: str = EMBED_COL,
    filetype_filter: Optional[List[str]] = None
) -> List[File]:
    """
    Files that have FileTagLabel.tag == tag_label and a non-null embedding.
    Optional filetype filtering by File.extension (case-insensitive, no dot).
    """
    emb_col = getattr(FileContent, embed_col)
    q = (
        session.query(File)
        .join(File.content)               # File -> FileContent
        .filter(emb_col.isnot(None))
        .filter(
            exists().where(
                and_(FileTagLabel.file_id == File.id, FileTagLabel.tag == tag_label)
            )
        )
    )
    if filetype_filter:
        lowered = [ext.lower().lstrip(".") for ext in filetype_filter]
        q = q.filter(File.extension.isnot(None)).filter(
            File.extension.in_(lowered)
        )
    return q.all()

def add_members(session: Session, collection: FileCollection, files: Iterable[File], role: str):
    """
    Insert members, skipping those already present to allow safe re-runs.
    """
    existing_ids = {
        fid for (fid,) in session.query(FileCollectionMember.file_id)
        .filter(FileCollectionMember.collection_id == collection.id)
        .all()
    }
    new_members = [
        FileCollectionMember(collection_id=collection.id, file_id=f.id, role=role)
        for f in files if f.id not in existing_ids
    ]
    if new_members:
        session.add_all(new_members)
        session.commit()
    logger.info(f"Added {len(new_members)} files to '{collection.name}' as role='{role}'")


In [None]:
# %% [markdown]
# ## FileCollection creation with provenance
# Creates or updates a collection, appending history to description/metadata if it already exists.

# %%
def create_or_get_collection_with_description(
    session: Session,
    name: str,
    prov: CollectionProvenance
) -> FileCollection:
    description_md = prov.to_description()
    metadata_obj = prov.to_metadata()

    coll = session.query(FileCollection).filter(FileCollection.name == name).first()
    if not coll:
        coll = FileCollection(
            name=name,
            description=description_md,
            metadata=metadata_obj
        )
        session.add(coll)
        session.commit()
        logger.info(f"Created FileCollection: {name}")
        return coll

    # Append a timestamped update if it already exists
    existing_desc = coll.description or ""
    timestamp = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%SZ")
    coll.description = (
        existing_desc.rstrip()
        + f"\n\n---\n\n**Update {timestamp} (UTC):**\n"
        + description_md.strip()
    )

    meta = coll.meta or {}
    history = meta.get("history", [])
    history.append(metadata_obj.get("provenance", {}))
    meta["history"] = history
    coll.meta = meta
    session.commit()
    logger.info(f"Updated FileCollection: {name} (appended provenance history)")
    return coll


In [6]:
# %% [markdown]
# ## Build the splits
# - Optionally aggregates all child files into {PARENT}_parent_train / {PARENT}_parent_test
# - Always creates per-child collections {CHILD}_train / {CHILD}_test

# %%

def build_collections():
    if RANDOM_SEED is not None:
        random.seed(RANDOM_SEED)

    engine = get_db_engine()
    with Session(engine) as session:
        # Base provenance (shared across collections in this run)
        base_prov = CollectionProvenance(
            purpose=PURPOSE_TEXT,
            parents=PARENTS,
            strategy=STRATEGY_TEXT,
            embedding_column=EMBED_COL,
            split_ratio=TRAIN_RATIO,
            per_child_cap=PER_CHILD_CAP,
            filetype_filter=FILETYPE_FILTER or [],
            include_descendants=USE_ALL_DESCENDANTS,
            random_seed=RANDOM_SEED
        )

        parent_train = {}
        parent_test = {}
        if CREATE_PARENT_LEVEL_COLLECTIONS:
            for p in PARENTS:
                parent_train[p] = create_or_get_collection_with_description(
                    session,
                    name=f"{p}_parent_train",
                    prov=base_prov
                )
                parent_test[p] = create_or_get_collection_with_description(
                    session,
                    name=f"{p}_parent_test",
                    prov=base_prov
                )

        for parent in PARENTS:
            # children or descendants
            children = (
                list_all_descendants(session, parent)
                if USE_ALL_DESCENDANTS
                else list_direct_children(session, parent)
            )

            if not children:
                logger.warning(f"No children found under parent {parent}")
                continue

            logger.info(f"Parent {parent}: {len(children)} child/descendant tags to sample from")

            total_added_train = 0
            total_added_test = 0

            for child in children:
                # fetch files for this child with chosen embedding + optional filetype filter
                files = files_with_tag_and_embedding(
                    session, child.label, embed_col=EMBED_COL, filetype_filter=FILETYPE_FILTER
                )
                if not files:
                    continue

                # cap per-child
                cap = min(PER_CHILD_CAP, len(files))
                sampled = random.sample(files, cap) if len(files) > cap else list(files)

                # shuffle & split
                random.shuffle(sampled)
                split_idx = int(len(sampled) * TRAIN_RATIO)
                train_files, test_files = sampled[:split_idx], sampled[split_idx:]

                # per-child collections
                child_train_name = f"{child.label}_train"
                child_test_name  = f"{child.label}_test"

                child_prov = CollectionProvenance(
                    purpose=f"Training/Test for child tag {child.label} under parent {parent}.",
                    parents=[parent],
                    strategy=STRATEGY_TEXT,
                    embedding_column=EMBED_COL,
                    split_ratio=TRAIN_RATIO,
                    per_child_cap=PER_CHILD_CAP,
                    filetype_filter=FILETYPE_FILTER or [],
                    include_descendants=False,  # this collection is specific to a single child tag
                    random_seed=RANDOM_SEED
                )

                child_train_coll = create_or_get_collection_with_description(
                    session, child_train_name, child_prov
                )
                child_test_coll = create_or_get_collection_with_description(
                    session, child_test_name, child_prov
                )

                add_members(session, child_train_coll, train_files, role="train")
                add_members(session, child_test_coll, test_files, role="test")

                # also aggregate into parent-level collections
                if CREATE_PARENT_LEVEL_COLLECTIONS:
                    add_members(session, parent_train[parent], train_files, role="train")
                    add_members(session, parent_test[parent], test_files, role="test")

                total_added_train += len(train_files)
                total_added_test  += len(test_files)

            logger.info(
                f"Parent {parent}: added {total_added_train} train and {total_added_test} test files "
                f"across its children."
            )

        logger.info("✅ Done creating train/test FileCollections.")


In [None]:
# %% [markdown]
# ## Optional: update per-collection counts in metadata
# After build, you can enrich each collection's `metadata["counts"]` with totals for easy audit later.

# %%
def refresh_collection_counts():
    engine = get_db_engine()
    with Session(engine) as session:
        all_colls = session.query(FileCollection).all()
        for coll in all_colls:
            # count members by role
            q = session.query(FileCollectionMember).filter(
                FileCollectionMember.collection_id == coll.id
            )
            train_ct = q.filter(FileCollectionMember.role == "train").count()
            test_ct  = q.filter(FileCollectionMember.role == "test").count()
            val_ct   = q.filter(FileCollectionMember.role == "val").count()

            meta = coll.meta or {}
            meta.setdefault("counts", {})
            meta["counts"].update({"train": train_ct, "test": test_ct, "val": val_ct})
            coll.meta = meta
        session.commit()
    logger.info("Updated counts in collection metadata.")

In [8]:
# %% [markdown]
# ## Execute the build

# %%
build_collections()
refresh_collection_counts()  # optional, but nice to keep metadata in sync


2025-10-01 12:37:59,993 db.db INFO Creating database engine
  created_utc: str = field(default_factory=lambda: datetime.utcnow().isoformat() + "Z")
2025-10-01 12:38:01,712 knn.splitter INFO Created FileCollection: B_parent_train
2025-10-01 12:38:01,721 knn.splitter INFO Created FileCollection: B_parent_test
2025-10-01 12:38:01,728 knn.splitter INFO Created FileCollection: C_parent_train
2025-10-01 12:38:01,738 knn.splitter INFO Created FileCollection: C_parent_test
2025-10-01 12:38:01,745 knn.splitter INFO Created FileCollection: D_parent_train
2025-10-01 12:38:01,756 knn.splitter INFO Created FileCollection: D_parent_test
2025-10-01 12:38:01,764 knn.splitter INFO Created FileCollection: E_parent_train
2025-10-01 12:38:01,775 knn.splitter INFO Created FileCollection: E_parent_test
2025-10-01 12:38:01,785 knn.splitter INFO Created FileCollection: F_parent_train
2025-10-01 12:38:01,797 knn.splitter INFO Created FileCollection: F_parent_test
2025-10-01 12:38:01,808 knn.splitter INFO Creat

AttributeError: 'MetaData' object has no attribute 'setdefault'

In [None]:
# %% [markdown]
# ## Sanity checks
# Inspect a couple of collections to ensure descriptions and metadata look right.

# %%
engine = get_db_engine()
with Session(engine) as session:
    names = [f"{p}_parent_train" for p in PARENTS] + [f"{p}_parent_test" for p in PARENTS]
    sample = (
        session.query(FileCollection)
        .filter(FileCollection.name.in_(names))
        .order_by(FileCollection.name)
        .all()
    )
    for c in sample:
        print("—"*80)
        print(c.name)
        print("Counts:", c.metadata.get("counts", {}))
        print(c.description.splitlines()[0])  # first line just to confirm MD populated
