This notebook creates reproducible train/test FileCollections for parent-level k-NN experiments.
It uses a dataclass to embed a descriptive, structured provenance into each FileCollection.

In [1]:
# %% [markdown]
# # KNN Evaluation Pipeline — Abstract Integration
# This notebook demonstrates how to use the new abstract classes
# (SplitSelectionStrategy, NeighborFilterStrategy, LabelingStrategy)
# to manage reproducible, pluggable KNN evaluation runs.

# %%
import random
import logging
from collections import Counter
from sqlalchemy.orm import Session
from sqlalchemy import exists, and_, func
from db import get_db_engine
from db.models import File, FileCollection, FilingTag, FileCollectionMember, FileTagLabel, FileContent
from knn.evaluation import (
    SplitSelectionStrategy,
    NeighborFilterStrategy,
    LabelingStrategy,
    KNNRun,
)
from knn.evaluation import KNNCollectionProvenance  # updated import path
from typing import List, Iterable, Set

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("knn.evaluation.run")

# constants (for demo)
EMBED_COL = "minilm_emb"
TRAIN_RATIO = 0.8
RANDOM_SEED = 42
PARENTS = ["B", "C", "D", "E", "F", "G", "H"]
PER_CHILD_CAP = 50

In [2]:

def list_all_descendants(session: Session, root_label: str) -> List[FilingTag]:
    descendants, queue = [], [root_label]
    seen: Set[str] = set([root_label])
    while queue:
        parent = queue.pop(0)
        kids = session.query(FilingTag).filter(FilingTag.parent_label == parent).all()
        for k in kids:
            if k.label not in seen:
                seen.add(k.label)
                descendants.append(k)
                queue.append(k.label)
    return descendants

def files_with_tag_and_embedding(session: Session, tag_label: str) -> List[File]:
    emb_col = getattr(FileContent, EMBED_COL)
    q = (
        session.query(File)
        .join(File.content)
        .filter(emb_col.isnot(None))
        .filter(
            exists().where(
                and_(FileTagLabel.file_id == File.id, FileTagLabel.tag == tag_label)
            )
        )
    )
    return q.all()

def create_or_get_collection(session, name: str, prov: KNNCollectionProvenance, overwrite: bool = False):
    """
    Retrieve or create a FileCollection with the given name.
    If it already exists and has members, returns it unchanged.
    If it exists but is empty or overwrite=True, updates its metadata and description.
    """
    from db.models import FileCollection, FileCollectionMember

    collection = session.query(FileCollection).filter(FileCollection.name == name).first()

    if collection:
        # Check if it already has members
        has_members = (
            session.query(FileCollectionMember)
            .filter(FileCollectionMember.collection_id == collection.id)
            .first()
            is not None
        )

        if has_members and not overwrite:
            logger.info(f"Reusing existing FileCollection '{name}' (id={collection.id}) with existing members.")
            return collection

        # Optional: update metadata if explicitly overwriting
        if overwrite:
            logger.info(f"Overwriting existing FileCollection '{name}' (id={collection.id}) metadata.")
            collection.description = prov.to_description()
            collection.meta = prov.to_metadata()
            session.commit()
        return collection

    # ---- Create a new one ----
    new_coll = FileCollection(
        name=name,
        description=prov.to_description(),
        meta=prov.to_metadata()
    )
    session.add(new_coll)
    session.commit()
    logger.info(f"Created new FileCollection '{name}' (id={new_coll.id}).")
    return new_coll

