This notebook creates reproducible train/test FileCollections for parent-level k-NN experiments.
It uses a dataclass to embed a descriptive, structured provenance into each FileCollection.

In [None]:
# %%
import json
import logging, random
from dataclasses import dataclass, asdict, field
from datetime import datetime
from typing import List, Optional, Dict, Any, Iterable, Set

from sqlalchemy import and_, exists
from sqlalchemy.orm import Session

from db import get_db_engine
from db.models import (
    FilingTag, File, FileContent, FileCollection, FileCollectionMember, FileTagLabel
)
from knn.knn import CollectionProvenance


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("knn.simple_split")


In [None]:
# %% [markdown]
# ## Configuration
# Adjust these knobs before running the build.

# %%
PARENTS = ["B","C","D","E","F","G","H"]
EMBED_COL = "minilm_emb"
TRAIN_RATIO = 0.8
PER_CHILD_CAP = 50
USE_ALL_DESCENDANTS = True # True: children + grandchildren; False: direct children only
RANDOM_SEED = 42 # set None for non-reproducible
FILETYPE_FILTER: Optional[List[str]] = [] 

PURPOSE_TEXT = "Simple global train/test split for parent-level k-NN experiments."
STRATEGY_TEXT = "stratified-by-child-capped"



In [None]:
# %% [markdown]
# ## DB helpers

# %%
def list_all_descendants(session: Session, root_label: str) -> List[FilingTag]:
    descendants, queue = [], [root_label]
    seen: Set[str] = set([root_label])
    while queue:
        parent = queue.pop(0)
        kids = session.query(FilingTag).filter(FilingTag.parent_label == parent).all()
        for k in kids:
            if k.label not in seen:
                seen.add(k.label)
                descendants.append(k)
                queue.append(k.label)
    return descendants

def files_with_tag_and_embedding(session: Session, tag_label: str) -> List[File]:
    emb_col = getattr(FileContent, EMBED_COL)
    q = (
        session.query(File)
        .join(File.content)
        .filter(emb_col.isnot(None))
        .filter(
            exists().where(
                and_(FileTagLabel.file_id == File.id, FileTagLabel.tag == tag_label)
            )
        )
    )
    if FILETYPE_FILTER:
        lowered = [ext.lower().lstrip(".") for ext in FILETYPE_FILTER]
        q = q.filter(File.extension.in_(lowered))
    return q.all()

def create_or_get_collection(session: Session, name: str, prov: CollectionProvenance) -> FileCollection:
    coll = session.query(FileCollection).filter(FileCollection.name == name).first()
    if not coll:
        coll = FileCollection(name=name, description=prov.to_description(), meta=prov.to_metadata())
        session.add(coll); session.commit()
        logger.info(f"Created FileCollection {name}")
    return coll

def add_members(session: Session, collection: FileCollection, files: Iterable[File], role: str, batch_size: int = 1000) -> int:
    """
    Insert (collection_id, file_id, role) rows safely:
      - De-dupes incoming files by file_id
      - Skips rows that already exist in DB for this collection+role
      - Inserts in batches to avoid giant statements
    """
    # 1) de-dupe input
    unique_by_id = {}
    for f in files:
        if f and getattr(f, "id", None) is not None:
            unique_by_id.setdefault(f.id, f)
    candidates = list(unique_by_id.values())
    if not candidates:
        logger.info(f"No candidates to add for {collection.name} ({role}).")
        return 0

    # 2) fetch existing file_ids for this collection+role
    existing_ids = {
        fid for (fid,) in session.query(FileCollectionMember.file_id)
        .filter(
            FileCollectionMember.collection_id == collection.id,
            FileCollectionMember.role == role
        ).all()
    }

    # 3) choose new rows only
    to_insert = [f for f in candidates if f.id not in existing_ids]
    if not to_insert:
        logger.info(f"All {role} members already present in {collection.name}.")
        return 0

    # 4) batch insert
    added = 0
    for i in range(0, len(to_insert), batch_size):
        chunk = to_insert[i:i+batch_size]
        session.add_all([
            FileCollectionMember(collection_id=collection.id, file_id=f.id, role=role)
            for f in chunk
        ])
        session.commit()
        added += len(chunk)

    logger.info(f"Added {added} {role} files to {collection.name}")
    return added


