Skip to content

Commit

Permalink
add phi-3
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHEQIUSHUI committed Apr 25, 2024
1 parent 612e940 commit ad6231d
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 9 deletions.
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ endfunction()
build_exec(main src/main.cpp)
# build_exec(main_qwen src/main_qwen.cpp)

file(GLOB RUN_SCRIPT "${CMAKE_SOURCE_DIR}/scripts/*")
file(GLOB RUN_SCRIPT "${CMAKE_SOURCE_DIR}/scripts/*.py" "${CMAKE_SOURCE_DIR}/scripts/*.sh")
install(FILES ${RUN_SCRIPT} DESTINATION bin/)

file(GLOB LLAMA3_TOKENIZER "${CMAKE_SOURCE_DIR}/scripts/llama3_tokenizer/*")
install(FILES ${LLAMA3_TOKENIZER} DESTINATION bin/llama3_tokenizer/)

# add_executable(fp32_to_bf16 tools/fp32_to_bf16.cpp)
13 changes: 13 additions & 0 deletions scripts/run_phi3_mini.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
./main \
--template_filename_axmodel "phi3-int8/llama_l%d.axmodel" \
--axmodel_num 32 \
--tokenizer_type 3 \
--bos 1 --eos 0 \
--filename_tokenizer_model tokenizer.model \
--filename_post_axmodel phi3-int8/llama_post.axmodel \
--filename_tokens_embed phi3-int8/model.embed_tokens.weight.bfloat16.bin \
--tokens_embed_num 32064 \
--tokens_embed_size 3072 \
--live_print 1 \
--continue 1 \
--prompt "$1"
5 changes: 4 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ std::string prompt_complete(std::string prompt, TokenizerType tokenizer_type)
oss_prompt << "<|user|>\n"
<< prompt << "</s><|assistant|>\n";
break;
case TKT_Phi3:
oss_prompt << prompt << " ";
break;
case TKT_Qwen:
oss_prompt << "<|im_start|>system\nYou are a helpful assistant.<|im_end|>";
oss_prompt << "\n<|im_start|>user\n"
Expand All @@ -52,7 +55,7 @@ int main(int argc, char *argv[])
cmd.add<std::string>("prompt", 'p', "prompt", true, prompt);
cmd.add<std::string>("template_filename_axmodel", 0, "axmodel path template", false, attr.template_filename_axmodel);
cmd.add<std::string>("filename_post_axmodel", 0, "post axmodel path", false, attr.filename_post_axmodel);
cmd.add<int>("tokenizer_type", 0, "tokenizer type 0:LLaMa 1:Qwen 2:HTTP", false, attr.tokenizer_type);
cmd.add<int>("tokenizer_type", 0, "tokenizer type 0:LLaMa 1:Qwen 2:HTTP 3:Phi3", false, attr.tokenizer_type);
cmd.add<std::string>("filename_tokenizer_model", 0, "tokenizer model path", false, attr.filename_tokenizer_model);
cmd.add<std::string>("filename_tokens_embed", 0, "tokens embed path", false, attr.filename_tokens_embed);

Expand Down
24 changes: 17 additions & 7 deletions src/runner/LLM.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class LLM
return true;
}

LLMAttrType* getAttr()
LLMAttrType *getAttr()
{
return &_attr;
}
Expand Down Expand Up @@ -239,6 +239,7 @@ class LLM
int len_of_input = token_ids.size();
timer t_cost;
// print token_ids
// printf("%s\n", input_str.c_str());
// for (size_t i = 0; i < token_ids.size(); i++)
// {
// printf("%d ", token_ids[i]);
Expand Down Expand Up @@ -345,7 +346,22 @@ class LLM
}
}
next_token = max_index;

if (tokenizer->isEnd(max_index))
{
if (cached_token.size())
{
float t_cost_ms = t_cost.cost();
float token_per_sec = token_ids.size() / (t_cost_ms / 1000);
auto tmp_out = tokenizer->Decode(cached_token);
_attr.runing_callback(cached_token.data(), cached_token.size(), tmp_out.c_str(), token_per_sec, _attr.reserve);
cached_token.clear();
}
b_hit_eos = true;
break;
}
token_ids.push_back(max_index);

if (_attr.runing_callback)
{
cached_token.push_back(max_index);
Expand All @@ -358,12 +374,6 @@ class LLM
cached_token.clear();
}
}

if (max_index == tokenizer->GetEosID())
{
b_hit_eos = true;
break;
}
}
if (_attr.runing_callback == nullptr)
update_cqdm(&cqdm, indices, "token", "");
Expand Down
84 changes: 84 additions & 0 deletions src/runner/Tokenizer/Tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,88 @@ class TokenizerLLaMa : public BaseTokenizer
}
};

class TokenizerPhi3 : public BaseTokenizer
{
sentencepiece::SentencePieceProcessor sp;
bool _b_bos, _b_eos;

private:
/* data */
public:
bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override
{
auto ret = sp.Load(model_path);
if (!ret.ok())
{
ALOGE("%s", ret.error_message());
return false;
}

this->_b_bos = b_bos;
this->_b_eos = b_eos;
return ret.ok();
}

bool Encode(std::string input, std::vector<int> &output) override
{
auto ret = sp.Encode(input, &output);
if (!ret.ok())
{
ALOGE("%s", ret.error_message());
return false;
}
output.insert(output.begin(), 32010); //"<|user|>"
output.push_back(32007); //"<|end|>"
output.push_back(32001); //"<|assistant|>"
if (_b_bos)
{
output.insert(output.begin(), sp.bos_id());
}
if (_b_eos)
{
output.push_back(sp.eos_id());
}
return true;
}

std::vector<int> Encode(std::string input) override
{
std::vector<int> output;
Encode(input, output);
return output;
}

std::string Decode(const std::vector<int> input) override
{
sentencepiece::SentencePieceText spt;
sp.Decode(input, &spt);
std::string out = spt.pieces()[0].piece();
if (*(unsigned short *)out.data() == 38626)
{
return " " + spt.text();
}
else
{
return spt.text();
}
}

int GetBosID() override
{
return sp.bos_id();
}

int GetEosID() override
{
return 32007;
}

bool isEnd(int id) override
{
return id == GetEosID() || id > 31999;
}
};

class TokenizerQwen : public BaseTokenizer
{
std::shared_ptr<QwenTokenizer> sp;
Expand Down Expand Up @@ -370,6 +452,8 @@ std::shared_ptr<BaseTokenizer> CreateTokenizer(TokenizerType type)
return std::make_shared<Tokenizer_Http>();
case TKT_Qwen:
return std::make_shared<TokenizerQwen>();
case TKT_Phi3:
return std::make_shared<TokenizerPhi3>();
default:
return nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions src/runner/Tokenizer/Tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ enum TokenizerType
TKT_LLaMa,
TKT_Qwen,
TKT_HTTP,
TKT_Phi3,
TKT_END
};

Expand All @@ -20,6 +21,8 @@ class BaseTokenizer
virtual std::string Decode(const std::vector<int> input) = 0;
virtual int GetBosID() = 0;
virtual int GetEosID() = 0;

virtual bool isEnd(int id) { return id == GetEosID(); }
};

std::shared_ptr<BaseTokenizer> CreateTokenizer(TokenizerType type);

0 comments on commit ad6231d

Please sign in to comment.