Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions apps/memos-local-plugin/core/injection/scheduler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import type { SessionId, EpisodeId } from "../../agent-contract/dto.js";
import type { IntentDecision, TurnRelation } from "../session/types.js";

export type InjectionScenarioId =
| "CHITCHAT"
| "META"
| "MEMORY_PROBE"
| "NEW_TASK"
| "FOLLOW_UP"
| "TASK"
| "UNKNOWN_SAFE";

export interface SchedulerContext {
userText: string;
sessionId: SessionId;
episodeId: EpisodeId;
intent: IntentDecision;
relation?: TurnRelation | "bootstrap" | "lightweight_memory";
}

export interface RetrievePlan {
scenarioId: InjectionScenarioId;
entry: "turn_start" | "turn_start_skip";
wantTier1: boolean;
wantTier2: boolean;
wantTier3: boolean;
prepend: boolean;
}

export function scheduleInjection(ctx: SchedulerContext): RetrievePlan {
const { intent, relation } = ctx;

if (intent.kind === "chitchat" && intent.confidence >= 0.6) {
return skipPlan("CHITCHAT");
}

if (intent.kind === "chitchat") {
return retrievePlan("UNKNOWN_SAFE", { tier1: true, tier2: true, tier3: true });
}

if (intent.kind === "meta") {
return skipPlan("META");
}

if (intent.kind === "memory_probe") {
return retrievePlan("MEMORY_PROBE", intent.retrieval);
}

if (relation === "new_task") {
return retrievePlan("NEW_TASK", intent.retrieval);
}

if (relation === "revision" || relation === "follow_up" || relation === "unknown") {
return retrievePlan("FOLLOW_UP", intent.retrieval);
}

if (intent.kind === "unknown") {
return retrievePlan("UNKNOWN_SAFE", { tier1: true, tier2: true, tier3: true });
}

return retrievePlan("TASK", intent.retrieval);
}

function skipPlan(scenarioId: Extract<InjectionScenarioId, "CHITCHAT" | "META">): RetrievePlan {
return {
scenarioId,
entry: "turn_start_skip",
wantTier1: false,
wantTier2: false,
wantTier3: false,
prepend: false,
};
}

function retrievePlan(
scenarioId: Exclude<InjectionScenarioId, "CHITCHAT" | "META">,
retrieval: IntentDecision["retrieval"],
): RetrievePlan {
return {
scenarioId,
entry: "turn_start",
wantTier1: retrieval.tier1,
wantTier2: retrieval.tier2,
wantTier3: retrieval.tier3,
prepend: true,
};
}
4 changes: 4 additions & 0 deletions apps/memos-local-plugin/core/pipeline/memory-core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4612,6 +4612,8 @@ function findLatestPersistedModelStatus(
}