def add_members(session: Session, collection: FileCollection, files: Iterable[File], role: str, batch_size: int = 1000) -> int:
    """
    Insert (collection_id, file_id, role) rows safely:
      - De-dupes incoming files by file_id
      - Skips rows that already exist in DB for this collection+role
      - Inserts in batches to avoid giant statements
    """
    # 1) de-dupe input
    unique_by_id = {}
    for f in files:
        if f and getattr(f, "id", None) is not None:
            unique_by_id.setdefault(f.id, f)
    candidates = list(unique_by_id.values())
    if not candidates:
        logger.info(f"No candidates to add for {collection.name} ({role}).")
        return 0

    # 2) fetch existing file_ids for this collection+role
    existing_ids = {
        fid for (fid,) in session.query(FileCollectionMember.file_id)
        .filter(
            FileCollectionMember.collection_id == collection.id,
            FileCollectionMember.role == role
        ).all()
    }

    # 3) choose new rows only
    to_insert = [f for f in candidates if f.id not in existing_ids]
    if not to_insert:
        logger.info(f"All {role} members already present in {collection.name}.")
        return 0

    # 4) batch insert
    added = 0
    for i in range(0, len(to_insert), batch_size):
        chunk = to_insert[i:i+batch_size]
        session.add_all([
            FileCollectionMember(collection_id=collection.id, file_id=f.id, role=role)
            for f in chunk
        ])
        session.commit()
        added += len(chunk)

    logger.info(f"Added {added} {role} files to {collection.name}")
    return added

# unsure if I'll bother using this function
def refresh_collection_counts(session, collection: FileCollection):
    from db.models import FileCollectionMember
    counts = (
        session.query(FileCollectionMember.role, func.count(FileCollectionMember.file_id))
        .filter(FileCollectionMember.collection_id == collection.id)
        .group_by(FileCollectionMember.role)
        .all()
    )
    counts_dict = {role: count for role, count in counts}
    meta = collection.metadata or {}
    if "counts" not in meta:
        meta["counts"] = {}
    meta["counts"].update(counts_dict)
    collection.metadata = meta
    session.commit()


class RandomStratifiedSplit(SplitSelectionStrategy):
    @property
    def name(self) -> str:
        return "random_stratified_split"

    @property
    def description(self) -> str:
        return "Randomly sample up to PER_CHILD_CAP files per child tag, split 80/20 into global train/test sets, ensuring disjointness."

    def __init__(self, parents, per_child_cap, train_ratio, random_seed=None):
        self.parents = parents
        self.per_child_cap = per_child_cap
        self.train_ratio = train_ratio
        self.random_seed = random_seed or 42

    def select_files(self, session: Session):
        random.seed(self.random_seed)
        train_ids, test_ids = [], []
        for parent in self.parents:
            for child in list_all_descendants(session, parent):
                files = files_with_tag_and_embedding(session, child.label)
                if not files:
                    continue
                cap = min(self.per_child_cap, len(files))
                sampled = random.sample(files, cap) if len(files) > cap else list(files)
                random.shuffle(sampled)
                split_idx = int(len(sampled) * self.train_ratio)
                train_ids += [f.id for f in sampled[:split_idx]]
                test_ids += [f.id for f in sampled[split_idx:]]

        # Deduplicate and enforce disjointness
        train_ids = list(dict.fromkeys(train_ids))
        test_ids = [fid for fid in dict.fromkeys(test_ids) if fid not in set(train_ids)]
        return train_ids, test_ids
    
class NoOpNeighborFilter(NeighborFilterStrategy):
    @property
    def name(self):
        return "no_op_filter"
    @property
    def description(self):
        return "Does not filter any neighbors; uses all candidates."
    
    def filter(self, candidate_ids, test_file_hash = None):
        return candidate_ids
    
class MajorityVoteLabeler(LabelingStrategy):
    
    @property
    def name(self):
        return "majority_vote"
    
    @property
    def description(self):
        return "Infers label by majority vote among top-k neighbors."
    
    def infer_label(self, neighbor_tags, neighbor_scores):
        if not neighbor_tags:
            return None
            
        tag_counts = Counter(neighbor_tags)
        most_common_items = tag_counts.most_common(1)
        most_common_tag, _ = most_common_items[0]
        
        return most_common_tag
    


