Skip to content

Commit

Permalink
feat: OpenAIGeneration model for embedder (#9474)
Browse files Browse the repository at this point in the history
  • Loading branch information
jordanh committed Feb 29, 2024
1 parent 9c44e23 commit 807e347
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 26 deletions.
2 changes: 0 additions & 2 deletions packages/embedder/ai_models/AbstractModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ export interface GenerationModelConfig extends ModelConfig {}

export abstract class AbstractModel {
public readonly url?: string
public modelInstance: any

constructor(config: ModelConfig) {
this.url = this.normalizeUrl(config.url)
Expand Down Expand Up @@ -57,7 +56,6 @@ export interface GenerationOptions {
temperature?: number
topK?: number
topP?: number
truncate?: boolean
}

export abstract class AbstractGenerationModel extends AbstractModel {
Expand Down
9 changes: 6 additions & 3 deletions packages/embedder/ai_models/ModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
GenerationModelConfig,
ModelConfig
} from './AbstractModel'
import OpenAIGeneration from './OpenAIGeneration'
import TextEmbeddingsInference from './TextEmbeddingsInference'
import TextGenerationInference from './TextGenerationInference'

Expand All @@ -16,7 +17,7 @@ interface ModelManagerConfig {
}

export type EmbeddingsModelType = 'text-embeddings-inference'
export type GenerationModelType = 'text-generation-inference'
export type GenerationModelType = 'openai' | 'text-generation-inference'

export class ModelManager {
embeddingModels: AbstractEmbeddingsModel[]
Expand Down Expand Up @@ -80,9 +81,11 @@ export class ModelManager {
const [modelType, _] = modelConfig.model.split(':') as [GenerationModelType, string]

switch (modelType) {
case 'openai': {
return new OpenAIGeneration(modelConfig)
}
case 'text-generation-inference': {
const generator = new TextGenerationInference(modelConfig)
return generator
return new TextGenerationInference(modelConfig)
}
default:
throw new Error(`unsupported summarization model '${modelType}'`)
Expand Down
94 changes: 94 additions & 0 deletions packages/embedder/ai_models/OpenAIGeneration.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import OpenAI from 'openai'
import {
AbstractGenerationModel,
GenerationModelConfig,
GenerationModelParams,
GenerationOptions
} from './AbstractModel'

const MAX_REQUEST_TIME_S = 3 * 60

export type ModelId = 'gpt-3.5-turbo-0125' | 'gpt-4-turbo-preview'

type OpenAIGenerationOptions = Omit<GenerationOptions, 'topK'>

const modelIdDefinitions: Record<ModelId, GenerationModelParams> = {
'gpt-3.5-turbo-0125': {
maxInputTokens: 4096
},
'gpt-4-turbo-preview': {
maxInputTokens: 128000
}
}

function isValidModelId(object: any): object is ModelId {
return Object.keys(modelIdDefinitions).includes(object)
}

export class OpenAIGeneration extends AbstractGenerationModel {
private openAIApi: OpenAI | null
private modelId: ModelId

constructor(config: GenerationModelConfig) {
super(config)
if (!process.env.OPEN_AI_API_KEY) {
this.openAIApi = null
return
}
this.openAIApi = new OpenAI({
apiKey: process.env.OPEN_AI_API_KEY,
organization: process.env.OPEN_AI_ORG_ID
})
}

async summarize(content: string, options: OpenAIGenerationOptions) {
if (!this.openAIApi) {
const eMsg = 'OpenAI is not configured'
console.log('OpenAIGenerationSummarizer.summarize(): ', eMsg)
throw new Error(eMsg)
}
const {maxNewTokens: max_tokens = 512, seed, stop, temperature = 0.8, topP: top_p} = options
const prompt = `Create a brief, one-paragraph summary of the following: ${content}`

try {
const response = await this.openAIApi.chat.completions.create({
frequency_penalty: 0,
max_tokens,
messages: [
{
role: 'user',
content: prompt
}
],
model: this.modelId,
presence_penalty: 0,
temperature,
seed,
stop,
top_p
})
const maybeSummary = response.choices[0]?.message?.content?.trim()
if (!maybeSummary) throw new Error('OpenAI returned empty summary')
return maybeSummary
} catch (e) {
console.log('OpenAIGenerationSummarizer.summarize(): ', e)
throw e
}
}
protected constructModelParams(config: GenerationModelConfig): GenerationModelParams {
const modelConfigStringSplit = config.model.split(':')
if (modelConfigStringSplit.length != 2) {
throw new Error('OpenAIGeneration model string must be colon-delimited and len 2')
}

const maybeModelId = modelConfigStringSplit[1]
if (!isValidModelId(maybeModelId))
throw new Error(`OpenAIGeneration model id unknown: ${maybeModelId}`)

this.modelId = maybeModelId

return modelIdDefinitions[maybeModelId]
}
}

export default OpenAIGeneration
2 changes: 1 addition & 1 deletion packages/embedder/ai_models/TextEmbeddingsInference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class TextEmbeddingsInference extends AbstractEmbeddingsModel {
if (!this.url) throw new Error('TextGenerationInferenceSummarizer model requires url')
const maybeModelId = modelConfigStringSplit[1]
if (!isValidModelId(maybeModelId))
throw new Error(`TextGenerationInference model subtype unknown: ${maybeModelId}`)
throw new Error(`TextGenerationInference model id unknown: ${maybeModelId}`)
return modelIdDefinitions[maybeModelId]
}
}
Expand Down
26 changes: 9 additions & 17 deletions packages/embedder/ai_models/TextGenerationInference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,16 @@ export class TextGenerationInference extends AbstractGenerationModel {
super(config)
}

public async summarize(content: string, options: GenerationOptions) {
const {
maxNewTokens: max_new_tokens = 512,
seed,
stop,
temperature = 0.8,
topP,
topK,
truncate
} = options
async summarize(content: string, options: GenerationOptions) {
const {maxNewTokens: max_new_tokens = 512, seed, stop, temperature = 0.8, topP, topK} = options
const parameters = {
max_new_tokens,
seed,
stop,
temperature,
topP,
topK,
truncate
truncate: true
}
const prompt = `Create a brief, one-paragraph summary of the following: ${content}`
const fetchOptions = {
Expand All @@ -59,27 +51,27 @@ export class TextGenerationInference extends AbstractGenerationModel {
}

try {
// console.log(`TextGenerationInterface.summarize(): summarizing from ${this.url}/generate`)
// console.log(`TextGenerationInference.summarize(): summarizing from ${this.url}/generate`)
const res = await fetchWithRetry(`${this.url}/generate`, fetchOptions)
const json = await res.json()
if (!json || !json.generated_text)
throw new Error('TextGenerationInterface.summarize(): malformed response')
throw new Error('TextGenerationInference.summarize(): malformed response')
return json.generated_text as string
} catch (e) {
console.log('TextGenerationInterfaceSummarizer.summarize(): timeout')
console.log('TextGenerationInferenceSummarizer.summarize(): timeout')
throw e
}
}
protected constructModelParams(config: GenerationModelConfig): GenerationModelParams {
const modelConfigStringSplit = config.model.split(':')
if (modelConfigStringSplit.length != 2) {
throw new Error('TextGenerationInterface model string must be colon-delimited and len 2')
throw new Error('TextGenerationInference model string must be colon-delimited and len 2')
}

if (!this.url) throw new Error('TextGenerationInterfaceSummarizer model requires url')
if (!this.url) throw new Error('TextGenerationInferenceSummarizer model requires url')
const maybeModelId = modelConfigStringSplit[1]
if (!isValidModelId(maybeModelId))
throw new Error(`TextGenerationInterface model subtype unknown: ${maybeModelId}`)
throw new Error(`TextGenerationInference model id unknown: ${maybeModelId}`)
return modelIdDefinitions[maybeModelId]
}
}
Expand Down
3 changes: 1 addition & 2 deletions packages/embedder/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,8 @@ const dequeueAndEmbedUntilEmpty = async (modelManager: ModelManager) => {
try {
const generator = modelManager.generationModels[0] // use 1st generator
if (!generator) throw new Error(`Generator unavailable`)
const summarizeOptions = {maxInputTokens, truncate: true}
console.log(`embedder: ...summarizing ${itemKey} for ${modelTable}`)
embedText = await generator.summarize(fullText, summarizeOptions)
embedText = await generator.summarize(fullText, {maxNewTokens: maxInputTokens})
} catch (e) {
await updateJobState(jobQueueId, 'failed', {
stateMessage: `unable to summarize long embed text: ${e}`
Expand Down
2 changes: 1 addition & 1 deletion packages/embedder/indexing/embeddingsTablesOps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export async function selectMetaToQueue(
.where(({eb, not, or, and, exists, selectFrom}) =>
and([
or([
not(eb('em.models', '<@', sql`ARRAY[${sql.ref('model')}]::varchar[]` as any) as any),
not(eb('em.models', '@>', sql`ARRAY[${sql.ref('model')}]::varchar[]` as any) as any),
eb('em.models' as any, 'is', null)
]),
not(
Expand Down

0 comments on commit 807e347

Please sign in to comment.