Skip to content

Commit

Permalink
pass prompt to eval trial
Browse files Browse the repository at this point in the history
  • Loading branch information
andrefs committed May 3, 2024
1 parent 61c5661 commit 0b85837
Show file tree
Hide file tree
Showing 14 changed files with 168 additions and 82 deletions.
30 changes: 22 additions & 8 deletions src/lib/experiments/compare-mc30/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ import {
import logger from "../../logger";
import { DsPartition } from "src/lib/dataset-partitions/DsPartition";
import query, { QueryResponse } from "./query";
import { MultiDatasetScores, TrialResult, Usages } from "../experiment/types";
import {
MultiDatasetScores,
Prompt,
TrialResult,
Usages,
} from "../experiment/types";
import { renderTable } from "console-table-printer";
import { addUsage } from "../experiment/aux";

Expand Down Expand Up @@ -142,9 +147,15 @@ const getPairs = (scores: MultiDatasetScores) => {
const name = "compare-mc30";
const description =
"Compare the scores of multiple AI models with the scores from multiple human annotations of the MC30 pair set.";
const genPrompt = (pairs: string[][]) =>
'Please rate the similarity of the following pairs of words on a scale of 0 to 4, where 0 means "completely unrelated" and 4 means "very similar". Fractional values are allowed.\n\n' +
pairs.map(([w1, w2]) => `${w1} ${w2}`).join("\n");
const genPrompt = (pairs: [string, string][]): Prompt => ({
id: "compare-mc30-prompt",
language: "en",
type: "similarity",
pairs,
text:
'Please rate the similarity of the following pairs of words on a scale of 0 to 4, where 0 means "completely unrelated" and 4 means "very similar". Fractional values are allowed.\n\n' +
pairs.map(([w1, w2]) => `${w1} ${w2}`).join("\n"),
});

async function tryResponse(model: Model, prompt: string, params: ModelTool) {
let result;
Expand Down Expand Up @@ -183,7 +194,7 @@ async function tryResponse(model: Model, prompt: string, params: ModelTool) {

async function getResponse(
model: Model,
prompt: string,
prompt: Prompt,
tool: ModelTool,
maxRetries: number = 3
) {
Expand All @@ -193,13 +204,14 @@ async function getResponse(
logger.info(` attempt #${failedAttempts.length + 1}`);
const { result: attemptResult, usage } = await tryResponse(
model,
prompt,
prompt.text,
tool
);
addUsage(totalUsage, usage);
if (attemptResult instanceof ValidData) {
logger.info(` ✅ attempt #${failedAttempts.length + 1} succeeded.`);
const res: TrialResult<QueryResponse> = {
prompt,
totalTries: failedAttempts.length + 1,
failedAttempts,
ok: true,
Expand All @@ -217,6 +229,7 @@ async function getResponse(
}

const res: TrialResult<QueryResponse> = {
prompt,
totalTries: failedAttempts.length,
usage: totalUsage,
failedAttempts,
Expand All @@ -226,19 +239,20 @@ async function getResponse(
}

/** Run a single trial of the experiment, with a single model */
async function runTrialModel(model: Model, prompt: string, maxRetries = 3) {
async function runTrialModel(model: Model, prompt: Prompt, maxRetries = 3) {
const tool: ModelTool = {
name: "evaluate_scores",
description: "Evaluate the word similarity scores.",
schema: query.toolSchema,
};

logger.debug(`Prompt (${prompt.id}): ${prompt.text}`);
const res = await getResponse(model, prompt, tool, maxRetries);
return res;
}

/** Run multiple trials of the experiment, with a single model */
async function runTrials(trials: number, model: Model, prompt: string) {
async function runTrials(trials: number, model: Model, prompt: Prompt) {
const totalUsage: Usages = {};
logger.info(
`Running experiment ${name} ${trials} times on model ${model.id}.`
Expand Down
34 changes: 20 additions & 14 deletions src/lib/experiments/compare-prompts/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@ describe("comparePrompts", () => {
results: {
raw: [
{
scores: [
{ words: ["testWord1", "testWord2"], score: 0.5 },
{ words: ["testWord3", "testWord4"], score: 0.9 },
{ words: ["testWord5", "testWord6"], score: 0.9 },
{ words: ["testWord7", "testWord8"], score: 0.9 },
{ words: ["testWord9", "testWord10"], score: 0.9 },
],
prompt: {} as Prompt,
data: {
scores: [
{ words: ["testWord1", "testWord2"], score: 0.5 },
{ words: ["testWord3", "testWord4"], score: 0.9 },
{ words: ["testWord5", "testWord6"], score: 0.9 },
{ words: ["testWord7", "testWord8"], score: 0.9 },
{ words: ["testWord9", "testWord10"], score: 0.9 },
],
},
},
],
},
Expand Down Expand Up @@ -83,13 +86,16 @@ describe("comparePrompts", () => {
results: {
raw: [
{
scores: [
{ words: ["testWord1", "testWord2"], score: 0.5 },
{ words: ["testWord3", "testWord4"], score: 0.9 },
{ words: ["testWord5", "testWord6"], score: 0.9 },
{ words: ["testWord7", "testWord8"], score: 0.9 },
{ words: ["testWord9", "testWord10"], score: 0.9 },
],
prompt: {} as Prompt,
data: {
scores: [
{ words: ["testWord1", "testWord2"], score: 0.5 },
{ words: ["testWord3", "testWord4"], score: 0.9 },
{ words: ["testWord5", "testWord6"], score: 0.9 },
{ words: ["testWord7", "testWord8"], score: 0.9 },
{ words: ["testWord9", "testWord10"], score: 0.9 },
],
},
},
],
},
Expand Down
41 changes: 24 additions & 17 deletions src/lib/experiments/compare-prompts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import {
NoData,
ValidData,
} from "src/lib/evaluation";
import { ExpScore, PairScoreList, Usages } from "../experiment/types";
import { ExpScore, PairScoreList, Prompt, Usages } from "../experiment/types";
import query from "./query";
export const name = "compare-prompts";
const description = "Compare the results obtained with different prompts";
Expand Down Expand Up @@ -75,7 +75,7 @@ async function tryResponse(model: Model, prompt: string, params: ModelTool) {

async function getResponse(
model: Model,
prompt: string,
prompt: Prompt,
tool: ModelTool,
maxRetries = 3
) {
Expand All @@ -85,13 +85,14 @@ async function getResponse(
logger.info(` 💪 attempt #${failedAttempts.length + 1}`);
const { result: attemptResult, usage } = await tryResponse(
model,
prompt,
prompt.text,
tool
);
addUsage(totalUsage, usage);
if (attemptResult instanceof ValidData) {
logger.info(` ✅ attempt #${failedAttempts.length + 1} succeeded.`);
const res: TrialResult<ExpTypes["Data"]> = {
prompt,
totalTries: failedAttempts.length + 1,
failedAttempts,
ok: true,
Expand All @@ -109,6 +110,7 @@ async function getResponse(
}

const res: TrialResult<ExpTypes["Data"]> = {
prompt,
totalTries: failedAttempts.length,
usage: totalUsage,
failedAttempts,
Expand All @@ -117,37 +119,42 @@ async function getResponse(
return res;
}

async function runTrial(vars: ExpVarsFixedPrompt, maxRetries = 3) {
async function runTrial(vars: ExpVars, maxRetries = 3) {
const tool = {
name: "evaluate_scores",
description: "Evaluate the word similarity or relatedness scores",
schema: query.toolSchema,
};
const prompt =
"generate" in vars.prompt ? vars.prompt.generate(vars) : vars.prompt;
logger.debug(`Prompt (${prompt.id}): ${prompt.text}`);

const res = await getResponse(vars.model, vars.prompt.text, tool, maxRetries);
const res = await getResponse(vars.model, prompt, tool, maxRetries);
return res;
}

async function runTrials(
vars: ExpVarsFixedPrompt,
vars: ExpVars,
trials: number
): Promise<TrialsResultData<ExpTypes["Data"]>> {
logger.info(
`Running experiment ${name} ${trials} times on model ${vars.model.id}.`
);
logger.debug(`Prompt (${vars.prompt.id}): ${vars.prompt.text}`);

const results: ExpTypes["Data"][] = [];
const results = [];
for (let i = 0; i < trials; i++) {
logger.info(` ⚔️ trial #${i + 1} of ${trials}`);
const res = await runTrial(vars);
if (res.ok && res.result) {
results.push(res.result.data);
results.push({
data: res.result.data,
prompt: res.prompt,
});
}
}
return {
variables: vars,
data: results,
trials: results,
};
}

Expand All @@ -157,10 +164,10 @@ async function perform(
traceId: number,
folder: string
) {
const prompt =
"generate" in vars.prompt ? vars.prompt.generate(vars) : vars.prompt;
const varsFixedPrompt = { ...vars, prompt } as ExpVarsFixedPrompt;
const trialsRes = await runTrials(varsFixedPrompt, trials);
//const prompt =
// "generate" in vars.prompt ? vars.prompt.generate(vars) : vars.prompt;
//const varsFixedPrompt = { ...vars, prompt } as ExpVarsFixedPrompt;
const trialsRes = await runTrials(vars, trials);

const expData: ExperimentData<ExpTypes> = {
meta: {
Expand All @@ -169,9 +176,9 @@ async function perform(
traceId,
queryData: query,
},
variables: varsFixedPrompt,
variables: vars,
results: {
raw: trialsRes.data,
raw: trialsRes.trials,
},
};

Expand Down Expand Up @@ -273,7 +280,7 @@ function expEvalScores(exps: ExperimentData<ExpTypes>[]): ExpScore[] {
);

const rawResults: PairScoreList[] = exp.results.raw.map(r => {
return r.scores as PairScoreList;
return r.data.scores as PairScoreList;
});
const corr = evalScores(lcPairs, exp.variables.dpart, rawResults);
res.push({
Expand Down
17 changes: 13 additions & 4 deletions src/lib/experiments/experiment/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ export default class Experiment<T extends GenericExpTypes> {
) => Promise<TrialsResultData<T["Data"]>>;
evaluateTrial: (
dpart: DsPartition,
prompt: Prompt,
got: T["Data"]
) => Promise<EvaluationResult<T["Data"], T["Evaluation"]>>;
evaluate: (exp: ExperimentData<T>) => Promise<{
Expand Down Expand Up @@ -128,6 +129,7 @@ export default class Experiment<T extends GenericExpTypes> {
) => Promise<TrialResult<T["Data"]>>,
evaluateTrial: (
dpart: DsPartition,
prompt: Prompt,
got: T["Data"]
) => Promise<EvaluationResult<T["Data"], T["Evaluation"]>>,
expDataToExpScore?: (
Expand Down Expand Up @@ -164,6 +166,7 @@ export default class Experiment<T extends GenericExpTypes> {
if (attemptResult instanceof ValidData) {
logger.info(` ✅ attempt #${faCount} succeeded.`);
const res: TrialResult<T["Data"]> = {
prompt: vars.prompt,
totalTries: failedAttempts.length + 1,
failedAttempts,
ok: true,
Expand All @@ -190,6 +193,7 @@ export default class Experiment<T extends GenericExpTypes> {
}

const res: TrialResult<T["Data"]> = {
prompt: vars.prompt,
totalTries: failedAttempts.length,
usage: totalUsage,
failedAttempts,
Expand Down Expand Up @@ -262,13 +266,16 @@ export default class Experiment<T extends GenericExpTypes> {
);
addUsage(totalUsage, res.usage);
if (res.ok) {
results.push(res.result!.data); // TODO: handle failed attempts
results.push({
data: res.result!.data,
prompt: res.prompt,
}); // TODO: handle failed attempts
}
}
return {
variables: vars,
usage: totalUsage,
data: results,
trials: results,
};
};
this.evaluateTrial = evaluateTrial;
Expand All @@ -277,7 +284,9 @@ export default class Experiment<T extends GenericExpTypes> {
exp: ExperimentData<T>
) {
const trialEvaluationResults = await Promise.all(
exp.results.raw.map(d => this.evaluateTrial(exp.variables.dpart, d))
exp.results.raw.map(d =>
this.evaluateTrial(exp.variables.dpart, d.prompt, d.data)
)
);
return {
evaluation: trialEvaluationResults,
Expand All @@ -303,7 +312,7 @@ export default class Experiment<T extends GenericExpTypes> {
variables: vars,
usage: trialsRes.usage,
results: {
raw: trialsRes.data,
raw: trialsRes.trials,
},
};
const { evaluation, aggregated } = await this.evaluate(expData);
Expand Down
11 changes: 9 additions & 2 deletions src/lib/experiments/experiment/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ export interface ExpMeta<T extends GenericExpTypes> {

export interface ExpResults<DataType, ExpectedType> {
/** Raw results from the trials */
raw: DataType[];
raw: {
data: DataType;
prompt: Prompt;
}[];
/** Evaluation results for each trial */
evaluation?: EvaluationResult<DataType, ExpectedType>[];
/** Aggregated evaluation results */
Expand All @@ -89,6 +92,7 @@ export interface ExperimentData<T extends GenericExpTypes> {
}

export interface TrialResult<DataType> {
prompt: Prompt;
totalTries: number;
failedAttempts: ValidationResult<DataType>[];
ok: boolean;
Expand All @@ -99,7 +103,10 @@ export interface TrialResult<DataType> {
export interface TrialsResultData<DataType> {
variables: ExpVars;
usage?: Usages;
data: DataType[];
trials: {
data: DataType;
prompt: Prompt;
}[];
}

export interface AggregatedEvaluationResult {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { describe, expect, it } from "vitest";
import { createMockDsPart, createMockModel } from "../mocks";
import dsNameFromDsSample from ".";
import { ExpVarsFixedPrompt, PromptGenerator } from "../..";
import { ExpVarsFixedPrompt, Prompt, PromptGenerator } from "../..";
import { DsPartition } from "../../../dataset-partitions/DsPartition";

describe("dsNameFromDsSample", () => {
Expand Down Expand Up @@ -36,6 +36,7 @@ describe("dsNameFromDsSample", () => {
const mockDsPartition = createMockDsPart();
const result = await dsNameFromDsSample.evaluateTrial(
createMockDsPart(),
{} as Prompt,
{
name: mockDsPartition.dataset.metadata.name,
year: "2021",
Expand All @@ -49,6 +50,7 @@ describe("dsNameFromDsSample", () => {
const mockDsPartition = createMockDsPart();
const result = await dsNameFromDsSample.evaluateTrial(
createMockDsPart(),
{} as Prompt,
{
name: mockDsPartition.dataset.metadata.name,
year: "2021",
Expand Down
Loading

0 comments on commit 0b85837

Please sign in to comment.