In [3]:
def build_collections_from_strategy(session: Session, strategy: SplitSelectionStrategy):
    """
    Generic builder that uses a SplitSelectionStrategy to populate train/test FileCollections.
    The strategy name and description are stored in metadata.
    """
    logger.info(f"Running split strategy: {strategy.name}")
    train_ids, test_ids = strategy.select_files(session)

    # Use the base class method with all required parameters
    prov = strategy.assemble_provenance_info(
        parents=PARENTS,
        embedding_col=EMBED_COL,
        split_ratio=TRAIN_RATIO,
        per_child_cap=PER_CHILD_CAP,
        random_seed=RANDOM_SEED
    )

    train_coll = create_or_get_collection(session=session,
                                          name=f"{prov.split_strategy}_{prov.embedding_column}_train",
                                          prov=prov)
    test_coll  = create_or_get_collection(session=session,
                                          name=f"{prov.split_strategy}_{prov.embedding_column}_test",
                                          prov=prov)

    # Get File objects
    train_files = session.query(File).filter(File.id.in_(train_ids)).all()
    test_files  = session.query(File).filter(File.id.in_(test_ids)).all()

    add_members(session, train_coll, train_files, role="train")
    add_members(session, test_coll, test_files, role="test")

    logger.info(f"Train: {len(train_files)}, Test: {len(test_files)}")

    return train_coll, test_coll

In [None]:
engine = get_db_engine()
with Session(engine) as session:
    split_strategy = RandomStratifiedSplit(
        parents=PARENTS,
        per_child_cap=PER_CHILD_CAP,
        train_ratio=TRAIN_RATIO,
        random_seed=RANDOM_SEED,
    )
    neighbor_filter = NoOpNeighborFilter()
    labeler = MajorityVoteLabeler()

    train_coll, test_coll = build_collections_from_strategy(session, split_strategy)

    # Extract collection names while session is active
    train_collection_name = train_coll.name
    test_collection_name = test_coll.name

    knn_run = KNNRun(
        k=5,
        name="baseline_random_split",
        description="Baseline KNN evaluation using random stratified split and majority voting.",
        training_collection=train_coll,
        test_collection=test_coll,
    )

    logger.info(f"KNNRun created: {knn_run.name}")
    logger.info(f"Split: {split_strategy.name}, Filter: {neighbor_filter.name}, Labeler: {labeler.name}")
    logger.info(f"Train collection: {train_collection_name}")
    logger.info(f"Test collection: {test_collection_name}")

2025-10-08 22:00:23,633 db.db INFO Creating database engine
2025-10-08 22:00:25,420 knn.evaluation.run INFO Running split strategy: random_stratified_split
2025-10-08 22:00:30,096 knn.evaluation.run INFO Created new FileCollection 'random_stratified_split_minilm_emb_train' (id=137).
2025-10-08 22:00:30,109 knn.evaluation.run INFO Created new FileCollection 'random_stratified_split_minilm_emb_test' (id=138).
2025-10-08 22:00:33,067 knn.evaluation.run INFO Added 2280 train files to random_stratified_split_minilm_emb_train
2025-10-08 22:00:34,131 knn.evaluation.run INFO Added 564 test files to random_stratified_split_minilm_emb_test
2025-10-08 22:00:34,132 knn.evaluation.run INFO Train: 2280, Test: 564


KNNRun created: baseline_random_split
Split: random_stratified_split, Filter: no_op_filter, Labeler: majority_vote


The code from this point down is for getting knn data for the collections. And designed to possibley be run independently of the Cells above

In [5]:
import json
import numpy as np
from sqlalchemy.orm import Session
from datetime import datetime
from db.db import get_db_engine
from db.models import FileCollection, FileCollectionMember, File, FileContent, FileTagLabel
from knn.base import cosine_similarity_batch


