Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions packages/core/src/Components/DataSourceLookup.class.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ export class DataSourceLookup extends Component {
// Need to reserve 30 characters for the prefixed unique id
'string.max': `The length of the 'namespace' name must be 50 characters or fewer.`,
}),
scoreThreshold: Joi.number().optional().label('Score Threshold'),
includeScore: Joi.boolean().optional().label('Include Score'),
});
constructor() {
super();
Expand All @@ -53,6 +55,9 @@ export class DataSourceLookup extends Component {
const postprocess = config.data?.postprocess || false;
const includeMetadata = config.data?.includeMetadata || false;

const scoreThreshold = config.data?.scoreThreshold || 0.001; // Use low score (0.001) to return most results for backward compatibility
const includeScore = config.data?.includeScore || false;

const _input = typeof input.Query === 'string' ? input.Query : JSON.stringify(input.Query);

const topK = Math.max(config.data?.topK || 50, 50);
Expand All @@ -64,27 +69,40 @@ export class DataSourceLookup extends Component {
throw new Error(`Namespace ${namespace} does not exist`);
}

let results: string[] | { content: string; metadata: any }[];
let results: string[] | { content: string; metadata: any; score?: number }[];
let _error;
try {
const response = await vectorDbConnector
.requester(AccessCandidate.team(teamId))
.search(namespace, _input, { topK, includeMetadata: true });

results = response.slice(0, config.data.topK).map((result) => ({
content: result.text,
metadata: result.metadata,
score: result.score, // use a very low score to return
}));

if (includeMetadata) {
// only show user-level metadata
results = results.map((result) => ({
results = results.filter((result) => result.score >= scoreThreshold);

// Transform results based on inclusion flags
results = results.map((result) => {
const transformedResult: any = {
content: result.content,
//* legacy user-specific metadata key [result.metadata?.metadata]),
metadata: this.parseMetadata(result.metadata || result.metadata?.metadata),
}));
} else {
results = results.map((result) => result.content);
}
};

if (includeMetadata) {
// legacy user-specific metadata key [result.metadata?.metadata]
transformedResult.metadata = this.parseMetadata(result.metadata || result.metadata?.metadata);
}

if (includeScore) {
transformedResult.score = result.score;
}

// If neither metadata nor score is included, return just the content string
return includeMetadata || includeScore ? transformedResult : result.content;
});

debugOutput += `[Results] \nLoaded ${results.length} results from namespace: ${namespace}\n\n`;
} catch (error) {
_error = error.toString();
Expand Down Expand Up @@ -112,6 +130,7 @@ export class DataSourceLookup extends Component {

const totalLength = JSON.stringify(results).length;
debugOutput += `[Total Length] \n${totalLength}\n\n`;

return {
Results: results,
_error,
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/Components/ServerlessCode.class.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Joi from 'joi';
import { ConnectorService } from '@sre/Core/ConnectorsService';
import { AWSCredentials, AWSRegionConfig } from '@sre/types/AWS.types';
import { calculateExecutionCost, generateCodeFromLegacyComponent, getLambdaCredentials, reportUsage } from '@sre/helpers/AWSLambdaCode.helper';
import { AccessCandidate } from '@sre/Security/AccessControl/AccessCandidate.class';

export class ServerlessCode extends Component {

Expand Down Expand Up @@ -96,7 +97,7 @@ export class ServerlessCode extends Component {
const cost = calculateExecutionCost(executionTime);
if (!codeCredentials.isUserProvidedKeys) {
const accountConnector = ConnectorService.getAccountConnector();
const agentTeam = await accountConnector.getCandidateTeam(agent.id);
const agentTeam = await accountConnector.getCandidateTeam(AccessCandidate.agent(agent.id));
reportUsage({ cost, agentId: agent.id, teamId: agentTeam });
}

Expand Down