Skip to content

Commit

Permalink
Merge pull request #268 from aashsach/bedrock-mistral
Browse files Browse the repository at this point in the history
Bedrock mistral
  • Loading branch information
VisargD committed Mar 29, 2024
2 parents 4878580 + 8d92689 commit a37b8bd
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 0 deletions.
159 changes: 159 additions & 0 deletions src/providers/bedrock/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import {
BedrockLlamaStreamChunk,
BedrockTitanCompleteResponse,
BedrockTitanStreamChunk,
BedrockMistralCompleteResponse,
BedrocMistralStreamChunk,
} from './complete';
import { BedrockErrorResponse } from './embed';

Expand Down Expand Up @@ -268,6 +270,57 @@ export const BedrockLLamaChatCompleteConfig: ProviderConfig = {
},
};

export const BedrockMistralChatCompleteConfig: ProviderConfig = {
messages: {
param: 'prompt',
required: true,
transform: (params: Params) => {
let prompt: string = '';
if (!!params.messages) {
let messages: Message[] = params.messages;
messages.forEach((msg, index) => {
if (index === 0 && msg.role === 'system') {
prompt += `system: ${messages}\n`;
} else if (msg.role == 'user') {
prompt += `user: ${msg.content}\n`;
} else if (msg.role == 'assistant') {
prompt += `assistant: ${msg.content}\n`;
} else {
prompt += `${msg.role}: ${msg.content}\n`;
}
});
prompt += 'Assistant:';
}
return prompt;
},
},
max_tokens: {
param: 'max_tokens',
default: 20,
min: 1,
},
temperature: {
param: 'temperature',
default: 0.75,
min: 0,
max: 5,
},
top_p: {
param: 'top_p',
default: 0.75,
min: 0,
max: 1,
},
top_k: {
param: 'top_k',
default: 0,
max: 200,
},
stop: {
param: 'stop',
},
};

const transformTitanGenerationConfig = (params: Params) => {
const generationConfig: Record<string, any> = {};
if (params['temperature']) {
Expand Down Expand Up @@ -908,3 +961,109 @@ export const BedrockCohereChatCompleteStreamChunkTransform: (
],
})}\n\n`;
};

export const BedrockMistralChatCompleteResponseTransform: (
response: BedrockMistralCompleteResponse | BedrockErrorResponse,
responseStatus: number,
responseHeaders: Headers
) => ChatCompletionResponse | ErrorResponse = (
response,
responseStatus,
responseHeaders
) => {
if (responseStatus !== 200) {
const errorResposne = BedrockErrorResponseTransform(
response as BedrockErrorResponse
);
if (errorResposne) return errorResposne;
}

if ('outputs' in response) {
const prompt_tokens =
Number(responseHeaders.get('X-Amzn-Bedrock-Input-Token-Count')) || 0;
const completion_tokens =
Number(responseHeaders.get('X-Amzn-Bedrock-Output-Token-Count')) || 0;
return {
id: Date.now().toString(),
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
index: 0,
message: {
role: 'assistant',
content: response.outputs[0].text,
},
finish_reason: response.outputs[0].stop_reason,
},
],
usage: {
prompt_tokens: prompt_tokens,
completion_tokens: completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
};
}

return generateInvalidProviderResponseError(response, BEDROCK);
};

export const BedrockMistralChatCompleteStreamChunkTransform: (
response: string,
fallbackId: string
) => string | string[] = (responseChunk, fallbackId) => {
let chunk = responseChunk.trim();
chunk = chunk.replace(/^data: /, '');
chunk = chunk.trim();
const parsedChunk: BedrocMistralStreamChunk = JSON.parse(chunk);

// discard the last cohere chunk as it sends the whole response combined.
if (parsedChunk.outputs[0].stop_reason) {
return [
`data: ${JSON.stringify({
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
index: 0,
delta: {},
finish_reason: parsedChunk.outputs[0].stop_reason,
},
],
usage: {
prompt_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount,
completion_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount,
total_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount +
parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount,
},
})}\n\n`,
`data: [DONE]\n\n`,
];
}

return `data: ${JSON.stringify({
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
index: 0,
delta: {
role: 'assistant',
content: parsedChunk.outputs[0].text,
},
finish_reason: null,
},
],
})}\n\n`;
};
152 changes: 152 additions & 0 deletions src/providers/bedrock/complete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,37 @@ export const BedrockLLamaCompleteConfig: ProviderConfig = {
},
};

export const BedrockMistralCompleteConfig: ProviderConfig = {
prompt: {
param: 'prompt',
required: true,
},
max_tokens: {
param: 'max_tokens',
default: 20,
min: 1,
},
temperature: {
param: 'temperature',
default: 0.75,
min: 0,
max: 5,
},
top_p: {
param: 'top_p',
default: 0.75,
min: 0,
max: 1,
},
top_k: {
param: 'top_k',
default: 0,
},
stop: {
param: 'stop',
},
};

const transformTitanGenerationConfig = (params: Params) => {
const generationConfig: Record<string, any> = {};
if (params['temperature']) {
Expand Down Expand Up @@ -754,3 +785,124 @@ export const BedrockCohereCompleteStreamChunkTransform: (
],
})}\n\n`;
};

