Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 140 additions & 28 deletions apps/roam/src/components/DiscourseContextOverlay.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
Position,
Tooltip,
ControlGroup,
Spinner,
} from "@blueprintjs/core";
import React, { useCallback, useEffect, useMemo, useState } from "react";
import ReactDOM from "react-dom";
Expand All @@ -27,6 +28,7 @@ import getAllPageNames from "roamjs-components/queries/getAllPageNames";
import { Result } from "roamjs-components/types/query-builder";
import createBlock from "roamjs-components/writes/createBlock";
import { getBlockUidFromTarget } from "roamjs-components/dom";
import { SuggestedNode, RelationDetails } from "~/utils/hyde";

type DiscourseData = {
results: Awaited<ReturnType<typeof getDiscourseContextResults>>;
Expand Down Expand Up @@ -102,9 +104,14 @@ const DiscourseContextOverlay = ({
const [results, setResults] = useState<DiscourseData["results"]>([]);
const [refs, setRefs] = useState(0);
const [score, setScore] = useState<number | string>(0);
const [isSearchingHyde, setIsSearchingHyde] = useState(false);
const [hydeFilteredNodes, setHydeFilteredNodes] = useState<SuggestedNode[]>(
[],
);

const discourseNode = useMemo(() => findDiscourseNode(tagUid), [tagUid]);
const relations = useMemo(() => getDiscourseRelations(), []);
const allNodes = useMemo(() => getDiscourseNodes(), []);

const getInfo = useCallback(
() =>
Expand Down Expand Up @@ -138,13 +145,58 @@ const DiscourseContextOverlay = ({
getInfo();
}, [refresh, getInfo]);

// Suggestive Mode
const validTypes = useMemo(() => {
const validRelations = useMemo(() => {
if (!discourseNode) return [];
const selfType = discourseNode.type;
const validRelations = relations.filter((relation) =>
[relation.source, relation.destination].includes(selfType),

return relations.filter(
(relation) =>
relation.source === selfType || relation.destination === selfType,
);
}, [relations, discourseNode]);

const uniqueRelationTypeTriplets = useMemo(() => {
if (!discourseNode) return [];
const relatedNodeType = discourseNode.type;

return validRelations.flatMap((relation) => {
const isSelfSource = relation.source === relatedNodeType;
const isSelfDestination = relation.destination === relatedNodeType;

let targetNodeType: string;
let currentRelationLabel: string;

if (isSelfSource) {
targetNodeType = relation.destination;
currentRelationLabel = relation.label;
} else if (isSelfDestination) {
targetNodeType = relation.source;
currentRelationLabel = relation.complement;
} else {
return [];
}

const identifiedTargetNode = allNodes.find(
(node) => node.type === targetNodeType,
);

if (!identifiedTargetNode) {
return [];
}

const mappedItem: RelationDetails = {
relationLabel: currentRelationLabel,
relatedNodeText: identifiedTargetNode.text,
relatedNodeFormat: identifiedTargetNode.format,
};
return [mappedItem];
});
}, [validRelations, discourseNode, allNodes]);

const validTypes = useMemo(() => {
if (!discourseNode) return [];
const selfType = discourseNode.type;

const hasSelfRelation = validRelations.some(
(relation) =>
relation.source === selfType && relation.destination === selfType,
Expand All @@ -158,9 +210,8 @@ const DiscourseContextOverlay = ({
),
);
return hasSelfRelation ? types : types.filter((type) => type !== selfType);
}, [discourseNode, relations]);
}, [discourseNode, validRelations]);

const [suggestedNodes, setSuggestedNodes] = useState<Result[]>([]);
const [currentPageInput, setCurrentPageInput] = useState("");
const [selectedPage, setSelectedPage] = useState<string | null>(null);
const allPages = useMemo(() => getAllPageNames(), []);
Expand All @@ -170,33 +221,90 @@ const DiscourseContextOverlay = ({

useEffect(() => {
if (!selectedPage) {
setSuggestedNodes([]);
setHydeFilteredNodes([]);
return;
}
const nodesOnPage = getAllReferencesOnPage(selectedPage);
const nodes = nodesOnPage
.map((n) => {
const node = findDiscourseNode(n.uid);
if (!node || node.backedBy === "default") return null;
if (!validTypes.includes(node.type)) return null;
return {
uid: n.uid,
text: n.text,
type: node.type,
};
})
.filter((node) => node !== null)
.filter((node) => validTypes.includes(node.type))
.filter((node) => !results.some((r) => Object.values(r.results).some((result) => result.uid === node.uid)));
.filter((node): node is SuggestedNode => node !== null)
.filter(
(node) =>
!results.some((r) =>
Object.values(r.results).some((result) => result.uid === node.uid),
),
);

setSuggestedNodes(nodes);
}, [selectedPage, discourseNode, relations]);
if (nodes.length > 0 && uniqueRelationTypeTriplets.length > 0) {
const performSearch = async () => {
setIsSearchingHyde(true);
setHydeFilteredNodes([]);
try {
const candidateNodesForHyde = nodes.map((node) => ({
uid: node.uid,
text: node.text,
type: node.type,
}));

// TODO: Remove this once the HyDE search is working
const foundNodes: SuggestedNode[] =
await tempFindSimilarNodesUsingHyde({
candidateNodes: candidateNodesForHyde,
currentNodeText: tag,
relationDetails: uniqueRelationTypeTriplets,
});

// TODO: Uncomment this once the HyDE search is working
// const foundNodes: SuggestedNode[] = await findSimilarNodesUsingHyde({
// candidateNodes: candidateNodesForHyde,
// currentNodeText: tag,
// relationDetails: uniqueRelationTypeTriplets,
// });

setHydeFilteredNodes(foundNodes);
} catch (error) {
console.error(
"Error during HyDE search operation in useEffect:",
error,
);
setHydeFilteredNodes([]);
} finally {
setIsSearchingHyde(false);
}
};
performSearch();
}
}, [selectedPage, results, validTypes, tag, uniqueRelationTypeTriplets]);

// TODO: Remove this once the HyDE search is working
const tempFindSimilarNodesUsingHyde = async ({
candidateNodes,
currentNodeText,
relationDetails,
}: {
candidateNodes: SuggestedNode[];
currentNodeText: string;
relationDetails: RelationDetails[];
}): Promise<SuggestedNode[]> => {
console.log("running stub for hyde search", candidateNodes);
return candidateNodes;
};

const handleCreateBlock = async (node: { uid: string; text: string }) => {
const handleCreateBlock = async (node: SuggestedNode) => {
await createBlock({
parentUid: blockUid,
node: { text: `[[${node.text}]]` },
});
setSuggestedNodes(suggestedNodes.filter((n) => n.uid !== node.uid));
setHydeFilteredNodes(hydeFilteredNodes.filter((n) => n.uid !== node.uid));
};

return (
Expand Down Expand Up @@ -246,21 +354,25 @@ const DiscourseContextOverlay = ({
<h3 className="mb-2 text-base font-semibold">
Suggested Relationships
</h3>
{isSearchingHyde && (
<Spinner size={Spinner.SIZE_SMALL} className="mb-2" />
)}
<ul className="space-y-2">
{suggestedNodes.length > 0 ? (
suggestedNodes.map((node) => (
<li key={node.uid} className="">
<span>{node.text}</span>
<Button
minimal
icon="add"
onClick={() => handleCreateBlock(node)}
className="ml-2"
/>
</li>
))
) : (
<li>No relations found</li>
{!isSearchingHyde && hydeFilteredNodes.length > 0
? hydeFilteredNodes.map((node) => (
<li key={node.uid} className="">
<span>{node.text}</span>
<Button
minimal
icon="add"
onClick={() => handleCreateBlock(node)}
className="ml-2"
/>
</li>
))
: null}
{!isSearchingHyde && hydeFilteredNodes.length === 0 && (
<li>No relevant relations found.</li>
)}
</ul>
</div>
Expand Down
5 changes: 1 addition & 4 deletions apps/roam/src/utils/hyde.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ export type EmbeddingVectorType = number[];

export type CandidateNodeWithEmbedding = Result & {
type: string;
embedding: EmbeddingVectorType;
};

export type SuggestedNode = Result & {
Expand Down Expand Up @@ -308,9 +307,7 @@ const rankNodes = ({
maxScores.forEach((score, uid) => {
const fullNode = nodeMap.get(uid);
if (fullNode) {
const { embedding, ...restOfNode } = fullNode;
const suggestedNodeObject: SuggestedNode = restOfNode;
combinedResults.push({ node: suggestedNodeObject, score });
combinedResults.push({ node: fullNode, score });
}
});

Expand Down