In [None]:
from chromadb import PersistentClient
import numpy as np
from sklearn.manifold import TSNE
import plotly.graph_objects as go
import random


vectorstore_path = "vector_db/sustainability_index"

DB_NAME = r"vector_db\sustainability_index"

# Use the exact collection name from the [SUCCESS] log
COLLECTION_NAME = "esg_audit_demo"


In [None]:
## VISUALIZE THE VECTORSTORE COZ WHY NOT ##

# ---- CONFIG (important for 512MB RAM) ----
MAX_POINTS = 800   # hard cap, safe for t-SNE
TSNE_PERPLEXITY = 30

# ---- LOAD VECTORSTORE ----
chroma = PersistentClient(path=DB_NAME)
collection = chroma.get_or_create_collection(COLLECTION_NAME)

result = collection.get(include=["embeddings", "documents", "metadatas"])

embeddings = result["embeddings"]
documents = result["documents"]
metadatas = result["metadatas"]

total_points = len(embeddings)
print(f"Total vectors in DB: {total_points}")

# ---- SAMPLE (VERY IMPORTANT) ----
if total_points > MAX_POINTS:
    idx = random.sample(range(total_points), MAX_POINTS)
    vectors = np.array([embeddings[i] for i in idx])
    documents = [documents[i] for i in idx]
    metadatas = [metadatas[i] for i in idx]
else:
    vectors = np.array(embeddings)

print(f"Visualizing {len(vectors)} vectors")

# ---- COLOR MAPPING (normalize source_type) ----
def normalize_type(t):
    if t is None:
        return "unknown"
    if "table" in t:
        return "tabular"
    if "narrative" in t:
        return "narrative"
    return "other"

color_map = {
    "tabular": "red",
    "narrative": "blue",
    "policy": "green",
    "other": "gray",
    "unknown": "gray",
}

doc_types = [normalize_type(m.get("source_type")) for m in metadatas]
colors = [color_map[t] for t in doc_types]

# ---- t-SNE (safe params) ----
tsne = TSNE(
    n_components=2,
    perplexity=min(TSNE_PERPLEXITY, len(vectors) - 1),
    learning_rate=200,
    n_iter=1000,
    random_state=42,
)

reduced_vectors = tsne.fit_transform(vectors)

# ---- PLOT ----
fig = go.Figure(
    data=[
        go.Scatter(
            x=reduced_vectors[:, 0],
            y=reduced_vectors[:, 1],
            mode="markers",
            marker=dict(size=6, color=colors, opacity=0.75),
            text=[
                f"""
                <b>Type:</b> {m.get('source_type')}<br>
                <b>Pillar:</b> {m.get('audit_pillar')}<br>
                <b>Topic:</b> {m.get('primary_topic')}<br>
                <b>Page:</b> {m.get('page_number')}
                """
                for m in metadatas
            ],
            hoverinfo="text",
        )
    ]
)

fig.update_layout(
    title="2D Chroma Vector Store Visualization (t-SNE)",
    width=850,
    height=600,
    margin=dict(r=20, b=20, l=20, t=50),
)

fig.show()


In [2]:
from chromadb import PersistentClient
import numpy as np
from collections import Counter

chroma = PersistentClient(path=DB_NAME)
collection = chroma.get_collection(COLLECTION_NAME)

count = collection.count()
print(f"[OK] Vectorstore contains {count} records")

assert count > 0, "Vectorstore is empty"


[OK] Vectorstore contains 1477 records


In [3]:
SAMPLE_SIZE = min(20, count)

result = collection.get(
    limit=SAMPLE_SIZE,
    include=["embeddings", "documents", "metadatas"]
)

vectors = np.array(result["embeddings"])
documents = result["documents"]
metadatas = result["metadatas"]


In [4]:
print("\n--- BASIC SHAPE CHECKS ---")
print("Embeddings shape:", vectors.shape)
print("Documents count:", len(documents))
print("Metadatas count:", len(metadatas))

assert len(documents) == len(metadatas) == vectors.shape[0], \
    "Mismatch between embeddings, documents, and metadata"



--- BASIC SHAPE CHECKS ---
Embeddings shape: (20, 1536)
Documents count: 20
Metadatas count: 20


In [5]:
REQUIRED_METADATA_KEYS = {
    "source_type",   # tabular | narrative | policy
    "page",          # int
}

print("\n--- METADATA VALIDATION ---")

