Skip to content

Commit

Permalink
fixes in compare prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
andrefs committed May 11, 2024
1 parent 1815e9c commit 6cea7df
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 117 deletions.
64 changes: 37 additions & 27 deletions src/lib/experiments/compare-prompts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ import query from "./query";
export const name = "compare-prompts";
const description = "Compare the results obtained with different prompts";

const validateSchema = (value: unknown): value is ExpTypes["Data"] =>
const validateSchema = (value: unknown): value is CPExpTypes["Data"] =>
Value.Check(query.responseSchema, value);
interface ExpTypes extends GenericExpTypes {
export interface CPExpTypes extends GenericExpTypes {
Data: Static<typeof query.responseSchema>;
DataSchema: typeof query.responseSchema;
Evaluation: ComparisonGroup[];
Expand All @@ -63,7 +63,7 @@ async function tryResponse(model: Model, prompt: string, params: ModelTool) {
return { result: new NoData(), usage };
}
try {
const got = JSON.parse(data) as ExpTypes["Data"];
const got = JSON.parse(data) as CPExpTypes["Data"];
if (!validateSchema(got)) {
return { result: new JsonSchemaError(data), usage };
}
Expand Down Expand Up @@ -91,7 +91,7 @@ async function getResponse(
addUsage(totalUsage, usage);
if (attemptResult instanceof ValidData) {
logger.info(` ✅ attempt #${failedAttempts.length + 1} succeeded.`);
const res: TrialResult<ExpTypes["Data"]> = {
const res: TrialResult<CPExpTypes["Data"]> = {
prompt,
totalTries: failedAttempts.length + 1,
failedAttempts,
Expand All @@ -109,7 +109,7 @@ async function getResponse(
failedAttempts.push(attemptResult);
}

const res: TrialResult<ExpTypes["Data"]> = {
const res: TrialResult<CPExpTypes["Data"]> = {
prompt,
totalTries: failedAttempts.length,
usage: totalUsage,
Expand All @@ -136,7 +136,7 @@ async function runTrial(vars: ExpVars, maxRetries = 3) {
async function runTrials(
vars: ExpVars,
trials: number
): Promise<TrialsResultData<ExpTypes["Data"]>> {
): Promise<TrialsResultData<CPExpTypes["Data"]>> {
const totalUsage: Usages = {};
logger.info(
`Running experiment ${name} ${trials} times on model ${vars.model.id}.`
Expand Down Expand Up @@ -169,7 +169,7 @@ async function perform(
) {
const trialsRes = await runTrials(vars, trials);

const expData: ExperimentData<ExpTypes> = {
const expData: ExperimentData<CPExpTypes> = {
meta: {
trials,
name,
Expand Down Expand Up @@ -275,9 +275,9 @@ async function performMulti(
* @returns The evaluated scores
* @throws {Error} If more than half of the trials failed to parse
*/
function expEvalScores(exps: ExperimentData<ExpTypes>[]): ExpScore[] {
function expEvalScores(exps: ExperimentData<CPExpTypes>[]): ExpScore[] {
const res = [];
for (const exp of exps) {
for (const [i, exp] of exps.entries()) {
for (const trial of exp.results.raw) {
const lcPairs = trial.prompt.pairs!.map(
p => [p[0].toLowerCase(), p[1].toLowerCase()] as [string, string]
Expand All @@ -286,11 +286,19 @@ function expEvalScores(exps: ExperimentData<ExpTypes>[]): ExpScore[] {
const rawResults: PairScoreList[] = exp.results.raw.map(r => {
return r.data.scores as PairScoreList;
});
const corr = evalScores(lcPairs, exp.variables.dpart, rawResults);
res.push({
variables: exp.variables,
score: corr.pcorr,
});
try {
const corr = evalScores(lcPairs, exp.variables.dpart, rawResults);
res.push({
variables: exp.variables,
score: corr!.pcorr,
});
} catch (e) {
logger.warn(
`Error calculating correlation for expVC ${i} with variables ${JSON.stringify(
getVarIds(exp.variables)
)}: ${e}`
);
}
}
}
return res;
Expand All @@ -306,21 +314,21 @@ function logExpScores(expScores: ExpScore[]) {
}
}

async function evaluate(exps: ExperimentData<ExpTypes>[]) {
async function evaluate(exps: ExperimentData<CPExpTypes>[]) {
const expScores = expEvalScores(exps);
const { varValues, varNames } = calcVarValues(exps);

logExpScores(expScores);

const comparisons: ExpTypes["Evaluation"] = [];
const comparisons: CPExpTypes["Evaluation"] = [];
for (const [i, v1] of varNames.entries()) {
for (const v2 of varNames.slice(i + 1)) {
if (varValues[v1].size === 1 && varValues[v2].size === 1) {
// No need to compare if both variables have only one value
continue;
}
//if (varValues[v1].size === 1 && varValues[v2].size === 1) {
// // No need to compare if both variables have only one value
// continue;
//}

let compGroups = [] as ExpTypes["Evaluation"];
let compGroups = [] as CPExpTypes["Evaluation"];
const fixedNames = varNames.filter(v => v !== v1 && v !== v2);

for (const expScore of expScores) {
Expand All @@ -340,12 +348,14 @@ async function evaluate(exps: ExperimentData<ExpTypes>[]) {
group.data[v1Val][v2Val] = corr;
}

// keep only groups with more than one value for each variable
compGroups = compGroups.filter(
g =>
Object.keys(g.data).length > 1 &&
Object.keys(g.data).every(k => Object.keys(g.data[k]).length > 1)
);
if (compGroups.length > 1) {
// keep only groups with more than one value for each variable
compGroups = compGroups.filter(
g =>
Object.keys(g.data).length > 1 &&
Object.keys(g.data).every(k => Object.keys(g.data[k]).length > 1)
);
}

comparisons.push(...compGroups);
}
Expand Down
Loading

0 comments on commit 6cea7df

Please sign in to comment.