This notebook creates reproducible train/test FileCollections for parent-level k-NN experiments.
It uses a dataclass to embed a descriptive, structured provenance into each FileCollection.

In [None]:
# %% [markdown]
# # KNN Evaluation Pipeline — Abstract Integration
# This notebook demonstrates how to use the new abstract classes
# (SplitSelectionStrategy, NeighborFilterStrategy, LabelingStrategy)
# to manage reproducible, pluggable KNN evaluation runs.

# %%
import random
import logging
from collections import Counter
from sqlalchemy.orm import Session
from sqlalchemy import exists, and_
from db import get_db_engine
from db.models import File, FileCollection, FilingTag, FileCollectionMember, FileTagLabel, FileContent
from knn.evaluation import (
    SplitSelectionStrategy,
    NeighborFilterStrategy,
    LabelingStrategy,
    KNNRun,
)
from knn.base import KNNCollectionProvenance  # updated import path
from typing import List, Iterable, Set

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("knn.evaluation.run")

# constants (for demo)
EMBED_COL = "minilm_emb"
TRAIN_RATIO = 0.8
RANDOM_SEED = 42
PARENTS = ["B", "C", "D", "E", "F", "G", "H"]
PER_CHILD_CAP = 50

In [None]:

def list_all_descendants(session: Session, root_label: str) -> List[FilingTag]:
    descendants, queue = [], [root_label]
    seen: Set[str] = set([root_label])
    while queue:
        parent = queue.pop(0)
        kids = session.query(FilingTag).filter(FilingTag.parent_label == parent).all()
        for k in kids:
            if k.label not in seen:
                seen.add(k.label)
                descendants.append(k)
                queue.append(k.label)
    return descendants

def files_with_tag_and_embedding(session: Session, tag_label: str) -> List[File]:
    emb_col = getattr(FileContent, EMBED_COL)
    q = (
        session.query(File)
        .join(File.content)
        .filter(emb_col.isnot(None))
        .filter(
            exists().where(
                and_(FileTagLabel.file_id == File.id, FileTagLabel.tag == tag_label)
            )
        )
    )
    return q.all()

def create_or_get_collection(session, name: str, prov: KNNCollectionProvenance, overwrite: bool = False):
    """
    Retrieve or create a FileCollection with the given name.
    If it already exists and has members, returns it unchanged.
    If it exists but is empty or overwrite=True, updates its metadata and description.
    """
    from db.models import FileCollection, FileCollectionMember

    collection = session.query(FileCollection).filter(FileCollection.name == name).first()

    if collection:
        # Check if it already has members
        has_members = (
            session.query(FileCollectionMember)
            .filter(FileCollectionMember.collection_id == collection.id)
            .first()
            is not None
        )

        if has_members and not overwrite:
            logger.info(f"Reusing existing FileCollection '{name}' (id={collection.id}) with existing members.")
            return collection

        # Optional: update metadata if explicitly overwriting
        if overwrite:
            logger.info(f"Overwriting existing FileCollection '{name}' (id={collection.id}) metadata.")
            collection.description = prov.to_description()
            collection.meta = prov.to_metadata()
            session.commit()
        return collection

    # ---- Create a new one ----
    new_coll = FileCollection(
        name=name,
        description=prov.to_description(),
        meta=prov.to_metadata()
    )
    session.add(new_coll)
    session.commit()
    logger.info(f"Created new FileCollection '{name}' (id={new_coll.id}).")
    return new_coll

