diff --git a/model/model_training/README.md b/model/model_training/README.md index 96781c4f4d..5823f513c5 100644 --- a/model/model_training/README.md +++ b/model/model_training/README.md @@ -85,6 +85,14 @@ To train using trlx try: python trainer_rl.py --configs defaults_rlhf ``` +## Test your model + +You can itneractively test your model like this: + +```bash +python tools/model_cli.py --model_path +``` + ## Model Normally you should be able to add new models in `configs/config.yml` diff --git a/model/model_training/models/__init__.py b/model/model_training/models/__init__.py index 510a273892..79e285f876 100644 --- a/model/model_training/models/__init__.py +++ b/model/model_training/models/__init__.py @@ -25,12 +25,12 @@ def freeze_top_n_layers(model, target_layers): return model -def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel): +def get_specific_model(model_name, seq2seqmodel=False, cache_dir=".cache", **kwargs): # encoder-decoder support for Flan-T5 like models # for now, we can use an argument but in the future, # we can automate this if seq2seqmodel: - model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir) + model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs) else: - model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs) return model diff --git a/model/model_training/tools/model_cli.py b/model/model_training/tools/model_cli.py new file mode 100644 index 0000000000..597f33dea2 --- /dev/null +++ b/model/model_training/tools/model_cli.py @@ -0,0 +1,77 @@ +import argparse +import os +import sys +import time + +import torch +import transformers + +if __name__ == "__main__": + import warnings + + warnings.filterwarnings("ignore") + + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair + from models import get_specific_model + from utils import _strtobool + + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--max_new_tokens", type=int, default=200) + parser.add_argument("--top_k", type=int, default=40) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--do-sample", type=_strtobool, default=True) + args = parser.parse_args() + + model = get_specific_model( + args.model_path, + load_in_8bit=True, + device_map="auto", + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + offload_state_dict=True, + ) + + model.gradient_checkpointing_enable() # reduce number of stored activations + tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_path) + + human_token_id = tokenizer.additional_special_tokens_ids[ + tokenizer.additional_special_tokens.index(QA_SPECIAL_TOKENS["Question"]) + ] + + print('Type "quit" to exit') + print("Press Control + C to restart conversation (spam to exit)") + + conversation_history = [] + + while True: + try: + user_input = input("User: ") + if user_input == "quit": + break + + conversation_history.append(user_input) + + batch = tokenizer.encode("".join(format_pair(conversation_history)), return_tensors="pt") + + with torch.cuda.amp.autocast(): + out = model.generate( + input_ids=batch.to(model.device), + max_new_tokens=args.max_new_tokens, + do_sample=True, + top_k=args.top_k, + temperature=args.temperature, + eos_token_id=human_token_id, + pad_token_id=tokenizer.eos_token_id, + ) + + response = tokenizer.decode(out[0]).split(QA_SPECIAL_TOKENS["Answer"])[-1] + print(f"Bot: {response}") + conversation_history.append(response) + except KeyboardInterrupt: + conversation_history = [] + print("Conversation restarted") + time.sleep(1) + continue