export interface BedrocMistralStreamChunk {
outputs: {
text: string;
stop_reason: string | null;
}[];
'amazon-bedrock-invocationMetrics': {
inputTokenCount: number;
outputTokenCount: number;
invocationLatency: number;
firstByteLatency: number;
};
}

export const BedrockMistralCompleteStreamChunkTransform: (
response: string,
fallbackId: string
) => string | string[] = (responseChunk, fallbackId) => {
let chunk = responseChunk.trim();
chunk = chunk.trim();
const parsedChunk: BedrocMistralStreamChunk = JSON.parse(chunk);

if (parsedChunk.outputs[0].stop_reason) {
return [
`data: ${JSON.stringify({
id: fallbackId,
object: 'text_completion',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
text: parsedChunk.outputs[0].text,
index: 0,
logprobs: null,
finish_reason: parsedChunk.outputs[0].stop_reason,
},
],
usage: {
prompt_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount,
completion_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount,
total_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount +
parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount,
},
})}\n\n`,
`data: [DONE]\n\n`,
];
}

return `data: ${JSON.stringify({
id: fallbackId,
object: 'text_completion',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
text: parsedChunk.outputs[0].text,
index: 0,
logprobs: null,
finish_reason: null,
},
],
})}\n\n`;
};

export interface BedrockMistralCompleteResponse {
outputs: {
text: string;
stop_reason: string;
}[];
}

export const BedrockMistralCompleteResponseTransform: (
response: BedrockMistralCompleteResponse | BedrockErrorResponse,
responseStatus: number,
responseHeaders: Headers
) => CompletionResponse | ErrorResponse = (
response,
responseStatus,
responseHeaders
) => {
if (responseStatus !== 200) {
const errorResponse = BedrockErrorResponseTransform(
response as BedrockErrorResponse
);
if (errorResponse) return errorResponse;
}

if ('outputs' in response) {
const prompt_tokens =
Number(responseHeaders.get('X-Amzn-Bedrock-Input-Token-Count')) || 0;
const completion_tokens =
Number(responseHeaders.get('X-Amzn-Bedrock-Output-Token-Count')) || 0;
return {
id: Date.now().toString(),
object: 'text_completion',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
text: response.outputs[0].text,
index: 0,
logprobs: null,
finish_reason: response.outputs[0].stop_reason,
},
],
usage: {
prompt_tokens: prompt_tokens,
completion_tokens: completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
};
}

return generateInvalidProviderResponseError(response, BEDROCK);
};
19 changes: 19 additions & 0 deletions src/providers/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import {
BedrockTitanChatCompleteResponseTransform,
BedrockTitanChatCompleteStreamChunkTransform,
BedrockTitanChatompleteConfig,
BedrockMistralChatCompleteConfig,
BedrockMistralChatCompleteResponseTransform,
BedrockMistralChatCompleteStreamChunkTransform,
} from './chatComplete';
import {
BedrockAI21CompleteConfig,
Expand All @@ -30,6 +33,9 @@ import {
BedrockLLamaCompleteConfig,
BedrockLlamaCompleteResponseTransform,
BedrockLlamaCompleteStreamChunkTransform,
BedrockMistralCompleteConfig,
BedrockMistralCompleteResponseTransform,
BedrockMistralCompleteStreamChunkTransform,
BedrockTitanCompleteConfig,
BedrockTitanCompleteResponseTransform,
BedrockTitanCompleteStreamChunkTransform,
Expand Down Expand Up @@ -91,6 +97,19 @@ const BedrockConfig: ProviderConfigs = {
chatComplete: BedrockLlamaChatCompleteResponseTransform,
},
};
case 'mistral':
return {
complete: BedrockMistralCompleteConfig,
chatComplete: BedrockMistralChatCompleteConfig,
api: BedrockAPIConfig,
responseTransforms: {
'stream-complete': BedrockMistralCompleteStreamChunkTransform,
complete: BedrockMistralCompleteResponseTransform,
'stream-chatComplete':
BedrockMistralChatCompleteStreamChunkTransform,
chatComplete: BedrockMistralChatCompleteResponseTransform,
},
};
case 'amazon':
return {
complete: BedrockTitanCompleteConfig,
Expand Down

0 comments on commit a37b8bd

Please sign in to comment.