### Weaviate Graph Visualization (yFiles for Jupyter)

Run cells in order:
- Step 1: Configure Weaviate connection
- Step 2: Load collections and build directed graph
- Step 3: Render graph with yFiles

Notes:
- Uses GraphQL to fetch objects and references from Weaviate
- Creates a directed graph (DiGraph) to show relationship directions
- Automatically assigns colors to different node classes

In [1]:
# Step 1 — Weaviate connection configuration
import os
import json
import requests
import networkx as nx
from typing import Dict, Any, List, Tuple

WEAVIATE_URL = os.environ.get("WEAVIATE_URL", "http://localhost:8090")
WEAVIATE_API_KEY = os.environ.get("WEAVIATE_API_KEY", "")

# Data loading controls
BATCH_SIZE = 1000  # Number of objects to fetch per batch (for pagination)
INCLUDE_PROPERTIES: List[str] = []  # empty => include all
EXCLUDE_PROPERTIES: List[str] = ["_additional", "id"]

session = requests.Session()
if WEAVIATE_API_KEY:
    session.headers.update({"Authorization": f"Bearer {WEAVIATE_API_KEY}"})
session.headers.update({"Content-Type": "application/json"})

print(f"Using Weaviate at: {WEAVIATE_URL}")

Using Weaviate at: http://localhost:8090


In [2]:
# Step 2 — Fetch objects and references from Weaviate using v4 Python API (with vector support)
from typing import Dict, Any, List, Tuple, Optional
import weaviate
from weaviate.classes.query import QueryReference
from pathlib import Path
from dotenv import load_dotenv

# Load environment from backend/.env
env_path = Path.cwd().parent / '.env'
load_dotenv(dotenv_path=env_path)

TENANT = os.environ.get("WEAVIATE_TENANT", "")
print("Tenant:", TENANT or "<none>")


def pick_properties(props: Dict[str, Any]) -> Dict[str, Any]:
    if not props:
        return {}
    if INCLUDE_PROPERTIES:
        picked = {k: v for k, v in props.items() if k in INCLUDE_PROPERTIES}
    else:
        picked = {k: v for k, v in props.items() if k not in EXCLUDE_PROPERTIES}
    return picked


def build_graph_from_weaviate_v4(batch_size: int = 1000) -> Tuple[nx.DiGraph, Dict[str, Dict[str, Any]]]:
    """Build a directed graph from Weaviate using v4 Python API with vector support"""
    
    # Connect using v4 API
    client = weaviate.connect_to_local(port=8090, grpc_port=50051)
    
    try:
        graph = nx.DiGraph()
        node_props: Dict[str, Dict[str, Any]] = {}
        
        # Get schema to discover reference properties
        schema_resp = client.get_meta()
        schema = schema_resp if schema_resp else {}
        
        # Get all collections
        collections = client.collections.list_all()
        
        for class_name in collections.keys():
            print(f"Fetching all objects from class '{class_name}'...")
            collection = client.collections.get(class_name)
            
            # Get collection config to find reference properties
            try:
                config = collection.config.get()
                ref_props = []
                if config.references:
                    ref_props = [ref.name for ref in config.references]
                    print(f"  Reference properties: {ref_props}")
            except Exception as e:
                print(f"  Could not get config: {e}")
                ref_props = []
            
            # Fetch all objects with pagination
            offset = 0
            total_count = 0
            
            while True:
                # Build return_references if we have ref properties
                return_refs = None
                if ref_props:
                    # Return all reference properties
                    return_refs = [QueryReference(link_on=ref_name) for ref_name in ref_props]
                
                results = collection.query.fetch_objects(
                    limit=batch_size,
                    offset=offset,
                    include_vector=True,  # ✅ Include vector data
                    return_references=return_refs  # ✅ Include references
                )
                
                if not results.objects:
                    break
                
                for obj in results.objects:
                    uid = str(obj.uuid)
                    props = dict(obj.properties)
                    
                    # ✅ Add vector preview if vector exists
                    if obj.vector and obj.vector.get('default'):
                        vec = obj.vector['default']
                        preview = "[" + ", ".join(f"{x:.3f}" for x in vec[:5]) + ", ...]"
                        props["vector_preview"] = preview
                        props["vector_dimension"] = len(vec)
                    
                    props["class"] = class_name
                    
                    # ✅ Handle references properly
                    if obj.references:
                        for ref_name, ref_objs in obj.references.items():
                            if ref_objs and ref_objs.objects:
                                # Convert to format compatible with visualization
                                props[ref_name] = [{"_additional": {"id": str(r.uuid)}} for r in ref_objs.objects]
                    
                    sel = pick_properties(props)
                    sel["class"] = class_name
                    node_props[uid] = sel
                    graph.add_node(uid, **sel)
                    total_count += 1
                
                # Check if we've fetched all objects
                if len(results.objects) < batch_size:
                    break
                
                offset += batch_size
            
            print(f"  Fetched {total_count} objects from '{class_name}'")
        
        # Create edges based on references
        edge_count = 0
        for uid, props in node_props.items():
            for key, value in list(props.items()):
                if isinstance(value, list) and value and isinstance(value[0], dict) and "_additional" in value[0]:
                    for ref in value:
                        target_id = ref.get("_additional", {}).get("id")
                        if target_id:
                            graph.add_edge(uid, target_id, relation=key)
                            edge_count += 1
        
        print(f"Created {edge_count} edges from references")
        return graph, node_props
    
    finally:
        client.close()