def get_embeddings_and_tags(session: Session, collection_name: str, embed_col: str = "minilm_emb"):
    """
    Return dict with:
      - "valid": list of {file_id, file_hash, embedding, tags}
      - "missing_content": list of file_hash with no FileContent
      - "missing_tags": list of file_hash with no FileTagLabel
    """

    # Retrieve all file IDs and hashes in the collection
    collection_files = (
        session.query(File.id, File.file_hash)
        .join(FileCollectionMember, File.id == FileCollectionMember.file_id)
        .join(FileCollection, FileCollectionMember.collection_id == FileCollection.id)
        .filter(FileCollection.name == collection_name)
        .all()
    )

    results = []
    missing_content = []
    missing_tags = []

    for fid, fhash in collection_files:
        # get embedding
        content = session.query(FileContent).filter(FileContent.file_hash == fhash).first()
        if not content or getattr(content, embed_col) is None:
            missing_content.append(fhash)
            continue

        # get tags
        tag_rows = session.query(FileTagLabel.tag).filter(FileTagLabel.file_id == fid).all()
        if not tag_rows:
            missing_tags.append(fhash)
            continue

        tags = {t for (t,) in tag_rows}
        # infer parent tag if possible (e.g., "F5" -> "F")
        for tag in list(tags):
            if len(tag) > 1 and tag[0].isalpha():
                tags.add(tag[0])

        results.append({
            "file_id": fid,
            "file_hash": fhash,
            "embedding": getattr(content, embed_col),
            "tags": sorted(tags)
        })

    return {
        "valid": results,
        "missing_content": missing_content,
        "missing_tags": missing_tags
    }

def cosine_knn_neighbors(train_valid: List[dict], test_valid: List[dict], k: int = 5):
    """
    Given valid train/test embedding-tag sets, compute top-k nearest neighbors for each test file.
    Returns list of neighbor info dicts.
    """
    train_vecs = np.stack([np.array(f["embedding"]) for f in train_valid])
    train_hashes = [f["file_hash"] for f in train_valid]
    train_tags = [f["tags"] for f in train_valid]

    results = []

    for test_file in test_valid:
        test_vec = np.array(test_file["embedding"])
        sims = cosine_similarity_batch(test_vec, train_vecs)
        topk_idx = np.argsort(sims)[-k:][::-1]

        neighbors = [
            {
                "neighbor_hash": train_hashes[i],
                "neighbor_tags": train_tags[i],
                "similarity": float(sims[i]),
            }
            for i in topk_idx
        ]

        results.append({
            "test_file_hash": test_file["file_hash"],
            "true_tags": test_file["tags"],
            "neighbors": neighbors
        })

    return results

def run_knn_evaluation(train_collection, test_collection, k: int = 5, embed_col: str = "minilm_emb", output_path=None):
    """
    Executes KNN neighbor computation for a given train/test collection pair.
    Returns a dict containing:
      - knn_results: list of neighbor data per test file
      - missing_content / missing_tags summaries
    Optionally saves results to JSON if output_path is given.
    """
    engine = get_db_engine()
    with Session(engine) as session:
        train_data = get_embeddings_and_tags(session, train_collection, embed_col)
        test_data  = get_embeddings_and_tags(session, test_collection, embed_col)

    knn_results = cosine_knn_neighbors(train_data["valid"], test_data["valid"], k=k)

    summary = {
        "train_collection": train_collection,
        "test_collection": test_collection,
        "embedding_column": embed_col,
        "k": k,
        "results": knn_results,
        "missing": {
            "train": {
                "content": train_data["missing_content"],
                "tags": train_data["missing_tags"]
            },
            "test": {
                "content": test_data["missing_content"],
                "tags": test_data["missing_tags"]
            }
        }
    }

    if output_path:
        with open(output_path, "w") as f:
            json.dump(summary, f, indent=2)
        print(f"Saved KNN results to {output_path}")

    return summary



In [7]:
# Run KNN evaluation using the collection names we extracted
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results = run_knn_evaluation(
    train_collection=train_collection_name,
    test_collection=test_collection_name,
    k=5,
    embed_col=EMBED_COL,
    output_path=f"knn_results_{timestamp}.json"
)


NameError: name 'train_collection_name' is not defined