Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test model through CLI #1554

Merged
merged 1 commit into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions model/model_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ To train using trlx try:
python trainer_rl.py --configs defaults_rlhf
```

## Test your model

You can itneractively test your model like this:

```bash
python tools/model_cli.py --model_path <saved_path/huggingface>
```

## Model

Normally you should be able to add new models in `configs/config.yml`
Expand Down
6 changes: 3 additions & 3 deletions model/model_training/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def freeze_top_n_layers(model, target_layers):
return model


def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel):
def get_specific_model(model_name, seq2seqmodel=False, cache_dir=".cache", **kwargs):
# encoder-decoder support for Flan-T5 like models
# for now, we can use an argument but in the future,
# we can automate this
if seq2seqmodel:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir)
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
else:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
return model
77 changes: 77 additions & 0 deletions model/model_training/tools/model_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse
import os
import sys
import time

import torch
import transformers

if __name__ == "__main__":
import warnings

warnings.filterwarnings("ignore")

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
from models import get_specific_model
from utils import _strtobool

parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--max_new_tokens", type=int, default=200)
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--do-sample", type=_strtobool, default=True)
args = parser.parse_args()

model = get_specific_model(
args.model_path,
load_in_8bit=True,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
offload_state_dict=True,
)

model.gradient_checkpointing_enable() # reduce number of stored activations
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_path)

human_token_id = tokenizer.additional_special_tokens_ids[
tokenizer.additional_special_tokens.index(QA_SPECIAL_TOKENS["Question"])
]

print('Type "quit" to exit')
print("Press Control + C to restart conversation (spam to exit)")

conversation_history = []

while True:
try:
user_input = input("User: ")
if user_input == "quit":
break

conversation_history.append(user_input)

batch = tokenizer.encode("".join(format_pair(conversation_history)), return_tensors="pt")

with torch.cuda.amp.autocast():
out = model.generate(
input_ids=batch.to(model.device),
max_new_tokens=args.max_new_tokens,
do_sample=True,
top_k=args.top_k,
temperature=args.temperature,
eos_token_id=human_token_id,
pad_token_id=tokenizer.eos_token_id,
)

response = tokenizer.decode(out[0]).split(QA_SPECIAL_TOKENS["Answer"])[-1]
print(f"Bot: {response}")
conversation_history.append(response)
except KeyboardInterrupt:
conversation_history = []
print("Conversation restarted")
time.sleep(1)
continue