try:
    G_weaviate, node_properties = build_graph_from_weaviate_v4(batch_size=BATCH_SIZE)
    if G_weaviate.number_of_nodes() == 0:
        raise RuntimeError("Weaviate returned empty graph")
    
    # Check how many nodes have vectors
    nodes_with_vectors = sum(1 for uid, props in node_properties.items() if 'vector_preview' in props)
    print(f"Nodes with vectors: {nodes_with_vectors}/{len(node_properties)}")
    
except Exception as e:
    print(f"Falling back to mock data: {e}")
    import traceback
    traceback.print_exc()
    from mock_data import build_graph_from_mock
    G_weaviate, node_properties = build_graph_from_mock(
        num_persons=16, num_companies=8, num_articles=24
    )

print(f"Loaded: {G_weaviate.number_of_nodes()} nodes, {G_weaviate.number_of_edges()} edges")

Tenant: <none>
Fetching all objects from class 'Sdbody'...
  Reference properties: ['infile', 'composes', 'follows']
  Fetched 40 objects from 'Sdbody'
Fetching all objects from class 'Sddir'...
  Fetched 1 objects from 'Sddir'
Fetching all objects from class 'Sdfile'...
  Reference properties: ['indir']
  Fetched 20 objects from 'Sdfile'
Created 124 edges from references
Nodes with vectors: 61/61
Loaded: 61 nodes, 67 edges


In [4]:
# Step 3 — Render the graph with yFiles
from yfiles_jupyter_graphs import GraphWidget
import colorsys
import random

# Node label preference (customize as needed)
LABEL_PRIORITY = ["name", "title", "label", "class", "id"]
DEFAULT_COLOR = "#A8DADC"
MAX_NODES_PER_CLASS = 50  # Maximum number of nodes to display per class


def generate_distinct_colors(n: int) -> list:
    """Generate n visually distinct colors using HSV color space"""
    colors = []
    for i in range(n):
        hue = i / n
        saturation = 0.65 + (i % 3) * 0.1  # 0.65-0.85
        value = 0.85 + (i % 2) * 0.1       # 0.85-0.95
        
        rgb = colorsys.hsv_to_rgb(hue, saturation, value)
        hex_color = '#{:02x}{:02x}{:02x}'.format(
            int(rgb[0] * 255),
            int(rgb[1] * 255),
            int(rgb[2] * 255)
        )
        colors.append(hex_color)
    return colors


def auto_assign_class_colors(graph: nx.DiGraph) -> dict:
    """Automatically assign colors to all classes in the graph"""
    class_names = set()
    for node_id, node_data in graph.nodes(data=True):
        class_name = node_data.get("class")
        if class_name:
            class_names.add(class_name)
    
    class_names = sorted(class_names)  # Sort for consistency
    colors = generate_distinct_colors(len(class_names))
    return dict(zip(class_names, colors))


def choose_label(props: Dict[str, Any]) -> str:
    """Choose the best label from node properties"""
    for key in LABEL_PRIORITY:
        if key in props and props[key] not in (None, ""):
            return str(props[key])
    items = [f"{k}:{v}" for k, v in list(props.items())[:3]]
    return ", ".join(items) if items else "node"