missing_metadata = []
missing_keys = []

for i, m in enumerate(metadatas):
    if not isinstance(m, dict):
        missing_metadata.append(i)
        continue

    for k in REQUIRED_METADATA_KEYS:
        if k not in m:
            missing_keys.append((i, k))

print("Chunks with non-dict metadata:", missing_metadata)
print("Chunks missing required keys:", missing_keys)

assert not missing_metadata, "Some chunks have invalid metadata objects"



--- METADATA VALIDATION ---
Chunks with non-dict metadata: []
Chunks missing required keys: [(0, 'page'), (1, 'page'), (2, 'page'), (3, 'page'), (4, 'page'), (5, 'page'), (6, 'page'), (7, 'page'), (8, 'page'), (9, 'page'), (10, 'page'), (11, 'page'), (12, 'page'), (13, 'page'), (14, 'page'), (15, 'page'), (16, 'page'), (17, 'page'), (18, 'page'), (19, 'page')]


In [6]:
print("\n--- METADATA SAMPLE ---")
for i, m in enumerate(metadatas[:5]):
    print(f"\nChunk {i}:")
    for k, v in m.items():
        print(f"  {k}: {v}")



--- METADATA SAMPLE ---

Chunk 0:
  anchor_context: General Reference
  audit_pillar: General
  primary_topic: Narrative
  page_number: 0
  bbox_top: 0.0
  search_keywords: 
  table_id: 0
  bbox_x1: 595.276
  bbox_x0: 0.0
  source_type: narrative_text
  bbox_bottom: 841.89
  chunk_index: 0

Chunk 1:
  page_number: 1
  chunk_index: 0
  bbox_x1: 595.276
  bbox_bottom: 841.89
  table_id: 0
  audit_pillar: General
  anchor_context: General Reference
  source_type: narrative_text
  search_keywords: 
  primary_topic: Narrative
  bbox_top: 0.0
  bbox_x0: 0.0

Chunk 2:
  search_keywords: 
  primary_topic: Narrative
  chunk_index: 0
  bbox_x1: 1190.55
  audit_pillar: General
  bbox_bottom: 841.89
  bbox_top: 0.0
  page_number: 2
  table_id: 0
  bbox_x0: 0.0
  source_type: narrative_text
  anchor_context: General Reference

Chunk 3:
  anchor_context: General Reference
  bbox_x1: 1190.55
  bbox_top: 0.0
  source_type: narrative_text
  bbox_bottom: 841.89
  audit_pillar: General
  table_id: 0
  s

In [7]:
print("\n--- SOURCE TYPE DISTRIBUTION ---")

source_types = [
    m.get("source_type", "UNKNOWN") for m in metadatas
]

counter = Counter(source_types)
for k, v in counter.items():
    print(f"{k}: {v}")



--- SOURCE TYPE DISTRIBUTION ---
narrative_text: 20


In [None]:
print("\n--- FILTER TEST (source_type = table_consolidated) ---")

tabular = collection.get(
    where={"source_type": "table_consolidated"},
    limit=5,
    include=["documents", "metadatas"]
)

print("Matched:", len(tabular["documents"]))
for i, m in enumerate(tabular["metadatas"]):
    print(f"Match {i} metadata:", m)



--- FILTER TEST (source_type = tabular) ---
Matched: 5
Match 0 metadata: {'audit_pillar': 'Governance', 'bbox_bottom': 153.41708055555557, 'bbox_x0': 49.88933333333333, 'source_type': 'table_consolidated', 'search_keywords': 'gender diversity, workforce gender split, male female ratios, employee gender statistics, diversity metrics, workforce composition, gender representation', 'primary_topic': 'Diversity', 'bbox_top': 112.32065882352944, 'table_id': 1, 'anchor_context': 'General Reference', 'bbox_x1': 704.001, 'chunk_index': 0, 'page_number': 8}
Match 1 metadata: {'table_id': 1, 'audit_pillar': 'Social', 'anchor_context': 'General Reference', 'chunk_index': 0, 'bbox_x1': 704.001, 'search_keywords': 'women, gender diversity, female representation, workforce diversity, employee demographics, gender ratio, women statistics', 'bbox_x0': 49.890000000000015, 'bbox_top': 66.40263750000001, 'source_type': 'table_consolidated', 'page_number': 9, 'primary_topic': 'Diversity', 'bbox_bottom': 1