Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
andrefs committed Mar 6, 2024
1 parent 6eb61f9 commit a02bada
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
DatasetScores,
loadDatasetScores,
} from "../dataset-adapters/collection";
import { ExperimentData } from ".";
import { ExperimentData, TrialsResult } from ".";
const name = "compare-prompts";
const description = "Compare the results obtained with different prompts";

Expand Down Expand Up @@ -88,7 +88,7 @@ async function runTrialsModel(
return results;
}

async function runTrials(trials: number) {
async function runTrials(trials: number): TrialsRes[] {

Check failure on line 91 in src/lib/experiments/compare-prompts.ts

View workflow job for this annotation

GitHub Actions / build (18.x)

Cannot find name 'TrialsRes'. Did you mean 'TrialsResult'?

Check failure on line 91 in src/lib/experiments/compare-prompts.ts

View workflow job for this annotation

GitHub Actions / build (18.x)

The return type of an async function or method must be the global Promise<T> type. Did you mean to write 'Promise<TrialsRes[]>'?

Check failure on line 91 in src/lib/experiments/compare-prompts.ts

View workflow job for this annotation

GitHub Actions / build (18.x)

A function whose declared type is neither 'undefined', 'void', nor 'any' must return a value.

Check failure on line 91 in src/lib/experiments/compare-prompts.ts

View workflow job for this annotation

GitHub Actions / build (21.x)

Cannot find name 'TrialsRes'. Did you mean 'TrialsResult'?

Check failure on line 91 in src/lib/experiments/compare-prompts.ts

View workflow job for this annotation

GitHub Actions / build (21.x)

The return type of an async function or method must be the global Promise<T> type. Did you mean to write 'Promise<TrialsRes[]>'?

Check failure on line 91 in src/lib/experiments/compare-prompts.ts

View workflow job for this annotation

GitHub Actions / build (21.x)

A function whose declared type is neither 'undefined', 'void', nor 'any' must return a value.
const datasetIds = ["ws353", "simlex999"];
const datasets: { [key: string]: DatasetScores } = {};
for (const dsId of datasetIds) {
Expand All @@ -101,27 +101,39 @@ async function runTrials(trials: number) {
)}.`
);

const res: TrialsResult[] = [];
for (const modelId in models) {
for (const promptId in prompts) {
for (const dsId in datasets) {
const results = await runTrialsModel(
const trialsRes = await runTrialsModel(
trials,
models[modelId as keyof typeof models],
dsId,
promptId
);
const res: ExperimentData = {
res.push({
variables: {
modelId,
promptId,
dsId: "",
dsId,
},
results: {
trial: results,
},
};
logger.info(`Results: ${JSON.stringify(res)}`);
data: trialsRes,
});
}
}
}
logger.info(`Results: ${JSON.stringify(res)}`);
}

async function validate(trialsRes: TrialsResult[]) {}

const ComparePromptsExperiment = {
name,
description,
prompts,
schema: resultSchema,
runTrials,
validate,
};

export default ComparePromptsExperiment;
70 changes: 44 additions & 26 deletions src/lib/experiments/experiment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ class Experiment {
trials: number,
ds: DatasetProfile,
model: Model
) => Promise<string[]>;
) => Promise<TrialsResult>;
validateTrial: (
ds: DatasetProfile,
data: string
) => Promise<ValidationResult>;
validate: (
ds: DatasetProfile,
data: string[]
tr: TrialsResult
) => Promise<{
validation: ValidationResult[];
aggregated: AggregatedValidationResult;
Expand All @@ -50,7 +50,7 @@ class Experiment {
schema: any, // eslint-disable-line @typescript-eslint/no-explicit-any
ds: DatasetProfile,
model: Model
) => Promise<TrialResult>,
) => Promise<ModelResponse>,
validateTrial: (
ds: DatasetProfile,
data: string
Expand Down Expand Up @@ -83,16 +83,23 @@ class Experiment {
: ""
);
}
return results;
return {
variables: {
dsId: ds.id,
modelId: model.modelId,
prompt: prompt,
},
data: results,
};
};
this.validateTrial = validateTrial;
this.validate = async function (
this: Experiment,
ds: DatasetProfile,
data: string[]
tr: TrialsResult
) {
const trialValidationResults = await Promise.all(
data.map(d => this.validateTrial(ds, d))
tr.data.map(d => this.validateTrial(ds, d))
);
return {
validation: trialValidationResults,
Expand All @@ -106,8 +113,8 @@ class Experiment {
model: Model,
traceId?: number
): Promise<ExperimentData> {
const results = await this.runTrials(trials, ds, model);
const { validation, aggregated } = await this.validate(ds, results);
const trialsRes = await this.runTrials(trials, ds, model);
const { validation, aggregated } = await this.validate(ds, trialsRes);

const expData: ExperimentData = {
meta: {
Expand All @@ -121,7 +128,7 @@ class Experiment {
modelId: model.modelId,
},
results: {
raw: results,
raw: trialsRes.data,
validation,
aggregated,
},
Expand Down Expand Up @@ -157,23 +164,34 @@ export async function saveExperimentData(data: ExperimentData) {
await fs.writeFile(filename, json);
}

export interface ExperimentVariables {
dsId: string;
modelId: string;
prompt?: string;
promptId?: string;
}

export interface ExperimentMeta {
name: string;
traceId: number;
schema: any; // eslint-disable-line @typescript-eslint/no-explicit-any
}

export interface ExperimentResults {
raw: string[];
validation: ValidationResult[];
aggregated: AggregatedValidationResult;
}

export interface ExperimentData {
variables: {
dsId: string;
modelId: string;
prompt?: string;
promptId?: string;
};
meta: {
name: string;
traceId: number;
schema: any; // eslint-disable-line @typescript-eslint/no-explicit-any
};
results: {
raw: string[];
validation: ValidationResult[];
aggregated: AggregatedValidationResult;
};
variables: ExperimentVariables;
meta: ExperimentMeta;
results: ExperimentResults;
}

export interface TrialsResult {
variables: ExperimentVariables;
data: string[];
}

export interface AggregatedValidationResult {
Expand All @@ -183,7 +201,7 @@ export interface AggregatedValidationResult {
};
}

export type TrialResult = {
export type ModelResponse = {
type: "openai";
data: OpenAI.Chat.Completions.ChatCompletion;
};
Expand Down
6 changes: 3 additions & 3 deletions src/lib/models/model.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { OpenAIModelParams } from "./openai";
import { TrialResult } from "../experiments";
import { ModelResponse } from "../experiments";

export class Model {
modelId: string;
makeRequest: (prompt: string, params: ModelParams) => Promise<TrialResult>;
makeRequest: (prompt: string, params: ModelParams) => Promise<ModelResponse>;

constructor(
modelId: string,
makeRequest: (prompt: string, params: ModelParams) => Promise<TrialResult>
makeRequest: (prompt: string, params: ModelParams) => Promise<ModelResponse>
) {
this.modelId = modelId;
this.makeRequest = makeRequest;
Expand Down

0 comments on commit a02bada

Please sign in to comment.