In [None]:
# TAB 5: SEMANTIC SEARCH + GRAPH (CLEAN & FIXED)

with tabs[4]:
    st.header("üåê Knowledge Graph Visualization & Semantic Search (Milestone 3)")

    import tempfile
    import networkx as nx
    from pyvis.network import Network

    triples = st.session_state.get("triples", [])
    if not triples:
        st.warning("No triples available. Upload CSV or run NLP first.")
    else:
        # --- Settings ---
        show_labels = st.checkbox("Show relation labels", value=False)
        top_k = st.number_input("Top-K matches", 1, 20, 5)

        # --- Build Graph ---
        G = nx.Graph()
        for s, p, o in triples:
            G.add_node(s)
            G.add_node(o)
            G.add_edge(s, o, relation=p)

        # --- Summary ---
        c1, c2, c3 = st.columns(3)
        c1.metric("Nodes", G.number_of_nodes())
        c2.metric("Edges", G.number_of_edges())
        avg_deg = round(sum(dict(G.degree()).values()) / max(1, G.number_of_nodes()), 2)
        c3.metric("Avg degree", avg_deg)
        st.markdown("---")

        # --- Semantic Search (cached) ---
        @st.cache_resource
        def load_st_model():
            from sentence_transformers import SentenceTransformer
            return SentenceTransformer("all-MiniLM-L6-v2")

        @st.cache_data
        def embed_sentences(sentences):
            import numpy as np
            model = load_st_model()
            return model.encode(sentences, convert_to_numpy=True)

        sentences = [f"{s} {p} {o}" for s, p, o in triples]
        query = st.text_input("üîç Semantic Search (e.g., 'Einstein physics')")

        highlight_nodes, highlight_edges = set(), set()
        if query.strip():
            try:
                import numpy as np
                from sklearn.metrics.pairwise import cosine_similarity

                emb = embed_sentences(sentences)
                q = load_st_model().encode([query], convert_to_numpy=True)
                scores = cosine_similarity(q, emb)[0]
                idxs = np.argsort(scores)[::-1][:top_k]

                results = []
                for i in idxs:
                    s, p, o = triples[i]
                    results.append({"Subject": s, "Relation": p, "Object": o, "Score": float(scores[i])})
                    highlight_nodes.update([s, o])
                    highlight_edges.add((s, o, p))
                st.subheader("Results")
                st.dataframe(pd.DataFrame(results), use_container_width=True)
            except Exception as e:
                st.error(f"Semantic search failed: {e}")

        # --- Graph Display ---
        H = G.copy()
        deg = dict(H.degree())
        max_deg = max(deg.values()) if deg else 1
        def size_for(n): return 16 + int(22 * (deg.get(n,0) / max_deg))

        net = Network(height="680px", width="100%", bgcolor="#0e1117", font_color="white")
        net.barnes_hut(gravity=-40000, central_gravity=0.25, spring_length=260, spring_strength=0.02, damping=0.5)

        for n in H.nodes():
            color = "#FF6B6B" if n in highlight_nodes else "#7FB3FF"
            net.add_node(n, label=n, title=n, color=color, size=size_for(n))

        for u, v, data in H.edges(data=True):
            rel = data.get("relation","")
            is_hit = (u, v, rel) in highlight_edges or (v, u, rel) in highlight_edges
            edge_kwargs = dict(title=rel)
            if show_labels:
                edge_kwargs["label"] = rel
            if is_hit:
                edge_kwargs["width"] = 3
            net.add_edge(u, v, **edge_kwargs)

        # Try to set options (safe for mixed pyvis versions)
        try:
            net.set_options("""
{
  "nodes": {"borderWidth": 2, "shape": "dot", "font": {"size": 16}},
  "edges": {
    "font": {"size": 12, "background": "rgba(14,17,23,0.85)"},
    "color": {"inherit": false, "opacity": 0.8},
    "smooth": {"type": "continuous"}
  },
  "physics": {
    "barnesHut": {"gravitationalConstant": -40000, "springLength": 260},
    "stabilization": {"enabled": true, "iterations": 250}
  },
  "interaction": {"hover": true, "tooltipDelay": 120, "zoomView": true, "dragView": true}
}
""")
        except Exception:
            pass  # fall back to defaults if pyvis version differs

        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
        net.write_html(tmp.name)
        with open(tmp.name, "r", encoding="utf-8") as f:
            html = f.read()
        st.components.v1.html(html, height=720, scrolling=True)
        os.unlink(tmp.name)

        st.markdown("**Triple Preview**")
        st.dataframe(pd.DataFrame(triples[:25], columns=["Subject","Relation","Object"]))


   