# Dimensionality Reduction of Top-Level Filing Tags (Excluding 'A')
This notebook:

1. Connects to the project database (MiniLM embeddings).
2. Pulls embeddings for files whose assigned tag is a *top-level* filing tag (no parent) excluding the 'A' tag.
3. Balances samples per tag (up to `MAX_PER_TAG`).
4. Applies L2 normalization, PCA (pre-dimension reduction), then UMAP to 2D.
5. Produces an interactive Plotly scatter with distinct colors per tag.

If database credentials are missing or the query returns no rows, the notebook will print a clear message and skip plotting.


In [11]:
# cell 1: imports & setup
import os
import random
import numpy as np
import pandas as pd

from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
import umap.umap_ as umap
import plotly.express as px

from sqlalchemy.orm import sessionmaker

# Your existing db models
from db.models import FileEmbedding, FileTagLabel, FilingTag, File
from db.db import get_db_engine

# ---- Config ----
MODEL_NAME = "all-MiniLM-L6-v2"
MAX_PER_TAG = 300
RANDOM_SEED = 42
PCA_PRECOMP = 50

UMAP_PARAMS = dict(
    n_neighbors=15,
    min_dist=0.1,
    metric='cosine'
)

engine = get_db_engine()
Session = sessionmaker(bind=engine)
session = Session()


2025-09-11 19:24:18,775 db.db INFO Creating database engine


In [12]:
def fetch_embeddings_tags(engine_model_name, split=None, parent_tag=None, max_per_tag=MAX_PER_TAG, top_level_exclude_A: bool = False):
    """
    Fetch MiniLM embeddings and associated tags / filenames from the database.

    Returns a DataFrame with columns: file_id, tag, filename, embedding (np.array).

    Filtering modes (mutually aware):
        top_level_exclude_A=True  -> only FilingTag rows where parent_label IS NULL AND tag != 'A'.
        parent_tag="X"            -> restrict to descendants whose parent_label == parent_tag (ignored if top_level_exclude_A True).
        split="train|val|test"    -> restrict by FileTagLabel.split.

    Notes:
        * Correct joins use file hash (File.hash) because FileEmbedding & FileTagLabel are keyed on file_hash.
        * File itself has no filename column; we take the (first) FileLocation.filename via LEFT OUTER JOIN.
        * Sampling is done client‑side per tag after full retrieval; for very large tables consider a SQL sampling strategy.
    """
    # Base query assembling required pieces.
    # We pick the first filename by using MIN(FileLocation.filename) aggregation in a subquery to avoid duplicate rows per location.
    from sqlalchemy import func
    from sqlalchemy.orm import aliased

    # Subquery to get a representative filename per file_id (min over filenames)
    from db.models import FileEmbedding as FE, File as F, FileTagLabel as FTL, FilingTag as FT, FileLocation as FL

    filename_subq = (
        session.query(FL.file_id.label('fl_file_id'), func.min(FL.filename).label('rep_filename'))
        .group_by(FL.file_id)
        .subquery()
    )

    q = (
        session.query(
            F.id.label('file_id'),
            FE.minilm_emb.label('embedding'),
            FTL.tag.label('tag'),
            filename_subq.c.rep_filename.label('filename')
        )
        .join(FE, FE.file_hash == F.hash)          # FileEmbedding -> File via hash
        .join(FTL, F.hash == FTL.file_hash)        # Tag labels via hash
        .join(FT, FT.label == FTL.tag)             # FilingTag for hierarchy filters
        .outerjoin(filename_subq, filename_subq.c.fl_file_id == F.id)  # representative filename
        .filter(FE.minilm_model == engine_model_name)
    )

    if top_level_exclude_A:
        q = q.filter(FT.parent_label.is_(None)).filter(FTL.tag != 'A')
    elif parent_tag is not None:
        # Children of specific parent
        q = q.filter(FT.parent_label == parent_tag)

    if split is not None:
        q = q.filter(FTL.split == split)

    rows = q.all()
    if not rows:
        return pd.DataFrame([])

    # Group and sample per tag
    tag2rows = {}
    for file_id, emb_vec, tag, filename in rows:
        if emb_vec is None:
            continue
        # pgvector returns a list-like already; ensure numpy array float
        emb_np = np.array(emb_vec, dtype=float)
        tag2rows.setdefault(tag, []).append((file_id, emb_np, filename or ""))

    random.seed(RANDOM_SEED)
    records = []
    for tag, items in tag2rows.items():
        if not items:
            continue
        if len(items) > max_per_tag:
            items = random.sample(items, max_per_tag)
        for fid, emb_arr, fname in items:
            records.append({
                'file_id': fid,
                'tag': tag,
                'filename': fname,
                'embedding': emb_arr
            })

    return pd.DataFrame(records)