def add_members(session: Session, collection: FileCollection, files: Iterable[File], role: str, batch_size: int = 1000) -> int:
    """
    Insert (collection_id, file_id, role) rows safely:
      - De-dupes incoming files by file_id
      - Skips rows that already exist in DB for this collection+role
      - Inserts in batches to avoid giant statements
    """
    # 1) de-dupe input
    unique_by_id = {}
    for f in files:
        if f and getattr(f, "id", None) is not None:
            unique_by_id.setdefault(f.id, f)
    candidates = list(unique_by_id.values())
    if not candidates:
        logger.info(f"No candidates to add for {collection.name} ({role}).")
        return 0

    # 2) fetch existing file_ids for this collection+role
    existing_ids = {
        fid for (fid,) in session.query(FileCollectionMember.file_id)
        .filter(
            FileCollectionMember.collection_id == collection.id,
            FileCollectionMember.role == role
        ).all()
    }

    # 3) choose new rows only
    to_insert = [f for f in candidates if f.id not in existing_ids]
    if not to_insert:
        logger.info(f"All {role} members already present in {collection.name}.")
        return 0

    # 4) batch insert
    added = 0
    for i in range(0, len(to_insert), batch_size):
        chunk = to_insert[i:i+batch_size]
        session.add_all([
            FileCollectionMember(collection_id=collection.id, file_id=f.id, role=role)
            for f in chunk
        ])
        session.commit()
        added += len(chunk)

    logger.info(f"Added {added} {role} files to {collection.name}")
    return added


class RandomStratifiedSplit(SplitSelectionStrategy):
    @property
    def name(self) -> str:
        return "random_stratified_split"

    @property
    def description(self) -> str:
        return "Randomly sample up to PER_CHILD_CAP files per child tag, split 80/20 into global train/test sets, ensuring disjointness."

    def __init__(self, parents, per_child_cap, train_ratio, random_seed=None):
        self.parents = parents
        self.per_child_cap = per_child_cap
        self.train_ratio = train_ratio
        self.random_seed = random_seed or 42

    def select_files(self, session: Session):
        random.seed(self.random_seed)
        train_ids, test_ids = [], []
        for parent in self.parents:
            for child in list_all_descendants(session, parent):
                files = files_with_tag_and_embedding(session, child.label)
                if not files:
                    continue
                cap = min(self.per_child_cap, len(files))
                sampled = random.sample(files, cap) if len(files) > cap else list(files)
                random.shuffle(sampled)
                split_idx = int(len(sampled) * self.train_ratio)
                train_ids += [f.id for f in sampled[:split_idx]]
                test_ids += [f.id for f in sampled[split_idx:]]

        # Deduplicate and enforce disjointness
        train_ids = list(dict.fromkeys(train_ids))
        test_ids = [fid for fid in dict.fromkeys(test_ids) if fid not in set(train_ids)]
        return train_ids, test_ids
    
class NoOpNeighborFilter(NeighborFilterStrategy):
    @property
    def name(self):
        return "no_op_filter"
    @property
    def description(self):
        return "Does not filter any neighbors; uses all candidates."
    
    def filter(self, candidate_ids, test_file_hash):
        return candidate_ids
    
class MajorityVoteLabeler(LabelingStrategy):
    
    @property
    def name(self):
        return "majority_vote"
    
    @property
    def description(self):
        return "Infers label by majority vote among top-k neighbors."
    
    def infer_label(self, neighbor_tags, neighbor_scores):
        if not neighbor_tags:
            return None
            
        tag_counts = Counter(neighbor_tags)
        most_common_items = tag_counts.most_common(1)
        most_common_tag, _ = most_common_items[0]
        
        return most_common_tag
    


In [None]:
def build_collections_from_strategy(session: Session, strategy: SplitSelectionStrategy):
    """
    Generic builder that uses a SplitSelectionStrategy to populate train/test FileCollections.
    The strategy name and description are stored in metadata.
    """
    logger.info(f"Running split strategy: {strategy.name}")
    train_ids, test_ids = strategy.select_files(session)

    # Use the base class method with all required parameters
    prov = strategy.assemble_provenance_info(
        parents=PARENTS,
        embedding_col=EMBED_COL,
        split_ratio=TRAIN_RATIO,
        per_child_cap=PER_CHILD_CAP,
        random_seed=RANDOM_SEED
    )

    train_coll = create_or_get_collection(session=session,
                                          name=f"{prov.split_strategy}_{prov.embedding_column}_train",
                                          prov=prov)
    test_coll  = create_or_get_collection(session=session,
                                          name=f"{prov.split_strategy}_{prov.embedding_column}_test",
                                          prov=prov)

    # Get File objects
    train_files = session.query(File).filter(File.id.in_(train_ids)).all()
    test_files  = session.query(File).filter(File.id.in_(test_ids)).all()

    add_members(session, train_coll, train_files, role="train")
    add_members(session, test_coll, test_files, role="test")

    logger.info(f"Train: {len(train_files)}, Test: {len(test_files)}")

    return train_coll, test_coll

