diff --git a/qlora/qlora.py b/qlora/qlora.py index 30080f3..262114a 100644 --- a/qlora/qlora.py +++ b/qlora/qlora.py @@ -80,6 +80,12 @@ class ModelArguments: model_name_or_path: Optional[str] = field( default="EleutherAI/pythia-12b" ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + variant: Optional[str] = field( + default=None + ) trust_remote_code: Optional[bool] = field( default=False, metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} @@ -313,6 +319,7 @@ def get_accelerate_model(args, checkpoint_dir): cache_dir=args.cache_dir, load_in_4bit=args.bits == 4, load_in_8bit=args.bits == 8, + variant=args.variant, device_map=device_map, max_memory=max_memory, quantization_config=BitsAndBytesConfig( @@ -345,7 +352,7 @@ def get_accelerate_model(args, checkpoint_dir): # Tokenizer tokenizer = AutoTokenizer.from_pretrained( - args.model_name_or_path, + args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir, padding_side="right", use_fast=False, # Fast tokenizer giving issues. @@ -359,7 +366,8 @@ def get_accelerate_model(args, checkpoint_dir): tokenizer=tokenizer, model=model, ) - if 'llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer): + # if 'llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer): + if ('llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer)) and args.tokenizer_name != "novelai/nerdstash-tokenizer-v1": # LLaMA tokenizer may not have correct special tokens set. # Check and add them if missing to prevent them from being parsed into different tokens. # Note that these are present in the vocabulary.