def reduce_dims(df, embedding_col='embedding'):
    """
    Given df with an embedding column (np.array), do normalization, PCA, UMAP.
    The necessary dataframe columns are: 'file_id', 'tag', 'filename', embedding_col.

    Returns df augmented by 'umap_x','umap_y' (and optionally PCA columns).
    """
    if df.empty:
        return df
    X = np.vstack(df[embedding_col].values)
    X_norm = normalize(X, norm='l2')

    pca = PCA(n_components=min(PCA_PRECOMP, X_norm.shape[1]), random_state=RANDOM_SEED)
    X_pca = pca.fit_transform(X_norm)
    print("Explained variance ratio (first few components):", pca.explained_variance_ratio_[:10])

    reducer = umap.UMAP(**UMAP_PARAMS, random_state=RANDOM_SEED)
    X_umap = reducer.fit_transform(X_pca)

    df2 = df.copy()
    for i in range(min(5, X_pca.shape[1])):
        df2[f'pca_{i}'] = X_pca[:, i]
    df2['umap_x'] = X_umap[:, 0]
    df2['umap_y'] = X_umap[:, 1]
    return df2


In [13]:
df_raw = fetch_embeddings_tags(MODEL_NAME, split='train', parent_tag=None, max_per_tag=MAX_PER_TAG, top_level_exclude_A=True)
if df_raw.empty:
    print("No rows returned for the selected criteria (top-level tags excluding 'A').")
else:
    print("Number of points fetched:", len(df_raw))
    print("Tag counts:\n", df_raw['tag'].value_counts())

df_vis = reduce_dims(df_raw)

Number of points fetched: 2100
Tag counts:
 tag
B    300
G    300
H    300
E    300
D    300
F    300
C    300
Name: count, dtype: int64
Explained variance ratio (first few components): [0.08548264 0.04174839 0.03548706 0.03101033 0.02861427 0.02512823
 0.0223553  0.01926928 0.01753276 0.01604993]
Explained variance ratio (first few components): [0.08548264 0.04174839 0.03548706 0.03101033 0.02861427 0.02512823
 0.0223553  0.01926928 0.01753276 0.01604993]



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



# Top-Level Tag Embedding Visualization (Excluding 'A')
This notebook:

1. Loads MiniLM embeddings for files whose assigned tag is a *top-level* filing tag (no parent) except 'A'.
2. Samples up to `MAX_PER_TAG` items per tag for balance.
3. L2 normalizes embeddings, does PCA (pre-reduction) then UMAP to 2D.
4. Plots an interactive scatter with distinct colors per tag.

If no rows are returned (e.g., env vars / DB not set or tags missing), the notebook will report and skip plotting.

Adjustable knobs near the top: `MAX_PER_TAG`, `PCA_PRECOMP`, and `UMAP_PARAMS`.


In [14]:
if df_vis.empty:
    print("No data available for visualization. Check database connection, model name, or tag filters.")
else:
    # Ensure deterministic color mapping (sorted tags)
    unique_tags = sorted(df_vis['tag'].unique())
    color_discrete_sequence = px.colors.qualitative.Alphabet
    # If more tags than palette length, cycle
    palette = [color_discrete_sequence[i % len(color_discrete_sequence)] for i in range(len(unique_tags))]
    tag_color_map = {t: palette[i] for i, t in enumerate(unique_tags)}

    fig = px.scatter(
        df_vis,
        x='umap_x',
        y='umap_y',
        color='tag',
        hover_data=['file_id', 'filename', 'tag'],
        title=f"UMAP of MiniLM embeddings (model={MODEL_NAME}) – Top-Level Tags (excl. 'A')",
        width=1000,
        height=800,
        opacity=0.75,
        color_discrete_map=tag_color_map,
    )
    fig.update_layout(legend_title_text="Tag", legend=dict(itemsizing='trace', bordercolor='rgba(0,0,0,0.2)', borderwidth=1))
    fig.update_traces(marker=dict(size=7, line=dict(width=0)))
    fig.show()
    print("Visualization complete.")

Visualization complete.