def limit_nodes_by_class(graph: nx.DiGraph, max_per_class: int = 200) -> nx.DiGraph:
    """
    Create a subgraph with at most max_per_class nodes per class type.
    Includes all edges connected to the selected nodes (including edges to nodes not in selection).
    
    :param graph: Original graph with all nodes
    :param max_per_class: Maximum number of nodes to keep per class
    :return: Subgraph with limited nodes but all relevant edges
    """
    # Group nodes by class
    nodes_by_class = {}
    for node_id, node_data in graph.nodes(data=True):
        class_name = node_data.get("class", "Unknown")
        if class_name not in nodes_by_class:
            nodes_by_class[class_name] = []
        nodes_by_class[class_name].append(node_id)
    
    # Select up to max_per_class nodes per class
    selected_nodes = set()
    selection_stats = {}
    
    for class_name, node_list in nodes_by_class.items():
        if len(node_list) <= max_per_class:
            # Keep all nodes of this class
            selected_nodes.update(node_list)
            selection_stats[class_name] = len(node_list)
        else:
            # Randomly select max_per_class nodes
            selected = random.sample(node_list, max_per_class)
            selected_nodes.update(selected)
            selection_stats[class_name] = max_per_class
            print(f"  Limited {class_name}: {len(node_list)} -> {max_per_class} nodes")
    
    # Find all edges connected to selected nodes
    # Include edges where either source or target is in selected_nodes
    selected_edges = []
    connected_nodes = set(selected_nodes)  # Start with selected nodes
    
    for u, v, edata in graph.edges(data=True):
        if u in selected_nodes or v in selected_nodes:
            selected_edges.append((u, v, edata))
            # Include both endpoints of the edge
            connected_nodes.add(u)
            connected_nodes.add(v)
    
    # Create subgraph with selected nodes and all connected nodes
    print(f"\nSelection summary:")
    for class_name, count in sorted(selection_stats.items()):
        print(f"  {class_name}: {count} core nodes")
    print(f"  Total core nodes: {len(selected_nodes)}")
    print(f"  Additional connected nodes: {len(connected_nodes) - len(selected_nodes)}")
    print(f"  Total nodes in subgraph: {len(connected_nodes)}")
    print(f"  Edges in subgraph: {len(selected_edges)}")
    
    # Build subgraph
    subgraph = nx.DiGraph()
    
    # Add all connected nodes (selected + neighbors)
    for node_id in connected_nodes:
        node_data = graph.nodes[node_id].copy()
        subgraph.add_node(node_id, **node_data)
    
    # Add all edges connected to selected nodes
    for u, v, edata in selected_edges:
        subgraph.add_edge(u, v, **edata)
    
    return subgraph


# Auto-assign colors to all classes
print("Analyzing graph structure...")
CLASS_COLORS = auto_assign_class_colors(G_weaviate)
print(f"Found {len(CLASS_COLORS)} distinct classes")

# Limit nodes per class for performance
print(f"\nLimiting to {MAX_NODES_PER_CLASS} nodes per class (with all connected edges)...")
G_limited = limit_nodes_by_class(G_weaviate, max_per_class=MAX_NODES_PER_CLASS)

# Build render graph with limited nodes but all visualization properties
print("\nBuilding render graph...")
G_render = nx.DiGraph()  # Use directed graph to show arrows

# Copy limited nodes: keep all Weaviate data + add visualization properties
for n, pdata in G_limited.nodes(data=True):
    node_label = choose_label(pdata)
    class_name = pdata.get("class", "")
    color = CLASS_COLORS.get(class_name, DEFAULT_COLOR)
    
    # Keep all original Weaviate data + add visualization properties
    node_attrs = {
        **pdata,              # All original Weaviate data
        "label": node_label,  # Add label for display
        "color": color        # Add color for visualization
    }
    G_render.add_node(n, **node_attrs)

# Copy all edges with attributes
for u, v, edata in G_limited.edges(data=True):
    relation = edata.get("relation", "")
    edge_attrs = {
        **edata,              # All original edge data
        "label": relation     # Add label for display
    }
    G_render.add_edge(u, v, **edge_attrs)

print(f"Graph built: {G_render.number_of_nodes()} nodes, {G_render.number_of_edges()} directed edges\n")

# Create GraphWidget
w = GraphWidget(graph=G_render)
w.node_color_mapping = 'color'
w.node_label_mapping = 'label'
w.edge_label_mapping = 'label'

# Display color legend
print("="*60)
print("Color Legend:")
print("="*60)
class_count = {}
for node_id, node_data in G_render.nodes(data=True):
    cls = node_data.get("class", "Unknown")
    class_count[cls] = class_count.get(cls, 0) + 1

for class_name in sorted(CLASS_COLORS.keys()):
    color = CLASS_COLORS[class_name]
    count = class_count.get(class_name, 0)
    print(f"  {class_name:20s}: {color}  ({count} nodes)")

# Display edge statistics
print("\n" + "="*60)
print("Edge Relationship Types:")
print("="*60)
relation_count = {}
for u, v, edata in G_render.edges(data=True):
    rel = edata.get("relation", "unknown")
    relation_count[rel] = relation_count.get(rel, 0) + 1

for rel, count in sorted(relation_count.items()):
    print(f"  {rel:20s}: {count} edges")

print("="*60)

# Display graph
w

Analyzing graph structure...
Found 3 distinct classes

Limiting to 50 nodes per class (with all connected edges)...

Selection summary:
  Sdbody: 40 core nodes
  Sddir: 1 core nodes
  Sdfile: 20 core nodes
  Total core nodes: 61
  Additional connected nodes: 0
  Total nodes in subgraph: 61
  Edges in subgraph: 67

Building render graph...
Graph built: 61 nodes, 67 directed edges

Color Legend:
  Sdbody              : #d84b4b  (40 nodes)
  Sddir               : #3cf23c  (1 nodes)
  Sdfile              : #2020d8  (20 nodes)

Edge Relationship Types:
  composes            : 7 edges
  follows             : 17 edges
  indir               : 19 edges
  infile              : 24 edges


GraphWidget(layout=Layout(height='800px', width='100%'))