From f8ef1072d210328b0231dae0c15f50d7fce8bba1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Santos?= Date: Wed, 6 Mar 2024 22:32:14 +0000 Subject: [PATCH] some refactoring and kickoff comparePrompts --- src/lib/dataset-adapters/collection.ts | 59 +++++++++--- src/lib/experiments/compare-mc30.ts | 24 ++--- src/lib/experiments/comparePrompts.ts | 127 +++++++++++++++++++++++++ src/lib/experiments/experiment.ts | 77 +++++++++------ src/scripts/dsNameFromDsSample.ts | 12 +-- src/scripts/dsPaperFromDsName.ts | 12 +-- src/scripts/dsSampleFromDsName.ts | 12 +-- src/scripts/dsSampleFromDsSample.ts | 12 +-- src/scripts/dsValuesExactMatches.ts | 12 +-- 9 files changed, 262 insertions(+), 85 deletions(-) create mode 100644 src/lib/experiments/comparePrompts.ts diff --git a/src/lib/dataset-adapters/collection.ts b/src/lib/dataset-adapters/collection.ts index b346f77..75ee91c 100644 --- a/src/lib/dataset-adapters/collection.ts +++ b/src/lib/dataset-adapters/collection.ts @@ -1,35 +1,66 @@ -import * as ds from "punuy-datasets"; +import { DatasetProfile } from "../types"; -interface DatasetColPairs { +export interface MultiDatasetScores { [w1: string]: { [w2: string]: { - [dataset: string]: number[]; + [dataset: string]: number; }; }; } -const col = {} as DatasetColPairs; +export interface DatasetScores { + [w1: string]: { + [w2: string]: number; + }; +} -for (const dataset in ds) { - const d = ds[dataset as keyof typeof ds]; +export async function loadDatasetScores(dsId: string) { + const d = (await import(`punuy-datasets/datasets/${dsId}`)) as DatasetProfile; + const res = {} as DatasetScores; for (const part of d.partitions) { for (const row of part.data) { const w1 = row.term1.toLowerCase(); const w2 = row.term2.toLowerCase(); - - col[w1] = col[w1] || {}; - col[w1][w2] = col[w1][w2] || {}; if ("value" in row && row.value !== undefined) { - col[w1][w2][dataset] = [row.value]; + res[w1] = res[w1] || {}; + res[w1][w2] = row.value; continue; } if ("values" in row && Array.isArray(row.values)) { - col[w1][w2][dataset] = row.values.filter( - v => v !== undefined && v !== null - ) as number[]; + const vals = row.values.filter(v => typeof v === "number") as number[]; + res[w1] = res[w1] || {}; + res[w1][w2] = vals.reduce((a, b) => a + b, 0) / vals.length; } } } + return res; } -export default col; +export async function loadAllDatasetScores() { + const ds = (await import("punuy-datasets")).default; + const res = {} as MultiDatasetScores; + + for (const dataset in ds) { + const d = ds[dataset as keyof typeof ds]; + for (const part of d.partitions) { + for (const row of part.data) { + const w1 = row.term1.toLowerCase(); + const w2 = row.term2.toLowerCase(); + + res[w1] = res[w1] || {}; + res[w1][w2] = res[w1][w2] || {}; + if ("value" in row && row.value !== undefined) { + res[w1][w2][dataset] = row.value; + continue; + } + if ("values" in row && Array.isArray(row.values)) { + const vals = row.values.filter( + v => typeof v === "number" + ) as number[]; + res[w1][w2][dataset] = vals.reduce((a, b) => a + b, 0) / vals.length; + } + } + } + } + return res; +} diff --git a/src/lib/experiments/compare-mc30.ts b/src/lib/experiments/compare-mc30.ts index 54fe727..7a16b39 100644 --- a/src/lib/experiments/compare-mc30.ts +++ b/src/lib/experiments/compare-mc30.ts @@ -9,21 +9,14 @@ import oldFs from "fs"; import { Model, ModelIds, gpt35turbo, gpt4, gpt4turbo } from "../models"; import { JsonSyntaxError } from "../validation"; import logger from "../logger"; - -interface DatasetScores { - [term1: string]: { - [term2: string]: { - [dataset: string]: number; - }; - }; -} +import { MultiDatasetScores } from "../dataset-adapters/collection"; type ModelsResults = { [key in ModelIds]: string[]; }; export const loadDatasetScores = async () => { - const pairs: DatasetScores = {}; + const pairs: MultiDatasetScores = {}; for (const part of mc30.partitions) { for (const entry of part.data) { @@ -117,7 +110,7 @@ export const loadDatasetScores = async () => { return pairs; }; -const getPairs = (scores: DatasetScores) => { +const getPairs = (scores: MultiDatasetScores) => { const pairs: [string, string][] = []; for (const term1 in scores) { @@ -154,6 +147,7 @@ const resultSchema = { }, }; +/** Run a single trial of the experiment, with a single model */ async function runTrialModel(model: Model, prompt: string) { const f = { name: "evaluate_scores", @@ -165,6 +159,7 @@ async function runTrialModel(model: Model, prompt: string) { return res; } +/** Run multiple trials of the experiment, with a single model */ async function runTrialsModel(trials: number, model: Model, prompt: string) { logger.info(` model ${model.modelId}.`); logger.debug(`Prompt: ${prompt}`); @@ -181,6 +176,7 @@ async function runTrialsModel(trials: number, model: Model, prompt: string) { return results; } +/** Run multiple trials of the experiment, with multiple models */ async function runTrials(trials: number) { const scores = await loadDatasetScores(); const pairs = getPairs(scores); @@ -247,7 +243,7 @@ function unzipResults(results: MC30Results) { async function validate( modelsRes: ModelsResults, - humanScores: DatasetScores, + humanScores: MultiDatasetScores, trials: number ) { try { @@ -366,7 +362,11 @@ function calcCorrelation(data: number[][]) { return corrMatrix; } -function mergeResults(modelsRes: ModelsResults, humanScores: DatasetScores) { +/** Merge the results from the models and the human scores */ +function mergeResults( + modelsRes: ModelsResults, + humanScores: MultiDatasetScores +) { const res = {} as MC30Results; try { diff --git a/src/lib/experiments/comparePrompts.ts b/src/lib/experiments/comparePrompts.ts new file mode 100644 index 0000000..8a22388 --- /dev/null +++ b/src/lib/experiments/comparePrompts.ts @@ -0,0 +1,127 @@ +import { ws353 } from "punuy-datasets"; +import { Model, gpt4, gpt4turbo, gpt35turbo } from "../models"; +import logger from "../logger"; +import { + DatasetScores, + loadDatasetScores, +} from "../dataset-adapters/collection"; +import { ExperimentData } from "."; +const name = "compare-prompts"; +const description = "Compare the results obtained with different prompts"; + +interface Prompts { + [key: string]: { + type: "relatedness" | "similarity"; + text: string; + }; +} +const prompts: Prompts = { + simplest: { + type: "relatedness", + text: "Indicate how strongly the words in each pair are related in meaning using integers from 1 to 5, where 1 means very unrelated and 5 means very related.", + }, + simpleScale: { + type: "relatedness", + text: "Indicate how strongly the words in each pair are related in meaning using integers from 1 to 5, where the scale means: 1 - not at all related, 2 - vaguely related, 3 - indirectly related, 4 - strongly related, 5 - inseparably related.", + }, + adaptedWs353: { + type: "relatedness", + text: 'Hello, we kindly ask you to assist us in a psycholinguistic experiment, aimed at estimating the semantic relatedness of various words in the English language. The purpose of this experiment is to assign semantic relatedness scores to pairs of words, so that machine learning algorithms can be subsequently trained and adjusted using human-assigned scores. Below is a list of pairs of words. For each pair, please assign a numerical relatedness score between 1 and 5 (1 = words are totally unrelated, 5 = words are VERY closely related). By definition, the relatedness of the word to itself should be 5. You may assign fractional scores (for example, 3.5). When estimating relatedness of antonyms, consider them "related" (i.e., belonging to the same domain or representing features of the same concept), rather than "unrelated". Thank you for your assistance!', + }, + simlex999: { + type: "similarity", + text: "Two words are synonyms if they have very similar meanings. Synonyms represent the same type or category of thing. Here are some examples of synonym pairs: cup/mug, glasses/spectacles, envy/jealousy. In practice, word pairs that are not exactly synonymous may still be very similar. Here are some very similar pairs - we could say they are nearly synonyms: alligator/crocodile, love / affection, frog/toad. In contrast, although the following word pairs are related, they are not very similar. The words represent entirely different types of thing:car/tyre, car/motorway, car/crash, In this survey, you are asked to compare word pairs and to rate how similar they are by moving a slider. Remember, things that are related are not necessarily similar. If you are ever unsure, think back to the examples of synonymous pairs (glasses/spectacles), and consider how close the words are (or are not) to being synonymous. There is no right answer to these questions. It is perfectly reasonable to use your intuition or gut feeling as a native English speaker, especially when you are asked to rate word pairs that you think are not similar at all.", + }, +}; + +const models = { + gpt35turbo, + gpt4, + gpt4turbo, +}; + +const resultSchema = { + type: "object", + properties: { + scores: { + type: "array", + items: { + type: "object", + properties: { + words: { type: "array", items: { type: "string" } }, + score: { type: "string" }, + }, + }, + }, + }, +}; + +async function runTrialModel(model: Model, dsId: string, promptId: string) { + const f = { + name: "evaluate_scores", + description: "Evaluate the word similarity or relatedness scores", + parameters: resultSchema, + }; + const res = await model.makeRequest(prompts[promptId].text, { function: f }); + return res; +} + +async function runTrialsModel( + trials: number, + model: Model, + dsId: string, + promptId: string +) { + const results = []; + logger.info(` model ${model.modelId}.`); + logger.debug(`Prompt ID: ${promptId}`); + + for (let i = 0; i < trials; i++) { + logger.info(` trial #${i + 1} of ${trials}`); + const res = await runTrialModel(model, dsId, prompts[promptId].text); + results.push( + res.type === "openai" + ? res.data.choices[0].message.tool_calls?.[0].function.arguments || "" + : "" + ); + } + return results; +} + +async function runTrials(trials: number) { + const datasetIds = ["ws353", "simlex999"]; + const datasets: { [key: string]: DatasetScores } = {}; + for (const dsId of datasetIds) { + datasets[dsId] = await loadDatasetScores(dsId); + } + + logger.info( + `Running experiment ${name} with ${trials} trials on models [gpt35turbo, gpt4, gpt4turbo], datasets ${datasetIds} and prompts ${Object.keys( + prompts + )}.` + ); + + for (const modelId in models) { + for (const promptId in prompts) { + for (const dsId in datasets) { + const results = await runTrialsModel( + trials, + models[modelId as keyof typeof models], + dsId, + promptId + ); + const res: ExperimentData = { + variables: { + modelId, + promptId, + dsId: "", + }, + results: { + trial: results, + }, + }; + logger.info(`Results: ${JSON.stringify(res)}`); + } + } + } +} diff --git a/src/lib/experiments/experiment.ts b/src/lib/experiments/experiment.ts index 93915cd..6f9dfb0 100644 --- a/src/lib/experiments/experiment.ts +++ b/src/lib/experiments/experiment.ts @@ -29,8 +29,8 @@ class Experiment { ds: DatasetProfile, data: string[] ) => Promise<{ - trialResults: ValidationResult[]; - combinedResult: AggregatedValidationResult; + validation: ValidationResult[]; + aggregated: AggregatedValidationResult; }>; perform: ( this: Experiment, @@ -72,7 +72,7 @@ class Experiment { ); logger.debug(`Prompt: ${prompt}`); - const results = []; + const results: string[] = []; for (let i = 0; i < trials; i++) { logger.info(` trial #${i + 1} of ${trials}`); const res = await runTrial(prompt, this.schema, ds, model); @@ -91,12 +91,12 @@ class Experiment { ds: DatasetProfile, data: string[] ) { - const trialResults = await Promise.all( + const trialValidationResults = await Promise.all( data.map(d => this.validateTrial(ds, d)) ); return { - trialResults, - combinedResult: await combineValidations(trialResults), + validation: trialValidationResults, + aggregated: await combineValidations(trialValidationResults), }; }; this.perform = async function ( @@ -107,17 +107,24 @@ class Experiment { traceId?: number ): Promise { const results = await this.runTrials(trials, ds, model); - const { trialResults, combinedResult } = await this.validate(ds, results); + const { validation, aggregated } = await this.validate(ds, results); - const expData = { - name: this.name, - traceId: traceId ?? Date.now(), - prompt: this.genPrompt(ds), - schema: this.schema, - dsId: ds.id, - modelId: model.modelId, - trialResults, - combinedResult, + const expData: ExperimentData = { + meta: { + name: this.name, + traceId: traceId ?? Date.now(), + schema: this.schema, + }, + variables: { + prompt: this.genPrompt(ds), + dsId: ds.id, + modelId: model.modelId, + }, + results: { + raw: results, + validation, + aggregated, + }, }; await saveExperimentData(expData); @@ -127,16 +134,20 @@ class Experiment { } export async function saveExperimentData(data: ExperimentData) { - const ts = data.traceId; - const dsId = data.dsId; - const expName = data.name; - const modelId = data.modelId; + const ts = data.meta.traceId; + const dsId = data.variables.dsId; + const expName = data.meta.name; + const modelId = data.variables.modelId; const rootFolder = "./results"; const filename = `${rootFolder}/${ts}_${expName}_${dsId}_${modelId}.json`; const json = JSON.stringify(data, null, 2); logger.info( - `Saving experiment ${data.name} ${data.trialResults.length} times on model ${data.modelId} to ${filename}.` + `Saving experiment ${data.meta.name} which ran ${ + data.results.raw.length + } times on ${JSON.stringify(data.variables)} with traceId ${ + data.meta.traceId + } to ${filename}.` ); if (!oldFs.existsSync(rootFolder)) { @@ -147,14 +158,22 @@ export async function saveExperimentData(data: ExperimentData) { } export interface ExperimentData { - name: string; - traceId: number; - prompt: string; - schema: any; // eslint-disable-line @typescript-eslint/no-explicit-any - dsId: string; - modelId: string; - trialResults: ValidationResult[]; - combinedResult: AggregatedValidationResult; + variables: { + dsId: string; + modelId: string; + prompt?: string; + promptId?: string; + }; + meta: { + name: string; + traceId: number; + schema: any; // eslint-disable-line @typescript-eslint/no-explicit-any + }; + results: { + raw: string[]; + validation: ValidationResult[]; + aggregated: AggregatedValidationResult; + }; } export interface AggregatedValidationResult { diff --git a/src/scripts/dsNameFromDsSample.ts b/src/scripts/dsNameFromDsSample.ts index b6b2214..ed05e4b 100644 --- a/src/scripts/dsNameFromDsSample.ts +++ b/src/scripts/dsNameFromDsSample.ts @@ -29,16 +29,16 @@ const nameFromSample = async (ds: DatasetProfile) => { ); logger.info( - { ...gpt35turbo_res.combinedResult.resultTypes }, - `gpt35turbo_res ${gpt35turbo_res.combinedResult.avg}` + { ...gpt35turbo_res.results.aggregated.resultTypes }, + `gpt35turbo_res ${gpt35turbo_res.results.aggregated.avg}` ); logger.info( - { ...gpt4_res.combinedResult.resultTypes }, - `gpt4_res ${gpt4_res.combinedResult.avg}` + { ...gpt4_res.results.aggregated.resultTypes }, + `gpt4_res ${gpt4_res.results.aggregated.avg}` ); logger.info( - { ...gpt4turbo_res.combinedResult.resultTypes }, - `gpt4turbo_res ${gpt4turbo_res.combinedResult.avg}` + { ...gpt4turbo_res.results.aggregated.resultTypes }, + `gpt4turbo_res ${gpt4turbo_res.results.aggregated.avg}` ); }; diff --git a/src/scripts/dsPaperFromDsName.ts b/src/scripts/dsPaperFromDsName.ts index cff4533..dc6485d 100644 --- a/src/scripts/dsPaperFromDsName.ts +++ b/src/scripts/dsPaperFromDsName.ts @@ -24,16 +24,16 @@ const paperFromName = async (ds: DatasetProfile) => { ); logger.info( - { ...gpt35turbo_res.combinedResult.resultTypes }, - `gpt35turbo_res ${gpt35turbo_res.combinedResult.avg}` + { ...gpt35turbo_res.results.aggregated.resultTypes }, + `gpt35turbo_res ${gpt35turbo_res.results.aggregated.avg}` ); logger.info( - { ...gpt4_res.combinedResult.resultTypes }, - `gpt4_res ${gpt4_res.combinedResult.avg}` + { ...gpt4_res.results.aggregated.resultTypes }, + `gpt4_res ${gpt4_res.results.aggregated.avg}` ); logger.info( - { ...gpt4turbo_res.combinedResult.resultTypes }, - `gpt4turbo_res ${gpt4turbo_res.combinedResult.avg}` + { ...gpt4turbo_res.results.aggregated.resultTypes }, + `gpt4turbo_res ${gpt4turbo_res.results.aggregated.avg}` ); }; diff --git a/src/scripts/dsSampleFromDsName.ts b/src/scripts/dsSampleFromDsName.ts index 5573d7d..8895f99 100644 --- a/src/scripts/dsSampleFromDsName.ts +++ b/src/scripts/dsSampleFromDsName.ts @@ -29,16 +29,16 @@ const sampleFromName = async (ds: DatasetProfile) => { ); logger.info( - { ...gpt35turbo_res.combinedResult.resultTypes }, - `gpt35turbo_res ${gpt35turbo_res.combinedResult.avg}` + { ...gpt35turbo_res.results.aggregated.resultTypes }, + `gpt35turbo_res ${gpt35turbo_res.results.aggregated.avg}` ); logger.info( - { ...gpt4_res.combinedResult.resultTypes }, - `gpt4_res ${gpt4_res.combinedResult.avg}` + { ...gpt4_res.results.aggregated.resultTypes }, + `gpt4_res ${gpt4_res.results.aggregated.avg}` ); logger.info( - { ...gpt4turbo_res.combinedResult.resultTypes }, - `gpt4turbo_res ${gpt4turbo_res.combinedResult.avg}` + { ...gpt4turbo_res.results.aggregated.resultTypes }, + `gpt4turbo_res ${gpt4turbo_res.results.aggregated.avg}` ); }; diff --git a/src/scripts/dsSampleFromDsSample.ts b/src/scripts/dsSampleFromDsSample.ts index b65c638..6834565 100644 --- a/src/scripts/dsSampleFromDsSample.ts +++ b/src/scripts/dsSampleFromDsSample.ts @@ -29,16 +29,16 @@ const sampleFromSample = async (ds: DatasetProfile) => { ); logger.info( - { ...gpt35turbo_res.combinedResult.resultTypes }, - `gpt35turbo_res ${gpt35turbo_res.combinedResult.avg}` + { ...gpt35turbo_res.results.aggregated.resultTypes }, + `gpt35turbo_res ${gpt35turbo_res.results.aggregated.avg}` ); logger.info( - { ...gpt4_res.combinedResult.resultTypes }, - `gpt4_res ${gpt4_res.combinedResult.avg}` + { ...gpt4_res.results.aggregated.resultTypes }, + `gpt4_res ${gpt4_res.results.aggregated.avg}` ); logger.info( - { ...gpt4turbo_res.combinedResult.resultTypes }, - `gpt4turbo_res ${gpt4turbo_res.combinedResult.avg}` + { ...gpt4turbo_res.results.aggregated.resultTypes }, + `gpt4turbo_res ${gpt4turbo_res.results.aggregated.avg}` ); }; diff --git a/src/scripts/dsValuesExactMatches.ts b/src/scripts/dsValuesExactMatches.ts index c021d0a..2c6355b 100644 --- a/src/scripts/dsValuesExactMatches.ts +++ b/src/scripts/dsValuesExactMatches.ts @@ -29,16 +29,16 @@ const valuesExactMatch = async (ds: DatasetProfile) => { ); logger.info( - { ...gpt35turbo_res.combinedResult.resultTypes }, - `gpt35turbo_res: ${gpt35turbo_res.combinedResult.avg.toFixed(2)}%` + { ...gpt35turbo_res.results.aggregated.resultTypes }, + `gpt35turbo_res: ${gpt35turbo_res.results.aggregated.avg.toFixed(2)}%` ); logger.info( - { ...gpt4_res.combinedResult.resultTypes }, - `gpt4_res: ${gpt4_res.combinedResult.avg.toFixed(2)}%` + { ...gpt4_res.results.aggregated.resultTypes }, + `gpt4_res: ${gpt4_res.results.aggregated.avg.toFixed(2)}%` ); logger.info( - { ...gpt4turbo_res.combinedResult.resultTypes }, - `gpt4turbo_res: ${gpt4turbo_res.combinedResult.avg.toFixed(2)}%` + { ...gpt4turbo_res.results.aggregated.resultTypes }, + `gpt4turbo_res: ${gpt4turbo_res.results.aggregated.avg.toFixed(2)}%` ); };