Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tokenizer_name #2

Merged
merged 5 commits into from
Oct 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions qlora/qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default="EleutherAI/pythia-12b"
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
variant: Optional[str] = field(
default=None
)
trust_remote_code: Optional[bool] = field(
default=False,
metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
Expand Down Expand Up @@ -313,6 +319,7 @@ def get_accelerate_model(args, checkpoint_dir):
cache_dir=args.cache_dir,
load_in_4bit=args.bits == 4,
load_in_8bit=args.bits == 8,
variant=args.variant,
device_map=device_map,
max_memory=max_memory,
quantization_config=BitsAndBytesConfig(
Expand Down Expand Up @@ -345,7 +352,7 @@ def get_accelerate_model(args, checkpoint_dir):

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
cache_dir=args.cache_dir,
padding_side="right",
use_fast=False, # Fast tokenizer giving issues.
Expand All @@ -359,7 +366,8 @@ def get_accelerate_model(args, checkpoint_dir):
tokenizer=tokenizer,
model=model,
)
if 'llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer):
# if 'llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer):
if ('llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer)) and args.tokenizer_name != "novelai/nerdstash-tokenizer-v1":
# LLaMA tokenizer may not have correct special tokens set.
# Check and add them if missing to prevent them from being parsed into different tokens.
# Note that these are present in the vocabulary.
Expand Down