Skip to content

Commit

Permalink
Add EOS_TOKEN
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 23, 2023
1 parent 8038b10 commit ee375d0
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions examples/llama/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ use plotly::{Plot, Scatter};
use tokenizers::Tokenizer;

fn encode(input: &str, tokenizer: &Tokenizer) -> Vec<u32> {
tokenizer.encode(input, true).unwrap().get_ids().to_vec()
let mut toks = tokenizer.encode(input, true).unwrap().get_ids().to_vec();
toks.push(tokenizer.token_to_id(EOS_TOKEN).unwrap());
toks
}

const EOS_TOKEN: &str = "</s>";

pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let tokenizer = common_args.tokenizer().unwrap();

Expand Down Expand Up @@ -70,14 +74,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {

let mut dataset: LLMDataset<u32> = LLMDataset::new(vec![], device);
dataset.add_line(encode(
"What is oxygen good for? Oxygen is good for breathing",
"What is oxygen good for? Oxygen is good for breathing.",
&tokenizer,
));
dataset.add_line(encode(
"Why are leaves beautiful? Leaves might be beautiful",
"Why are leaves beautiful? Leaves might be beautiful.",
&tokenizer,
));
dataset.add_line(encode("What is Kelvin? A unit of temperature", &tokenizer));
dataset.add_line(encode("What is Kelvin? A unit of temperature.", &tokenizer));

let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);

Expand Down

0 comments on commit ee375d0

Please sign in to comment.