Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discord bot: Pythia model #2831

Merged
merged 10 commits into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion discord-bots/oa-bot-js/.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ OA_APIURL=OpenAssistant API url
INFERENCE_SERVER_API_KEY=
INFERENCE_SERVER_HOST=
REDIS_PASSWORD=
DEFAULT_MODEL=default model if user doesn't specify one
DEFAULT_MODEL=default model if user does not specify one
HUGGINGFACE_TOKEN=huggingface token
1 change: 1 addition & 0 deletions discord-bots/oa-bot-js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"license": "Apache-2.0",
"private": true,
"dependencies": {
"@huggingface/inference": "^2.0.0",
"axios": "^1.3.5",
"chalk": "^5.2.0",
"discord.js": "^14.7.1",
Expand Down
163 changes: 104 additions & 59 deletions discord-bots/oa-bot-js/src/commands/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import {
ButtonStyle,
ButtonBuilder,
} from "discord.js";
import { createInferenceClient } from "../modules/inference/client.js";
import redis from "../modules/redis.js";
import chatFN from "../modules/chat.js";

export default {
disablePing: null,
Expand All @@ -23,10 +23,16 @@ export default {
.setName("model")
.setDescription("The model you want to use for the AI.")
.setRequired(false)
.addChoices({
name: "OA_SFT_Llama_30B",
value: "OA_SFT_Llama_30B",
})
.addChoices(
{
name: "OA_SFT_Llama_30B",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be best if the list of available models in the future would be requested at the inference-server (/models endpoint). It will be constantly changing .. you were probably irritated that so many are displayed in the staging environment .. but that's really only a "problem" in staging.

value: "OA_SFT_Llama_30B",
},
{
name: "oasst-sft-4-pythia-12b",
value: "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
}
)
)
.addStringOption((option) =>
option
Expand Down Expand Up @@ -73,71 +79,110 @@ export default {
model = interaction.options.getString("model");
preset = interaction.options.getString("preset");
}
if (!model)
model = process.env.OPEN_ASSISTANT_DEFAULT_MODEL || "OA_SFT_Llama_30B";
if (!model) {
let userModel = await redis.get(`model_${interaction.user.id}`);
if (userModel) {
model = userModel;
} else {
model = process.env.OPEN_ASSISTANT_DEFAULT_MODEL || "OA_SFT_Llama_30B";
redis.set(`model_${interaction.user.id}`, model);
}
} else {
redis.set(`model_${interaction.user.id}`, model);
}
if (!preset) preset = "k50";
// sleep for 30s

const OA = await createInferenceClient(
interaction.user.username,
interaction.user.id
);
if (model.includes("Llama")) {
try {
let chat = await redis.get(`chat_${interaction.user.id}`);
let chatId = chat ? chat.split("_")[0] : null;
let parentId = chat ? chat.split("_")[1] : null;
let { assistant_message, OA } = await chatFN(
model,
interaction.user,
message,
chatId,
parentId,
presets,
preset
);
await redis.set(
`chat_${interaction.user.id}`,
`${chatId}_${assistant_message.id}`
);

try {
let chat = await redis.get(`chat_${interaction.user.id}`);
let chatId = chat ? chat.split("_")[0] : null;
let parentId = chat ? chat.split("_")[1] : null;
if (!chatId) {
let chat = await OA.create_chat();
chatId = chat.id;
const row = new ActionRowBuilder().addComponents(
new ButtonBuilder()
.setStyle(ButtonStyle.Secondary)
.setLabel(`👍`)
.setCustomId(`vote_${assistant_message.id}_up`),
new ButtonBuilder()
.setStyle(ButtonStyle.Secondary)
.setLabel(`👎`)
.setCustomId(`vote_${assistant_message.id}_down`),
new ButtonBuilder()
.setStyle(ButtonStyle.Secondary)
.setDisabled(false)
.setLabel(
`${model.replaceAll("OpenAssistant/", "").replaceAll("_", "")}`
)
.setCustomId(`model_${assistant_message.id}`)
);
// using events
let events = await OA.stream_events({
chat_id: chatId,
message_id: assistant_message.id,
});
events.on("data", async (c) => {
/* let string = JSON.parse(c);
if (!string.queue_position) {
await commandType.reply(interaction, {
content: `${string} <a:loading:1051419341914132554>`,
components: [],
});
}*/
});
events.on("end", async (c) => {
let msg = await OA.get_message(chatId, assistant_message.id);
await commandType.reply(interaction, {
content: msg.content,
components: [row],
});
});
} catch (err: any) {
console.log(err);
// get details of the error
await commandType.reply(
interaction,
`There was an error while executing this command! ${err.message}`
);
}
let prompter_message = await OA.post_prompter_message({
chat_id: chatId,
content: message,
parent_id: parentId,
});

let assistant_message = await OA.post_assistant_message({
chat_id: chatId,
model_config_name: model,
parent_id: prompter_message.id,
sampling_parameters: presets[preset],
});
await redis.set(
`chat_${interaction.user.id}`,
`${chatId}_${assistant_message.id}`
} else {
let { assistant_message, error } = await chatFN(
model,
interaction.user,
message
);

if (error) {
await commandType.reply(
interaction,
`There was an error while executing this command! ${error}`
);
}
const row = new ActionRowBuilder().addComponents(
new ButtonBuilder()
.setStyle(ButtonStyle.Secondary)
.setLabel(`👍`)
.setCustomId(`vote_${assistant_message.id}_up`),
new ButtonBuilder()
.setStyle(ButtonStyle.Secondary)
.setLabel(`👎`)
.setCustomId(`vote_${assistant_message.id}_down`)
.setDisabled(false)
.setLabel(
`${model.replaceAll("OpenAssistant/", "").replaceAll("_", "")}`
)
.setCustomId(`model_${interaction.user.id}`)
);
// using events
let events = await OA.stream_events({
chat_id: chatId,
message_id: assistant_message.id,
await commandType.reply(interaction, {
content: assistant_message,
components: [row],
});
events.on("data", async (c) => {});
events.on("end", async (c) => {
let msg = await OA.get_message(chatId, assistant_message.id);
await commandType.reply(interaction, {
content: msg.content,
components: [row],
});
});
} catch (err: any) {
console.log(err);
// get details of the error
await commandType.reply(
interaction,
`There was an error while executing this command! ${err.message}`
);
}
},
};
Expand Down
45 changes: 45 additions & 0 deletions discord-bots/oa-bot-js/src/interactions/model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import {
SlashCommandBuilder,
ActionRowBuilder,
StringSelectMenuBuilder,
StringSelectMenuOptionBuilder,
} from "discord.js";
import { createInferenceClient } from "../modules/inference/client.js";
import redis from "../modules/redis.js";

export default {
data: {
customId: "model",
description: "Switch to another model.",
},
async execute(interaction, client, userId) {
if (interaction.user.id != userId)
return interaction.reply({
content: "You don't have permission to do this.",
ephemeral: true,
});
// model selector
let row = new ActionRowBuilder().addComponents(
new StringSelectMenuBuilder()
.setCustomId("modelselect")
.setPlaceholder("Select a model")
.setMinValues(1)
.setMaxValues(1)
.addOptions(
new StringSelectMenuOptionBuilder()
.setLabel("OA_SFT_Llama_30B")
.setDescription("Llama (default)")
.setValue("OA_SFT_Llama_30B"),
new StringSelectMenuOptionBuilder()
.setLabel("oasst-sft-4-pythia-12b")
.setDescription("Pythia")
.setValue("OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5")
)
);
await interaction.reply({
content: "Select a model.",
components: [row],
ephemeral: true,
});
},
};
28 changes: 28 additions & 0 deletions discord-bots/oa-bot-js/src/interactions/modelselect.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import {
SlashCommandBuilder,
ActionRowBuilder,
StringSelectMenuBuilder,
StringSelectMenuOptionBuilder,
} from "discord.js";
import { createInferenceClient } from "../modules/inference/client.js";
import redis from "../modules/redis.js";

export default {
data: {
customId: "modelselect",
description: "Switch to another model.",
},
async execute(interaction, client) {
// get selected value
let model = interaction.values[0];
// set model
await interaction.deferReply({
ephemeral: true,
});
redis.set(`model_${interaction.user.id}`, model);
await interaction.editReply({
content: `Model set to ${model}.`,
ephemeral: true,
});
},
};
75 changes: 75 additions & 0 deletions discord-bots/oa-bot-js/src/modules/chat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import { createInferenceClient } from "../modules/inference/client.js";
import { HfInference } from "@huggingface/inference";
const hf = new HfInference(process.env.HUGGINGFACE_TOKEN);

export default async function chat(
model,
user,
message,
chatId?,
parentId?,
presets?,
preset?
) {
if (model.includes("Llama")) {
const OA = await createInferenceClient(user.username, user.id);
if (!chatId) {
let chat = await OA.create_chat();
chatId = chat.id;
}
let prompter_message = await OA.post_prompter_message({
chat_id: chatId,
content: message,
parent_id: parentId,
});

let assistant_message = await OA.post_assistant_message({
chat_id: chatId,
model_config_name: model,
parent_id: prompter_message.id,
sampling_parameters: presets[preset],
});
return { assistant_message, OA };
} else {
let result = await huggingface(
model,
`<|prompter|>${message}<|endoftext|>\n<|assistant|>`
);
if (result.error) {
return { error: result.error };
}
return { assistant_message: result.response };
}
}

export async function huggingface(model, input) {
try {
let oldText;
let loop = true;
while (loop) {
let response = await hf.textGeneration({
model: model,
inputs: input,
});
let answer = response.generated_text.split("<|assistant|>")[1];
if (answer == oldText) {
loop = false;
} else {
if (!oldText) {
oldText = answer;
input += answer;
} else {
oldText += answer;
input += answer;
}
}
}

return { response: oldText };
} catch (err: any) {
console.log(err);
return {
error: err.message,
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export async function initInteraction(interaction, translation, lang) {
.setFooter({ text: `${getLocaleDisplayName(lang)}` })
.setTitle("Open assistant")
.setDescription(`${translation["conversational"]}`)
.setURL("https://open-assistant.io/?ref=turing")
.setURL("https://open-assistant.io/?ref=discordbot")
.setThumbnail("https://open-assistant.io/images/logos/logo.png");

const row = new ActionRowBuilder().addComponents(
Expand All @@ -25,7 +25,7 @@ export async function initInteraction(interaction, translation, lang) {
.setLabel(translation.grab_a_task)
.setCustomId(`oa_tasks_n_${interaction.user.id}`)
.setStyle(ButtonStyle.Primary)
.setDisabled(false),
.setDisabled(true),
new ButtonBuilder()
.setLabel("Change language")
.setCustomId(`oa_lang-btn_n_${interaction.user.id}`)
Expand Down