Skip to content

Commit

Permalink
Fix ooba (#926)
Browse files Browse the repository at this point in the history
Co-authored-by: Sceuick <dev@agnai.chat>
  • Loading branch information
sceuick and Sceuick committed May 5, 2024
1 parent ed91495 commit 38ce9ef
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 134 deletions.
6 changes: 3 additions & 3 deletions srv/adapter/agnaistic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import { handleHorde } from './horde'
import { handleKobold } from './kobold'
import { handleMancer } from './mancer'
import { handleNovel } from './novel'
import { getTextgenCompletion, handleOoba } from './ooba'
import { handleOAI } from './openai'
import { handleOpenRouter } from './openrouter'
import { getThirdPartyPayload } from './payloads'
Expand All @@ -24,6 +23,7 @@ import { ModelAdapter } from './type'
import { AIAdapter, AdapterSetting } from '/common/adapters'
import { AppSchema } from '/common/types'
import { parseStops } from '/common/util'
import { getTextgenCompletion } from './dispatch'

export async function getSubscriptionPreset(
user: AppSchema.User,
Expand Down Expand Up @@ -198,7 +198,7 @@ export const handleAgnaistic: ModelAdapter = async function* (opts) {

if (preset.service === 'kobold' && preset.thirdPartyFormat === 'llamacpp') {
opts.gen.service = 'kobold'
handler = handleOoba
handler = handleKobold
}

if (preset.service === 'goose') {
Expand Down Expand Up @@ -363,7 +363,7 @@ export async function updateRegisteredSubs() {
export const handlers: { [key in AIAdapter]: ModelAdapter } = {
novel: handleNovel,
kobold: handleKobold,
ooba: handleOoba,
ooba: handleKobold,
horde: handleHorde,
openai: handleOAI,
scale: handleScale,
Expand Down
56 changes: 0 additions & 56 deletions srv/adapter/ooba.ts → srv/adapter/dispatch.ts
Original file line number Diff line number Diff line change
@@ -1,61 +1,5 @@
import needle from 'needle'
import { normalizeUrl, sanitise, sanitiseAndTrim, trimResponseV2 } from '../api/chat/common'
import { ModelAdapter } from './type'
import { websocketStream } from './stream'
import { eventGenerator } from '/common/util'
import { getThirdPartyPayload } from './payloads'

export const handleOoba: ModelAdapter = async function* (opts) {
const { char, members, user, prompt, log, gen } = opts
const body = getThirdPartyPayload(opts)

yield { prompt }

log.debug({ ...body, prompt: null }, 'Textgen payload')

log.debug(`Prompt:\n${prompt}`)

const url = gen.thirdPartyUrl || user.oobaUrl
const baseUrl = normalizeUrl(url)
const resp =
opts.gen.service === 'kobold' && opts.gen.thirdPartyFormat === 'llamacpp'
? llamaStream(baseUrl, body)
: gen.streamResponse
? await websocketStream({ url: baseUrl + '/api/v1/stream', body })
: getTextgenCompletion('Textgen', `${baseUrl}/api/v1/generate`, body, {})

let accumulated = ''
let result = ''

while (true) {
let generated = await resp.next()

// Both the streaming and non-streaming generators return a full completion and yield errors.
if (generated.done) {
break
}

if (generated.value.error) {
yield generated.value
return
}

// Only the streaming generator yields individual tokens.
if (generated.value.token) {
accumulated += generated.value.token
yield { partial: sanitiseAndTrim(accumulated, prompt, char, opts.characters, members) }
}

if (typeof generated.value === 'string') {
result = generated.value
break
}
}

const parsed = sanitise((result || accumulated).replace(prompt, ''))
const trimmed = trimResponseV2(parsed, opts.replyAs, members, opts.characters, ['END_OF_DIALOG'])
yield trimmed || parsed
}

export async function* getTextgenCompletion(
service: string,
Expand Down
25 changes: 4 additions & 21 deletions srv/adapter/kobold.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { AppLog, logger } from '../logger'
import { normalizeUrl, sanitise, sanitiseAndTrim, trimResponseV2 } from '../api/chat/common'
import { AdapterProps, ModelAdapter } from './type'
import { requestStream, websocketStream } from './stream'
import { llamaStream } from './ooba'
import { llamaStream } from './dispatch'
import { getStoppingStrings } from './prompt'
import { ThirdPartyFormat } from '/common/adapters'
import { decryptText } from '../db/util'
Expand Down Expand Up @@ -37,6 +37,7 @@ export const handleKobold: ModelAdapter = async function* (opts) {
const { members, characters, prompt, mappedSettings } = opts

const body =
opts.gen.thirdPartyFormat === 'ooba' ||
opts.gen.thirdPartyFormat === 'mistral' ||
opts.gen.thirdPartyFormat === 'tabby' ||
opts.gen.thirdPartyFormat === 'aphrodite' ||
Expand Down Expand Up @@ -131,6 +132,7 @@ async function dispatch(opts: AdapterProps, body: any) {
case 'llamacpp':
return llamaStream(baseURL, body)

case 'ooba':
case 'aphrodite':
case 'tabby':
const url = `${baseURL}/v1/completions`
Expand Down Expand Up @@ -292,8 +294,6 @@ const streamCompletion = async function* (
})

const tokens = []
const start = Date.now()
let first = 0

const responses: Record<number, string> = {}

Expand All @@ -320,10 +320,6 @@ const streamCompletion = async function* (
const res = data.choices ? data.choices[0] : data
const token = 'text' in res ? res.text : res.token

if (!first) {
first = Date.now()
}

/** Handle batch generations */
if (res.index !== undefined) {
const index = res.index
Expand All @@ -337,6 +333,7 @@ const streamCompletion = async function* (
tokens.push(token)
yield { token: token }
}

continue
}

Expand All @@ -348,20 +345,6 @@ const streamCompletion = async function* (
return
}

const ttfb = (Date.now() - first) / 1000
const total = (Date.now() - start) / 1000
const tps = tokens.length / ttfb
const total_tps = tokens.length / total
log.info(
{
ttfb: ttfb.toFixed(1),
total: total.toFixed(1),
tps: tps.toFixed(1),
total_tps: total_tps.toFixed(1),
},
'Performance'
)

const gens: string[] = []
for (const [id, text] of Object.entries(responses)) {
if (+id === 0) continue
Expand Down
147 changes: 94 additions & 53 deletions srv/adapter/payloads.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,58 @@ export function getThirdPartyPayload(opts: AdapterProps, stops: string[] = []) {
function getBasePayload(opts: AdapterProps, stops: string[] = []) {
const { gen, prompt } = opts

// Agnaistic
if (gen.service !== 'kobold') {
const body: any = {
prompt,
context_limit: gen.maxContextLength,
max_new_tokens: gen.maxTokens,
do_sample: gen.doSample ?? true,
temperature: gen.temp,
top_p: gen.topP,
typical_p: gen.typicalP || 1,
repetition_penalty: gen.repetitionPenalty,
encoder_repetition_penalty: gen.encoderRepitionPenalty,
repetition_penalty_range: gen.repetitionPenaltyRange,
frequency_penalty: gen.frequencyPenalty,
presence_penalty: gen.presencePenalty,
top_k: gen.topK,
min_p: gen.minP,
top_a: gen.topA,
min_length: 0,
no_repeat_ngram_size: 0,
num_beams: gen.numBeams || 1,
penalty_alpha: gen.penaltyAlpha,
length_penalty: 1,
early_stopping: gen.earlyStopping || false,
seed: -1,
add_bos_token: gen.addBosToken || false,
truncation_length: gen.maxContextLength || 2048,
ban_eos_token: gen.banEosToken || false,
skip_special_tokens: gen.skipSpecialTokens ?? true,
stopping_strings: getStoppingStrings(opts, stops),
dynamic_temperature: gen.dynatemp_range ? true : false,
smoothing_factor: gen.smoothingFactor,
tfs: gen.tailFreeSampling,
mirostat_mode: gen.mirostatTau ? 2 : 0,
mirostat_tau: gen.mirostatTau,
mirostat_eta: gen.mirostatLR,
guidance: opts.guidance,
placeholders: opts.placeholders,
lists: opts.lists,
previous: opts.previous,
}

if (gen.dynatemp_range) {
body.min_temp = (gen.temp ?? 1) - (gen.dynatemp_range ?? 0)
body.max_temp = (gen.temp ?? 1) + (gen.dynatemp_range ?? 0)
body.dynatemp_range = gen.dynatemp_range
body.temp_exponent = gen.dynatemp_exponent
}

return body
}

if (gen.thirdPartyFormat === 'mistral') {
const body = {
messages: [{ role: 'user', content: prompt }],
Expand Down Expand Up @@ -65,7 +117,7 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {
return body
}

if (gen.service === 'kobold' && gen.thirdPartyFormat === 'llamacpp') {
if (gen.thirdPartyFormat === 'llamacpp') {
const body = {
prompt,
temperature: gen.temp,
Expand All @@ -90,7 +142,45 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {
return body
}

if (gen.service === 'kobold' && gen.thirdPartyFormat === 'aphrodite') {
if (gen.thirdPartyFormat === 'ooba') {
const body: any = {
prompt,
temperature: gen.temp,
min_p: gen.minP,
top_k: gen.topK,
top_p: gen.topP,
top_a: gen.topA,
stop: getStoppingStrings(opts, stops),
stream: true,
frequency_penality: gen.frequencyPenalty,
presence_penalty: gen.presencePenalty,
repetition_penalty: gen.repetitionPenalty,
repetition_penalty_range: gen.repetitionPenaltyRange,
do_sample: gen.doSample,
penalty_alpha: gen.penaltyAlpha,
mirostat_mode: gen.mirostatTau ? 2 : 0,
mirostat_tau: gen.mirostatTau,
mirostat_eta: gen.mirostatLR,
typical_p: gen.typicalP,
tfs_z: gen.tailFreeSampling,
max_tokens: opts.gen.maxTokens,
skip_special_tokens: gen.skipSpecialTokens,
smoothing_factor: gen.smoothingFactor,
smoothing_curve: gen.smoothingCurve,
tfs: gen.tailFreeSampling,
}

if (gen.dynatemp_range) {
body.dynamic_temperature = true
body.dynatemp_low = (gen.temp ?? 1) - (gen.dynatemp_range ?? 0)
body.dynatemp_high = (gen.temp ?? 1) + (gen.dynatemp_range ?? 0)
body.dynatemp_exponent = gen.dynatemp_exponent
}

return body
}

if (gen.thirdPartyFormat === 'aphrodite') {
const body: any = {
model: gen.thirdPartyModel || '',
stream: opts.kind === 'summary' ? false : gen.streamResponse ?? true,
Expand Down Expand Up @@ -133,7 +223,7 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {
return body
}

if (gen.service === 'kobold' && gen.thirdPartyFormat === 'exllamav2') {
if (gen.thirdPartyFormat === 'exllamav2') {
const body = {
request_id: opts.requestId,
action: 'infer',
Expand All @@ -151,7 +241,7 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {
return body
}

if (gen.service === 'kobold' && gen.thirdPartyFormat === 'koboldcpp') {
if (gen.thirdPartyFormat === 'koboldcpp') {
const body = {
n: 1,
max_context_length: gen.maxContextLength,
Expand All @@ -175,53 +265,4 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {
}
return body
}

const body: any = {
prompt,
context_limit: gen.maxContextLength,
max_new_tokens: gen.maxTokens,
do_sample: gen.doSample ?? true,
temperature: gen.temp,
top_p: gen.topP,
typical_p: gen.typicalP || 1,
repetition_penalty: gen.repetitionPenalty,
encoder_repetition_penalty: gen.encoderRepitionPenalty,
repetition_penalty_range: gen.repetitionPenaltyRange,
frequency_penalty: gen.frequencyPenalty,
presence_penalty: gen.presencePenalty,
top_k: gen.topK,
min_p: gen.minP,
top_a: gen.topA,
min_length: 0,
no_repeat_ngram_size: 0,
num_beams: gen.numBeams || 1,
penalty_alpha: gen.penaltyAlpha,
length_penalty: 1,
early_stopping: gen.earlyStopping || false,
seed: -1,
add_bos_token: gen.addBosToken || false,
truncation_length: gen.maxContextLength || 2048,
ban_eos_token: gen.banEosToken || false,
skip_special_tokens: gen.skipSpecialTokens ?? true,
stopping_strings: getStoppingStrings(opts, stops),
dynamic_temperature: gen.dynatemp_range ? true : false,
smoothing_factor: gen.smoothingFactor,
tfs: gen.tailFreeSampling,
mirostat_mode: gen.mirostatTau ? 2 : 0,
mirostat_tau: gen.mirostatTau,
mirostat_eta: gen.mirostatLR,
guidance: opts.guidance,
placeholders: opts.placeholders,
lists: opts.lists,
previous: opts.previous,
}

if (gen.dynatemp_range) {
body.min_temp = (gen.temp ?? 1) - (gen.dynatemp_range ?? 0)
body.max_temp = (gen.temp ?? 1) + (gen.dynatemp_range ?? 0)
body.dynatemp_range = gen.dynatemp_range
body.temp_exponent = gen.dynatemp_exponent
}

return body
}
2 changes: 1 addition & 1 deletion srv/api/chat/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ export const generateMessageV2 = handle(async (req, res) => {
}

await releaseLock(chatId)
if (error || !generated.trim()) {
if (error) {
return
}

Expand Down

0 comments on commit 38ce9ef

Please sign in to comment.