function retrievalStatsPayload(s: import("../retrieval/types.js").RetrievalStats): {
scenarioId?: string;
plannedTiers?: { tier1: boolean; tier2: boolean; tier3: boolean };
raw?: number;
ranked?: number;
droppedByThreshold?: number;
Expand All @@ -4629,6 +4631,8 @@ function retrievalStatsPayload(s: import("../retrieval/types.js").RetrievalStats
embedding?: import("../retrieval/types.js").RetrievalStats["embedding"];
} {
return {
scenarioId: s.scenarioId,
plannedTiers: s.plannedTiers,
raw: s.rawCandidateCount,
ranked: s.rankedCount,
droppedByThreshold: s.droppedByThresholdCount,
Expand Down
142 changes: 138 additions & 4 deletions apps/memos-local-plugin/core/pipeline/orchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import {
repairRetrieve,
} from "../retrieval/retrieve.js";
import type { RetrievalResult } from "../retrieval/types.js";
import { scheduleInjection, type RetrievePlan } from "../injection/scheduler.js";

import {
buildPipelineBuses,
Expand Down Expand Up @@ -80,7 +81,7 @@ import { memoryBuffer } from "../logger/index.js";
import { onBroadcastLog } from "../logger/transports/sse-broadcast.js";
import { createEmbeddingRetryWorker, systemErrorEvent } from "../embedding/index.js";
import type { EpisodeSnapshot } from "../session/index.js";
import type { RelationDecision } from "../session/types.js";
import type { IntentDecision, RelationDecision, TurnRelation } from "../session/types.js";

// ─── Factory ──────────────────────────────────────────────────────────────

Expand Down Expand Up @@ -968,7 +969,10 @@ export function createPipeline(deps: PipelineDeps): PipelineHandle {
};
}

async function retrieveTurnStart(input: TurnInputDTO): Promise<InjectionPacket> {
async function retrieveTurnStart(
input: TurnInputDTO,
plan?: RetrievePlan,
): Promise<InjectionPacket> {
const ctx = {
reason: "turn_start" as const,
agent: input.agent,
Expand All @@ -981,7 +985,17 @@ export function createPipeline(deps: PipelineDeps): PipelineHandle {
const result: RetrievalResult = await turnStartRetrieve(
retrievalDepsFor(input.namespace),
ctx,
{ events: buses.retrieval },
{
events: buses.retrieval,
plan: plan
? {
scenarioId: plan.scenarioId,
wantTier1: plan.wantTier1,
wantTier2: plan.wantTier2,
wantTier3: plan.wantTier3,
}
: undefined,
},
);
turnStartRetrievalStats.set(result.packet.packetId, result.stats);
return result.packet;
Expand Down Expand Up @@ -1060,9 +1074,47 @@ export function createPipeline(deps: PipelineDeps): PipelineHandle {
sessionId,
episodeId: episode.id as EpisodeId,
};
const schedulerIntent = await intentForCurrentTurn({
episode,
userText: input.userText,
ts: input.ts,
});
const retrievePlan = scheduleInjection({
userText: input.userText,
sessionId,
episodeId: episode.id as EpisodeId,
intent: schedulerIntent,
relation: schedulerRelation(routing.relation),
});

try {
const packet = await retrieveTurnStart(normalized);
if (retrievePlan.entry === "turn_start_skip") {
const packet = emptyInjectionPacket(input.agent, sessionId, episode.id as EpisodeId, input.ts);
turnStartRetrievalStats.set(
packet.packetId,
skippedRetrievalStats({
agent: input.agent,
sessionId,
episodeId: episode.id as EpisodeId,
scenarioId: retrievePlan.scenarioId,
userText: input.userText,
elapsedMs: now() - t0,
}),
);
log.info("turn.started", {
agent: input.agent,
sessionId,
episodeId: episode.id,
userChars: input.userText.length,
retrievalScenario: retrievePlan.scenarioId,
retrievalSkipped: true,
retrievalTotalMs: 0,
elapsedMs: now() - t0,
});
return packet;
}

const packet = await retrieveTurnStart(normalized, retrievePlan);
// Always stamp the routed sessionId + episodeId on the packet so
// adapters can correlate the subsequent `agent_end` / `turn.end`
// call without needing a separate round-trip to the session
Expand All @@ -1079,6 +1131,7 @@ export function createPipeline(deps: PipelineDeps): PipelineHandle {
sessionId,
episodeId: episode.id,
userChars: input.userText.length,
retrievalScenario: retrievePlan.scenarioId,
retrievalTotalMs: packet.tierLatencyMs.tier1 +
packet.tierLatencyMs.tier2 +
packet.tierLatencyMs.tier3,
Expand Down Expand Up @@ -1385,6 +1438,27 @@ export function createPipeline(deps: PipelineDeps): PipelineHandle {
return typeof ts === "number" && Number.isFinite(ts) ? ts : undefined;
}

async function intentForCurrentTurn(input: {
episode: EpisodeSnapshot;
userText: string;
ts?: number;
}): Promise<IntentDecision> {
const firstTurn = input.episode.turns[0];
const isFreshEpisodeForThisTurn =
input.episode.turns.length === 1 &&
firstTurn?.role === "user" &&
firstTurn.content === input.userText &&
(input.ts == null || firstTurn.ts === input.ts);

if (isFreshEpisodeForThisTurn) {
return input.episode.intent;
}

return session.intent.classify(input.userText, {
episodeId: input.episode.id as EpisodeId,
});
}

/**
* Build richer context for the relation classifier from episode turns.
*
Expand Down Expand Up @@ -1500,6 +1574,66 @@ function emptyInjectionPacket(
};
}

function skippedRetrievalStats(input: {
agent: AgentKind;
sessionId: SessionId;
episodeId: EpisodeId;
scenarioId: string;
userText: string;
elapsedMs: number;
}): RetrievalResult["stats"] {
return {
reason: "turn_start",
scenarioId: input.scenarioId,
agent: input.agent,
sessionId: input.sessionId,
episodeId: input.episodeId,
plannedTiers: { tier1: false, tier2: false, tier3: false },
tier1Count: 0,
tier2Count: 0,
tier3Count: 0,
tier1LatencyMs: 0,
tier2LatencyMs: 0,
tier3LatencyMs: 0,
fuseLatencyMs: 0,
totalLatencyMs: Math.max(0, input.elapsedMs),
queryTokens: Math.ceil(input.userText.length / 4),
queryTags: [],
emptyPacket: true,
embedding: {
attempted: false,
ok: false,
degraded: false,
},
rawCandidateCount: 0,
droppedByThresholdCount: 0,
thresholdFloor: 0,
topRelevance: 0,
rankedCount: 0,
llmFilterOutcome: "skipped_by_scheduler",
llmFilterSufficient: true,
llmFilterKept: 0,
llmFilterDropped: 0,
channelHits: {},
};
}

function schedulerRelation(
relation: string | undefined,
): TurnRelation | "bootstrap" | "lightweight_memory" | undefined {
if (
relation === "revision" ||
relation === "follow_up" ||
relation === "new_task" ||
relation === "unknown" ||
relation === "bootstrap" ||
relation === "lightweight_memory"
) {
return relation;
}
return undefined;
}

function _assertConfigShape(
algorithm: PipelineAlgorithmConfig,
feedback: FeedbackConfig,
Expand Down
37 changes: 33 additions & 4 deletions apps/memos-local-plugin/core/retrieval/retrieve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ export interface RetrieveOptions {
events?: RetrievalEventBus;
/** Override `limit` default (tier totals honored when unspecified). */
limit?: number;
/** Turn-start scheduler override. V1 uses this for intent tier gating. */
plan?: RetrievePlanOverride;
}

export interface RetrievePlanOverride {
scenarioId?: string;
wantTier1?: boolean;
wantTier2?: boolean;
wantTier3?: boolean;
limit?: number;
}

// ─── Entry point: turn_start ────────────────────────────────────────────────
Expand All @@ -76,24 +86,24 @@ export async function turnStartRetrieve(
opts: RetrieveOptions = {},
): Promise<RetrievalResult> {
if (deps.config.lightweightMemory) {
return runAll(deps, ctx, opts, {
return runAll(deps, ctx, opts, applyPlanOverride({
wantTier1: false,
wantTier2: true,
wantTier3: false,
includeLowValue: false,
limit: opts.limit ?? Math.max(1, deps.config.tier2TopK),
traceOnly: true,
});
}, opts.plan));
}
return runAll(deps, ctx, opts, {
return runAll(deps, ctx, opts, applyPlanOverride({
wantTier1: true,
wantTier2: true,
wantTier3: true,
includeLowValue: deps.config.includeLowValue,
limit:
opts.limit ??
deps.config.tier1TopK + deps.config.tier2TopK + deps.config.tier3TopK,
});
}, opts.plan));
}

// ─── Entry point: tool_driven ───────────────────────────────────────────────
Expand Down Expand Up @@ -188,6 +198,7 @@ export async function repairRetrieve(
// ─── Shared pipeline ────────────────────────────────────────────────────────

interface RunPlan {
scenarioId?: string;
wantTier1: boolean;
wantTier2: boolean;
wantTier3: boolean;
Expand All @@ -196,6 +207,18 @@ interface RunPlan {
traceOnly?: boolean;
}

function applyPlanOverride(plan: RunPlan, override?: RetrievePlanOverride): RunPlan {
if (!override) return plan;
return {
...plan,
scenarioId: override.scenarioId ?? plan.scenarioId,
wantTier1: override.wantTier1 ?? plan.wantTier1,
wantTier2: override.wantTier2 ?? plan.wantTier2,
wantTier3: override.wantTier3 ?? plan.wantTier3,
limit: override.limit ?? plan.limit,
};
}

async function runAll(
deps: RetrievalDeps,
ctx: RetrievalCtx,
Expand Down Expand Up @@ -446,9 +469,15 @@ async function runAll(

const stats: RetrievalStats = {
reason: ctx.reason,
scenarioId: plan.scenarioId,
agent,
sessionId,
episodeId,
plannedTiers: {
tier1: plan.wantTier1,
tier2: plan.wantTier2,
tier3: plan.wantTier3,
},
tier1Count: tier1.length,
tier2Count: tier2.traces.length + (traceOnly ? 0 : tier2.episodes.length) + tier2Experiences.length,
tier3Count: tier3.length,
Expand Down
Loading