In [None]:
# %%
def build_simple_split():
    if RANDOM_SEED is not None:
        random.seed(RANDOM_SEED)

    engine = get_db_engine()
    with Session(engine) as session:
        prov = CollectionProvenance(
            purpose=PURPOSE_TEXT,
            parents=PARENTS,
            strategy=STRATEGY_TEXT,
            embedding_column=EMBED_COL,
            split_ratio=TRAIN_RATIO,
            per_child_cap=PER_CHILD_CAP,
            filetype_filter=FILETYPE_FILTER,
            include_descendants=USE_ALL_DESCENDANTS,
            random_seed=RANDOM_SEED,
        )

        train_coll = create_or_get_collection(session, "global_train", prov)
        test_coll  = create_or_get_collection(session, "global_test", prov)

        # We'll accumulate by id to keep things clean
        train_files, test_files = [], []

        def direct_children(sess, parent):
            return sess.query(FilingTag).filter(FilingTag.parent_label == parent).all()

        for parent in PARENTS:
            children = list_all_descendants(session, parent) if USE_ALL_DESCENDANTS else direct_children(session, parent)

            for child in children:
                files = files_with_tag_and_embedding(session, child.label)
                if not files:
                    continue

                cap = min(PER_CHILD_CAP, len(files))
                sampled = random.sample(files, cap) if len(files) > cap else list(files)
                random.shuffle(sampled)

                split_idx = int(len(sampled) * TRAIN_RATIO)
                train_files.extend(sampled[:split_idx])
                test_files.extend(sampled[split_idx:])

        # ---- De-dup & disjointness enforcement ----
        def dedupe(files):
            seen = set()
            out = []
            for f in files:
                if f.id not in seen:
                    seen.add(f.id)
                    out.append(f)
            return out

        train_files = dedupe(train_files)
        test_files  = dedupe(test_files)

        # remove overlaps (prefer keeping in train)
        train_ids = {f.id for f in train_files}
        test_files = [f for f in test_files if f.id not in train_ids]

        logger.info(f"Prepared train {len(train_files)} / test {len(test_files)} (after dedupe & disjointness)")

        # insert (with extra safety checks inside add_members)
        add_members(session, train_coll, train_files, role="train")
        add_members(session, test_coll,  test_files,  role="test")


In [None]:
build_simple_split()

The code from this point down is for getting and analyzing knn performance.

In [1]:

import json
import numpy as np
from sqlalchemy.orm import Session
from db.db import get_db_engine
from db.models import FileCollection, FileCollectionMember, File, FileContent, FileTagLabel
from utils import cosine_similarity_batch

def get_embeddings_and_tags(session: Session, collection_name: str, embed_col: str = "minilm_emb"):
    """Return (file_id, hash, embedding, tag) tuples for all files in a collection."""
    emb_col = getattr(FileContent, embed_col)
    q = (
        session.query(File.id, File.hash, emb_col, FileTagLabel.tag)
        .join(FileCollectionMember, File.id == FileCollectionMember.file_id)
        .join(FileCollection, FileCollectionMember.collection_id == FileCollection.id)
        .join(FileContent, File.hash == FileContent.file_hash)
        .join(FileTagLabel, File.id == FileTagLabel.file_id)
        .filter(FileCollection.name == collection_name)
    )
    return q.all()

def evaluate_knn(k: int = 5, embed_col: str = "minilm_emb"):
    engine = get_db_engine()
    with Session(engine) as session:
        train = get_embeddings_and_tags(session, "global_train", embed_col)
        test = get_embeddings_and_tags(session, "global_test", embed_col)

    # convert to arrays for speed
    train_vecs = np.stack([np.array(t[2]) for t in train])
    train_tags = [t[3] for t in train]
    #train_ids  = [t[0] for t in train]

    results = []
    for fid, fhash, femb, true_tag in test:
        femb = np.array(femb)
        sims = cosine_similarity_batch(femb, train_vecs)  # shape (N_train,)
        topk_idx = np.argsort(sims)[-k:][::-1]  # top k highest cosine similarities
        topk_tags = [train_tags[i] for i in topk_idx]
        topk_scores = [float(sims[i]) for i in topk_idx]
        pred_tag = max(set(topk_tags), key=topk_tags.count)  # majority vote
        correct = pred_tag == true_tag
        tag_counts = {t: topk_tags.count(t) for t in set(topk_tags)}

        results.append({
            "test_file_id": fid,
            "test_hash": fhash,
            "true_tag": true_tag,
            "topk_tags": topk_tags,
            "topk_scores": topk_scores,
            "pred_tag": pred_tag,
            "is_correct": correct,
            "neighbor_tag_counts": tag_counts
        })

    return results


  r"N:\PPDO\Records"  (Windows)  or  "/mnt/records" (Linux).


In [2]:
results = evaluate_knn(k=5)

#store in json file
with open("knn_results.json", "w") as f:
    json.dump(results, f)




2025-10-07 16:16:49,673 db.db INFO Creating database engine
