diff --git a/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationTool.tsx b/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationTool.tsx index 3e809406b..8c068c2bc 100644 --- a/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationTool.tsx +++ b/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationTool.tsx @@ -363,14 +363,17 @@ export const createAllRelationShapeTools = ( ); const relation = discourseContext.relations[name].find( - (r) => r.source === target?.type, + (r) => r.source === target?.type || r.destination === target?.type, ); if (relation) { this.shapeType = relation.id; } else { - const acceptableTypes = discourseContext.relations[name].map( - (r) => discourseContext.nodes[r.source].text, - ); + const acceptableTypes = discourseContext.relations[name] + .flatMap((r) => [ + discourseContext.nodes[r.source]?.text, + discourseContext.nodes[r.destination]?.text, + ]) + .filter(Boolean); const uniqueTypes = [...new Set(acceptableTypes)]; this.cancelAndWarn( `Starting node must be one of ${uniqueTypes.join(", ")}`, diff --git a/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx b/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx index 5e40a22ad..46fd38643 100644 --- a/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx +++ b/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx @@ -65,6 +65,7 @@ import { } from "./helpers"; import { createReifiedRelation } from "~/utils/createReifiedBlock"; import { getStoredRelationsEnabled } from "~/utils/storedRelations"; +import type { DiscourseRelation } from "~/utils/getDiscourseRelations"; import { discourseContext, isPageUid } from "~/components/canvas/Tldraw"; import getPageUidByPageTitle from "roamjs-components/queries/getPageUidByPageTitle"; @@ -624,41 +625,77 @@ export const createAllRelationShapeUtils = ( const relations = Object.values(discourseContext.relations).flat(); const relation = relations.find((r) => r.id === arrow.type); if (!relation) return; - const possibleTargets = discourseContext.relations[relation.label] - .filter((r) => r.source === relation.source) - .map((r) => r.destination); - if (!possibleTargets.includes(target.type)) { - const uniqueTargets = [...new Set(possibleTargets)]; - const uniqueTargetTexts = uniqueTargets.map( - (t) => discourseContext.nodes[t].text, + const sourceNodeType = source.type; + const targetNodeType = target.type; + + // Check all relations with the same label for a match + const { + isDirect, + isReverse, + matchingRelation: foundRelation, + } = this.checkConnectionTypeAcrossLabel( + relation.label, + sourceNodeType, + targetNodeType, + ); + const matchingRelation = foundRelation ?? relation; + + if (!isDirect && !isReverse) { + const validTargets = this.getValidTargetTypes( + relation.label, + sourceNodeType, + ); + const uniqueTargetTexts = validTargets.map( + (t) => discourseContext.nodes[t]?.text || t, ); return deleteAndWarn( `Target node must be of type ${uniqueTargetTexts.join(", ")}`, ); } - if (arrow.type !== target.type) { - editor.updateShapes([{ id: arrow.id, type: target.type }]); + + // If we found a matching relation with a different ID, switch to it + if (matchingRelation.id !== arrow.type) { + // Get bindings before updating the shape type + const existingBindings = editor.getBindingsFromShape( + arrow, + arrow.type, + ); + // Update the shape type + editor.updateShapes([{ id: arrow.id, type: matchingRelation.id }]); + // Update bindings to use the new relation type + for (const binding of existingBindings) { + editor.updateBinding({ + ...binding, + type: matchingRelation.id, + }); + } } if (getStoredRelationsEnabled()) { const sourceAsDNS = asDiscourseNodeShape(source, editor); const targetAsDNS = asDiscourseNodeShape(target, editor); - if (sourceAsDNS && targetAsDNS) + if (sourceAsDNS && targetAsDNS) { + const isOriginal = isDirect; await createReifiedRelation({ - sourceUid: sourceAsDNS.props.uid, - destinationUid: targetAsDNS.props.uid, - relationBlockUid: relation.id, + sourceUid: isOriginal + ? sourceAsDNS.props.uid + : targetAsDNS.props.uid, + destinationUid: isOriginal + ? targetAsDNS.props.uid + : sourceAsDNS.props.uid, + relationBlockUid: matchingRelation.id, }); - else { + } else { void internalError({ error: "attempt to create a relation between non discourse nodes", type: "Canvas create relation", }); } } else { - const { triples, label: relationLabel } = relation; - const isOriginal = arrow.props.text === relationLabel; + const { triples } = matchingRelation; + const isOriginal = isDirect; + const newTriples = triples .map((t) => { if (/is a/i.test(t[1])) { @@ -845,6 +882,33 @@ export const createAllRelationShapeUtils = ( return update; } + // Validate target node type compatibility before creating binding + if ( + target.type !== "arrow" && + otherBinding && + target.id !== otherBinding.toId && + (!currentBinding || target.id !== currentBinding.toId) + ) { + const sourceNodeId = otherBinding.toId; + const sourceNode = this.editor.getShape(sourceNodeId); + const targetNodeType = target.type; + const sourceNodeType = sourceNode?.type; + + if (sourceNodeType && targetNodeType && shape.type) { + const isValidConnection = this.isValidNodeConnection( + sourceNodeType, + targetNodeType, + shape.type, + ); + + if (!isValidConnection) { + removeArrowBinding(this.editor, shape, handleId); + update.props![handleId] = { x: handle.x, y: handle.y }; + return update; + } + } + } + // we've got a target! the handle is being dragged over a shape, bind to it const targetGeometry = this.editor.getShapeGeometry(target); @@ -921,6 +985,42 @@ export const createAllRelationShapeUtils = ( this.editor.setHintingShapes([target.id]); const newBindings = getArrowBindings(this.editor, shape); + + // Check if both ends are bound and determine the correct text based on direction + if (newBindings.start && newBindings.end) { + const relations = Object.values(discourseContext.relations).flat(); + const relation = relations.find((r) => r.id === shape.type); + + if (relation) { + const startNode = this.editor.getShape(newBindings.start.toId); + const endNode = this.editor.getShape(newBindings.end.toId); + + if (startNode && endNode) { + const startNodeType = startNode.type; + const endNodeType = endNode.type; + + const { isReverse, matchingRelation } = + this.checkConnectionTypeAcrossLabel( + relation.label, + startNodeType, + endNodeType, + ); + + const effectiveRelation = matchingRelation ?? relation; + + const newText = + isReverse && effectiveRelation.complement + ? effectiveRelation.complement + : effectiveRelation.label; + + if (shape.props.text !== newText) { + update.props = update.props || {}; + update.props.text = newText; + } + } + } + } + if ( newBindings.start && newBindings.end && @@ -1601,6 +1701,79 @@ export class BaseDiscourseRelationUtil extends ShapeUtil ]; } + checkConnectionType( + relation: { source: string; destination: string }, + sourceNodeType: string, + targetNodeType: string, + ): { isDirect: boolean; isReverse: boolean } { + const isDirect = + sourceNodeType === relation.source && + targetNodeType === relation.destination; + + const isReverse = + sourceNodeType === relation.destination && + targetNodeType === relation.source; + + return { isDirect, isReverse }; + } + + checkConnectionTypeAcrossLabel( + label: string, + sourceNodeType: string, + targetNodeType: string, + ): { + isDirect: boolean; + isReverse: boolean; + matchingRelation: DiscourseRelation | null; + } { + const relationsWithLabel = discourseContext.relations[label]; + if (!relationsWithLabel) { + return { isDirect: false, isReverse: false, matchingRelation: null }; + } + + for (const rel of relationsWithLabel) { + const { isDirect, isReverse } = this.checkConnectionType( + rel, + sourceNodeType, + targetNodeType, + ); + if (isDirect || isReverse) { + return { isDirect, isReverse, matchingRelation: rel }; + } + } + + return { isDirect: false, isReverse: false, matchingRelation: null }; + } + + getValidTargetTypes(label: string, sourceNodeType: string): string[] { + const relationsWithLabel = discourseContext.relations[label]; + if (!relationsWithLabel) return []; + + const targets = new Set(); + for (const rel of relationsWithLabel) { + if (rel.source === sourceNodeType) targets.add(rel.destination); + if (rel.destination === sourceNodeType) targets.add(rel.source); + } + return [...targets]; + } + + isValidNodeConnection( + sourceNodeType: string, + targetNodeType: string, + relationId: string, + ): boolean { + const relations = Object.values(discourseContext.relations).flat(); + const relation = relations.find((r) => r.id === relationId); + if (!relation) return false; + + const { isDirect, isReverse } = this.checkConnectionTypeAcrossLabel( + relation.label, + sourceNodeType, + targetNodeType, + ); + return isDirect || isReverse; + } + component(shape: DiscourseRelationShape) { // eslint-disable-next-line react-hooks/rules-of-hooks // const theme = useDefaultColorTheme();