In [None]:
engine = get_db_engine()
with Session(engine) as session:
    split_strategy = RandomStratifiedSplit(
        parents=PARENTS,
        per_child_cap=PER_CHILD_CAP,
        train_ratio=TRAIN_RATIO,
        random_seed=RANDOM_SEED,
    )
    neighbor_filter = NoOpNeighborFilter()
    labeler = MajorityVoteLabeler()

    train_coll, test_coll = build_collections_from_strategy(session, split_strategy)

    knn_run = KNNRun(
        k=5,
        name="baseline_random_split",
        description="Baseline KNN evaluation using random stratified split and majority voting.",
        training_collection=train_coll,
        test_collection=test_coll,
    )

    print(f"KNNRun created: {knn_run.name}")
    print(f"Split: {split_strategy.name}, Filter: {neighbor_filter.name}, Labeler: {labeler.name}")

The code from this point down is for getting and analyzing knn performance.

In [1]:

import json
import numpy as np
from sqlalchemy.orm import Session
from db.db import get_db_engine
from db.models import FileCollection, FileCollectionMember, File, FileContent, FileTagLabel
from utils import cosine_similarity_batch

def get_embeddings_and_tags(session: Session, collection_name: str, embed_col: str = "minilm_emb"):
    """Return (file_id, hash, embedding, tag) tuples for all files in a collection."""
    emb_col = getattr(FileContent, embed_col)
    q = (
        session.query(File.id, File.hash, emb_col, FileTagLabel.tag)
        .join(FileCollectionMember, File.id == FileCollectionMember.file_id)
        .join(FileCollection, FileCollectionMember.collection_id == FileCollection.id)
        .join(FileContent, File.hash == FileContent.file_hash)
        .join(FileTagLabel, File.id == FileTagLabel.file_id)
        .filter(FileCollection.name == collection_name)
    )
    return q.all()

def evaluate_knn(k: int = 5, embed_col: str = "minilm_emb"):
    engine = get_db_engine()
    with Session(engine) as session:
        train = get_embeddings_and_tags(session, "global_train", embed_col)
        test = get_embeddings_and_tags(session, "global_test", embed_col)

    # convert to arrays for speed
    train_vecs = np.stack([np.array(t[2]) for t in train])
    train_tags = [t[3] for t in train]
    #train_ids  = [t[0] for t in train]

    results = []
    for fid, fhash, femb, true_tag in test:
        femb = np.array(femb)
        sims = cosine_similarity_batch(femb, train_vecs)  # shape (N_train,)
        topk_idx = np.argsort(sims)[-k:][::-1]  # top k highest cosine similarities
        topk_tags = [train_tags[i] for i in topk_idx]
        topk_scores = [float(sims[i]) for i in topk_idx]
        pred_tag = max(set(topk_tags), key=topk_tags.count)  # majority vote
        correct = pred_tag == true_tag
        tag_counts = {t: topk_tags.count(t) for t in set(topk_tags)}

        results.append({
            "test_file_id": fid,
            "test_hash": fhash,
            "true_tag": true_tag,
            "topk_tags": topk_tags,
            "topk_scores": topk_scores,
            "pred_tag": pred_tag,
            "is_correct": correct,
            "neighbor_tag_counts": tag_counts
        })

    return results


  r"N:\PPDO\Records"  (Windows)  or  "/mnt/records" (Linux).


In [2]:
results = evaluate_knn(k=5)

#store in json file
with open("knn_results.json", "w") as f:
    json.dump(results, f)




2025-10-07 16:16:49,673 db.db INFO Creating database engine
