Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 116 additions & 6 deletions apps/roam/src/components/CreateNodeDialog.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import getDiscourseNodes, {
import { getNewDiscourseNodeText } from "~/utils/formatUtils";
import MenuItemSelect from "roamjs-components/components/MenuItemSelect";
import createBlock from "roamjs-components/writes/createBlock";
import { findSimilarNodes, SuggestedNode } from "~/utils/hyde";
import { Spinner } from "@blueprintjs/core";

export type CreateNodeDialogProps = {
onClose: () => void;
Expand All @@ -37,6 +39,12 @@ const CreateNodeDialog = ({
const [selectedType, setSelectedType] =
useState<DiscourseNode>(defaultNodeType);
const [loading, setLoading] = useState(false);
const [filteredSuggestions, setFilteredSuggestions] = useState<
SuggestedNode[]
>([]);
const [rawSuggestions, setRawSuggestions] = useState<SuggestedNode[]>([]);
const [suggestionsLoading, setSuggestionsLoading] = useState(false);
const [formattedTitle, setFormattedTitle] = useState("");
const inputRef = useRef<HTMLInputElement>(null);

useEffect(() => {
Expand All @@ -45,16 +53,55 @@ const CreateNodeDialog = ({
}
}, []);

useEffect(() => {
let isCancelled = false;
const compute = async () => {
const base = title.trim();
if (!base) {
setFormattedTitle("");
return;
}
const ft = await getNewDiscourseNodeText({
text: base,
nodeType: selectedType.type,
blockUid: sourceBlockUid,
});
if (!isCancelled) setFormattedTitle(ft || "");
};
void compute();
return () => {
isCancelled = true;
};
}, [title, selectedType.type, sourceBlockUid]);

useEffect(() => {
const fetchSuggestions = async () => {
if (formattedTitle.trim()) {
setSuggestionsLoading(true);
console.log(
"fetching suggestions for",
formattedTitle,
selectedType.type,
);
const { raw, filtered } = await findSimilarNodes({
text: formattedTitle,
nodeType: selectedType.type,
});
setRawSuggestions(raw);
setFilteredSuggestions(filtered);
setSuggestionsLoading(false);
} else {
setRawSuggestions([]);
setFilteredSuggestions([]);
}
};
void fetchSuggestions();
}, [formattedTitle, selectedType.type]);

const onCreate = async () => {
if (!title.trim()) return;
setLoading(true);

const formattedTitle = await getNewDiscourseNodeText({
text: title.trim(),
nodeType: selectedType.type,
blockUid: sourceBlockUid,
});

if (!formattedTitle) {
setLoading(false);
return;
Expand Down Expand Up @@ -119,6 +166,24 @@ const CreateNodeDialog = ({
onClose();
};

const handleSuggestionClick = async (node: SuggestedNode) => {
if (sourceBlockUid) {
const pageRef = `[[${node.text}]]`;
await updateBlock({
uid: sourceBlockUid,
text: pageRef,
});
await createBlock({
parentUid: sourceBlockUid,
order: 0,
node: {
text: initialTitle,
},
});
}
onClose();
};

return (
<Dialog
isOpen={true}
Expand All @@ -138,6 +203,51 @@ const CreateNodeDialog = ({
/>
</div>

{suggestionsLoading && (
<div className="flex items-center gap-2">
<Spinner size={16} />
<span>Fetching possible duplicates...</span>
</div>
)}
{rawSuggestions.length > 0 && (
<div className="flex flex-col gap-1">
<h4 className="font-bold">Possible duplicates (Semantic)</h4>
<ul className="flex flex-col gap-1">
{rawSuggestions.map((node) => (
<li key={node.uid}>
<a
onClick={() => {
void handleSuggestionClick(node);
}}
className="cursor-pointer text-blue-500 hover:underline"
>
{node.text}
</a>
</li>
))}
</ul>
</div>
)}
{filteredSuggestions.length > 0 && (
<div className="flex flex-col gap-1">
<h4 className="font-bold">Possible duplicates (LLM Filtered)</h4>
<ul className="flex flex-col gap-1">
{filteredSuggestions.map((node) => (
<li key={node.uid}>
<a
onClick={() => {
void handleSuggestionClick(node);
}}
className="cursor-pointer text-blue-500 hover:underline"
>
{node.text}
</a>
</li>
))}
</ul>
</div>
)}

<Label>
Type
<MenuItemSelect
Expand Down
154 changes: 154 additions & 0 deletions apps/roam/src/utils/hyde.ts
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,79 @@ const rankNodes = ({
return combinedResults.map((item) => item.node);
};

const filterAndRerankByLlm = async ({
originalText,
candidates,
}: {
originalText: string;
candidates: SuggestedNode[];
}): Promise<SuggestedNode[]> => {
if (candidates.length === 0) {
return [];
}

const candidateList = candidates
.map((c, i) => `${i + 1}. "${c.text}"`)
.join("\n");

const userPromptContent = `Given the original text for a new node: "${originalText}".

Here is a list of existing nodes that might be duplicates:
${candidateList}

Your task is to identify which of these existing nodes are strong potential duplicates for the new node. A strong potential duplicate is one that covers the same core concepts or ideas as the new node, not just a partial or superficial match.

Please return a JSON array of strings, containing the exact text of only the nodes you've identified as strong potential duplicates. The list should be ordered from the most likely duplicate to the least likely. If you think none of the candidates are strong duplicates, return an empty JSON array.

For example, if you decide only candidates 3 and 1 are strong duplicates, and 3 is more likely, your response should be:
["text of candidate 3", "text of candidate 1"]

Only return the JSON array.`;

const requestBody = {
documents: [{ role: "user", content: userPromptContent }],
passphrase: "",
settings: {
model: API_CONFIG.LLM.MODEL,
maxTokens: 500,
temperature: 0.2,
},
};

let response: Response | null = null;
try {
const signal = AbortSignal.timeout(API_CONFIG.LLM.TIMEOUT_MS);
response = await fetch(API_CONFIG.LLM.URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
signal,
});

if (!response.ok) {
await handleApiError(response, "LLM reranking");
return candidates;
}

const responseText = await response.text();
const rerankedTextList = JSON.parse(responseText) as string[];

const originalCandidatesMap = new Map(candidates.map((c) => [c.text, c]));
const rerankedNodes = rerankedTextList
.map((text) => originalCandidatesMap.get(text))
.filter((node): node is SuggestedNode => !!node);

console.log("rerankedNodes", rerankedNodes);

return rerankedNodes;
} catch (error: unknown) {
console.error("LLM reranking failed:", error);
return candidates;
}
};

export const findSimilarNodesUsingHyde = async ({
candidateNodes,
currentNodeText,
Expand Down Expand Up @@ -530,3 +603,84 @@ export const performHydeSearch = async ({
}
return [];
};

export const findSimilarNodes = async ({
text,
nodeType,
}: {
text: string;
nodeType: string;
}): Promise<{ raw: SuggestedNode[]; filtered: SuggestedNode[] }> => {
const emptyResult = { raw: [], filtered: [] };
if (!text.trim() || !nodeType) {
return emptyResult;
}

try {
const context = await getSupabaseContext();
if (!context) return emptyResult;
const supabase = await getLoggedInClient();
const { spaceId } = context;
if (!supabase) return emptyResult;

const candidateNodesForHyde = (
await getNodesByType({
supabase,
spaceId,
fields: { content: ["source_local_id", "text"] },
ofTypes: [nodeType],
pagination: { limit: 10000 },
})
)
.map((c) => {
const node = findDiscourseNode(c.Content?.source_local_id || "");
return {
uid: c.Content?.source_local_id || "",
text: c.Content?.text || "",
type: node ? node.type : "",
};
})
.filter((n) => n.uid && n.text && n.type);

if (candidateNodesForHyde.length === 0) {
return emptyResult;
}

const queryEmbedding = await createEmbedding(text);
const searchResults = await searchEmbeddings({
queryEmbedding,
indexData: candidateNodesForHyde,
});
console.log("searchResults", searchResults);

const nodeMap = new Map<string, CandidateNodeWithEmbedding>(
candidateNodesForHyde.map((node) => [node.uid, node]),
);

const combinedResults: { node: SuggestedNode; score: number }[] = [];
searchResults.forEach((result) => {
const fullNode = nodeMap.get(result.object.uid);
if (fullNode) {
combinedResults.push({ node: fullNode, score: result.score });
}
});

combinedResults.sort((a, b) => b.score - a.score);
const topCandidates = combinedResults.slice(0, 7).map((item) => item.node);

if (topCandidates.length === 0) {
return emptyResult;
}
console.log("topCandidates", topCandidates);

const filteredResults = await filterAndRerankByLlm({
originalText: text,
candidates: topCandidates,
});

return { raw: topCandidates, filtered: filteredResults };
} catch (error) {
console.error("Error finding similar nodes:", error);
return { raw: [], filtered: [] };
}
};