diff --git a/apps/roam/src/components/DiscourseContextOverlay.tsx b/apps/roam/src/components/DiscourseContextOverlay.tsx index a9b9284e5..eb34ff73f 100644 --- a/apps/roam/src/components/DiscourseContextOverlay.tsx +++ b/apps/roam/src/components/DiscourseContextOverlay.tsx @@ -5,6 +5,7 @@ import { Position, Tooltip, ControlGroup, + Spinner, } from "@blueprintjs/core"; import React, { useCallback, useEffect, useMemo, useState } from "react"; import ReactDOM from "react-dom"; @@ -27,6 +28,7 @@ import getAllPageNames from "roamjs-components/queries/getAllPageNames"; import { Result } from "roamjs-components/types/query-builder"; import createBlock from "roamjs-components/writes/createBlock"; import { getBlockUidFromTarget } from "roamjs-components/dom"; +import { SuggestedNode, RelationDetails } from "~/utils/hyde"; type DiscourseData = { results: Awaited>; @@ -102,9 +104,14 @@ const DiscourseContextOverlay = ({ const [results, setResults] = useState([]); const [refs, setRefs] = useState(0); const [score, setScore] = useState(0); + const [isSearchingHyde, setIsSearchingHyde] = useState(false); + const [hydeFilteredNodes, setHydeFilteredNodes] = useState( + [], + ); const discourseNode = useMemo(() => findDiscourseNode(tagUid), [tagUid]); const relations = useMemo(() => getDiscourseRelations(), []); + const allNodes = useMemo(() => getDiscourseNodes(), []); const getInfo = useCallback( () => @@ -138,13 +145,58 @@ const DiscourseContextOverlay = ({ getInfo(); }, [refresh, getInfo]); - // Suggestive Mode - const validTypes = useMemo(() => { + const validRelations = useMemo(() => { if (!discourseNode) return []; const selfType = discourseNode.type; - const validRelations = relations.filter((relation) => - [relation.source, relation.destination].includes(selfType), + + return relations.filter( + (relation) => + relation.source === selfType || relation.destination === selfType, ); + }, [relations, discourseNode]); + + const uniqueRelationTypeTriplets = useMemo(() => { + if (!discourseNode) return []; + const relatedNodeType = discourseNode.type; + + return validRelations.flatMap((relation) => { + const isSelfSource = relation.source === relatedNodeType; + const isSelfDestination = relation.destination === relatedNodeType; + + let targetNodeType: string; + let currentRelationLabel: string; + + if (isSelfSource) { + targetNodeType = relation.destination; + currentRelationLabel = relation.label; + } else if (isSelfDestination) { + targetNodeType = relation.source; + currentRelationLabel = relation.complement; + } else { + return []; + } + + const identifiedTargetNode = allNodes.find( + (node) => node.type === targetNodeType, + ); + + if (!identifiedTargetNode) { + return []; + } + + const mappedItem: RelationDetails = { + relationLabel: currentRelationLabel, + relatedNodeText: identifiedTargetNode.text, + relatedNodeFormat: identifiedTargetNode.format, + }; + return [mappedItem]; + }); + }, [validRelations, discourseNode, allNodes]); + + const validTypes = useMemo(() => { + if (!discourseNode) return []; + const selfType = discourseNode.type; + const hasSelfRelation = validRelations.some( (relation) => relation.source === selfType && relation.destination === selfType, @@ -158,9 +210,8 @@ const DiscourseContextOverlay = ({ ), ); return hasSelfRelation ? types : types.filter((type) => type !== selfType); - }, [discourseNode, relations]); + }, [discourseNode, validRelations]); - const [suggestedNodes, setSuggestedNodes] = useState([]); const [currentPageInput, setCurrentPageInput] = useState(""); const [selectedPage, setSelectedPage] = useState(null); const allPages = useMemo(() => getAllPageNames(), []); @@ -170,7 +221,7 @@ const DiscourseContextOverlay = ({ useEffect(() => { if (!selectedPage) { - setSuggestedNodes([]); + setHydeFilteredNodes([]); return; } const nodesOnPage = getAllReferencesOnPage(selectedPage); @@ -178,25 +229,82 @@ const DiscourseContextOverlay = ({ .map((n) => { const node = findDiscourseNode(n.uid); if (!node || node.backedBy === "default") return null; + if (!validTypes.includes(node.type)) return null; return { uid: n.uid, text: n.text, type: node.type, }; }) - .filter((node) => node !== null) - .filter((node) => validTypes.includes(node.type)) - .filter((node) => !results.some((r) => Object.values(r.results).some((result) => result.uid === node.uid))); + .filter((node): node is SuggestedNode => node !== null) + .filter( + (node) => + !results.some((r) => + Object.values(r.results).some((result) => result.uid === node.uid), + ), + ); - setSuggestedNodes(nodes); - }, [selectedPage, discourseNode, relations]); + if (nodes.length > 0 && uniqueRelationTypeTriplets.length > 0) { + const performSearch = async () => { + setIsSearchingHyde(true); + setHydeFilteredNodes([]); + try { + const candidateNodesForHyde = nodes.map((node) => ({ + uid: node.uid, + text: node.text, + type: node.type, + })); + + // TODO: Remove this once the HyDE search is working + const foundNodes: SuggestedNode[] = + await tempFindSimilarNodesUsingHyde({ + candidateNodes: candidateNodesForHyde, + currentNodeText: tag, + relationDetails: uniqueRelationTypeTriplets, + }); + + // TODO: Uncomment this once the HyDE search is working + // const foundNodes: SuggestedNode[] = await findSimilarNodesUsingHyde({ + // candidateNodes: candidateNodesForHyde, + // currentNodeText: tag, + // relationDetails: uniqueRelationTypeTriplets, + // }); + + setHydeFilteredNodes(foundNodes); + } catch (error) { + console.error( + "Error during HyDE search operation in useEffect:", + error, + ); + setHydeFilteredNodes([]); + } finally { + setIsSearchingHyde(false); + } + }; + performSearch(); + } + }, [selectedPage, results, validTypes, tag, uniqueRelationTypeTriplets]); + + // TODO: Remove this once the HyDE search is working + const tempFindSimilarNodesUsingHyde = async ({ + candidateNodes, + currentNodeText, + relationDetails, + }: { + candidateNodes: SuggestedNode[]; + currentNodeText: string; + relationDetails: RelationDetails[]; + }): Promise => { + console.log("running stub for hyde search", candidateNodes); + return candidateNodes; + }; - const handleCreateBlock = async (node: { uid: string; text: string }) => { + const handleCreateBlock = async (node: SuggestedNode) => { await createBlock({ parentUid: blockUid, node: { text: `[[${node.text}]]` }, }); - setSuggestedNodes(suggestedNodes.filter((n) => n.uid !== node.uid)); + setHydeFilteredNodes(hydeFilteredNodes.filter((n) => n.uid !== node.uid)); }; return ( @@ -246,21 +354,25 @@ const DiscourseContextOverlay = ({

Suggested Relationships

+ {isSearchingHyde && ( + + )}
    - {suggestedNodes.length > 0 ? ( - suggestedNodes.map((node) => ( -
  • - {node.text} -
  • - )) - ) : ( -
  • No relations found
  • + {!isSearchingHyde && hydeFilteredNodes.length > 0 + ? hydeFilteredNodes.map((node) => ( +
  • + {node.text} +
  • + )) + : null} + {!isSearchingHyde && hydeFilteredNodes.length === 0 && ( +
  • No relevant relations found.
  • )}
diff --git a/apps/roam/src/utils/hyde.ts b/apps/roam/src/utils/hyde.ts index 1e24635e9..babfb4e6f 100644 --- a/apps/roam/src/utils/hyde.ts +++ b/apps/roam/src/utils/hyde.ts @@ -17,7 +17,6 @@ export type EmbeddingVectorType = number[]; export type CandidateNodeWithEmbedding = Result & { type: string; - embedding: EmbeddingVectorType; }; export type SuggestedNode = Result & { @@ -308,9 +307,7 @@ const rankNodes = ({ maxScores.forEach((score, uid) => { const fullNode = nodeMap.get(uid); if (fullNode) { - const { embedding, ...restOfNode } = fullNode; - const suggestedNodeObject: SuggestedNode = restOfNode; - combinedResults.push({ node: suggestedNodeObject, score }); + combinedResults.push({ node: fullNode, score }); } });