Skip to content

Commit

Permalink
refactor serve.cpp for better compatibility with copilot
Browse files Browse the repository at this point in the history
  • Loading branch information
ravenscroftj committed Apr 15, 2023
1 parent ef68767 commit 272e187
Showing 1 changed file with 74 additions and 56 deletions.
130 changes: 74 additions & 56 deletions examples/codegen/serve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,12 @@
#include <boost/uuid/uuid_generators.hpp> // generators
#include <boost/uuid/uuid_io.hpp> // streaming operators etc.

int main(int argc, char** argv) {

gpt_params params;
params.model = "models/gpt-j-6B/ggml-model.bin";

if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}

if (params.seed < 0) {
params.seed = time(NULL);
}

printf("%s: seed = %d\n", __func__, params.seed);

/**
* This function serves requests for autocompletion from crow
*
*/
crow::response serve_response(gpt_params params, gptj_model &model, gpt_vocab &vocab, const crow::request& req){

crow::SimpleApp app;

gpt_vocab vocab;
gptj_model model;

int64_t t_load_us = 0;

// load the model
{
const int64_t t_start_us = ggml_time_us();

if (!gptj_model_load(params.model, model, vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}

t_load_us = ggml_time_us() - t_start_us;
}


CROW_ROUTE(app, "/")([](){
return "Hello world";
});

CROW_ROUTE(app, "/copilot_internal/v2/token")([](){
//return "Hello world";

crow::json::wvalue response = {{"token","1"}, {"expires_at", static_cast<std::uint64_t>(2600000000)}, {"refresh_in",900}};

crow::response res;
res.code = 200;
res.set_header("Content-Type", "application/json");
res.body = response.dump();
return res;
});


CROW_ROUTE(app, "/v1/engines/codegen/completions").methods(crow::HTTPMethod::POST)
([&model, &vocab, &params](const crow::request& req) {
crow::json::rvalue data = crow::json::load(req.body);

if(!data.has("prompt") && !data.has("input_ids")){
Expand Down Expand Up @@ -97,8 +48,6 @@ int main(int argc, char** argv) {
std::string suffix = "";
float temperature = 0.6;

data["model"].s();

if(data.has("suffix")){
suffix = data["suffix"].s();
}
Expand Down Expand Up @@ -223,6 +172,75 @@ int main(int argc, char** argv) {

res.body = response.dump(); //ss.str();
return res;
}

int main(int argc, char** argv) {

gpt_params params;
params.model = "models/gpt-j-6B/ggml-model.bin";

if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}

if (params.seed < 0) {
params.seed = time(NULL);
}

printf("%s: seed = %d\n", __func__, params.seed);


crow::SimpleApp app;

gpt_vocab vocab;
gptj_model model;

int64_t t_load_us = 0;

// load the model
{
const int64_t t_start_us = ggml_time_us();

if (!gptj_model_load(params.model, model, vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}

t_load_us = ggml_time_us() - t_start_us;
}


CROW_ROUTE(app, "/")([](){
return "Hello world";
});

CROW_ROUTE(app, "/copilot_internal/v2/token")([](){
//return "Hello world";

crow::json::wvalue response = {{"token","1"}, {"expires_at", static_cast<std::uint64_t>(2600000000)}, {"refresh_in",900}};

crow::response res;
res.code = 200;
res.set_header("Content-Type", "application/json");
res.body = response.dump();
return res;
});


CROW_ROUTE(app, "/v1/completions").methods(crow::HTTPMethod::POST)
([&model, &vocab, &params](const crow::request& req) {
return serve_response(params, model, vocab, req);
});

CROW_ROUTE(app, "/v1/engines/codegen/completions").methods(crow::HTTPMethod::POST)
([&model, &vocab, &params](const crow::request& req) {
return serve_response(params, model, vocab, req);
});


CROW_ROUTE(app, "/v1/engines/copilot-codex/completions").methods(crow::HTTPMethod::POST)
([&model, &vocab, &params](const crow::request& req) {
return serve_response(params, model, vocab, req);
});

app.port(18080).multithreaded().run();
Expand Down

0 comments on commit 272e187

Please sign in to comment.