Skip to content

Commit

Permalink
Mixtral cache
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Feb 11, 2024
1 parent 84f4b31 commit 5e95d04
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.4.1",
version="1.4.2",
author="artitw",
author_email="artitw@gmail.com",
description="Text2Text: Crosslingual NLP/G toolkit",
Expand All @@ -21,7 +21,6 @@
keywords='multilingual crosslingual gpt chatgpt bert natural language processing nlp nlg text generation gpt question answer answering information retrieval tfidf tf-idf bm25 search index summary summarizer summarization tokenizer tokenization translation backtranslation data augmentation science machine learning colab embedding levenshtein sub-word edit distance conversational dialog chatbot mixtral',
install_requires=[
'accelerate',
'auto-gptq',
'bitsandbytes',
'peft',
'faiss-cpu',
Expand All @@ -38,7 +37,6 @@
'tqdm',
'transformers',
'hqq',
'tqdm',
'huggingface_hub'
],
)
36 changes: 20 additions & 16 deletions text2text/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,17 @@ def chat_completion(self, messages=[{"role": "user", "content": "hello"}], strea

input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)

input_string = tokenizer.apply_chat_template(messages, tokenize=False)

past_key_values = cache.get(input_string, None)
if past_key_values:
seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)
else:
attention_mask = None
past_key_values = None
for i in range(1,len(messages)):
past_input_string = tokenizer.apply_chat_template(messages[:-i], tokenize=False)
past_key_values = cache.get(past_input_string, None)
if past_key_values:
seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)
break

if attention_mask == None:
attention_mask = torch.ones_like(input_ids)

results = model.generate(
Expand All @@ -104,18 +108,18 @@ def chat_completion(self, messages=[{"role": "user", "content": "hello"}], strea
output_hidden_states=False,
)

cache[input_string] = results["past_key_values"]

results = tokenizer.batch_decode(**results)[0]

return {
output_string = tokenizer.batch_decode(**results)[0]
input_string = tokenizer.apply_chat_template(messages, tokenize=False)
messages.append({
"role": "assistant",
"content": _clean_output(input_string, results)
}
"content": _clean_output(input_string, output_string)
})
cache_string = tokenizer.apply_chat_template(messages, tokenize=False)
self.__class__.cache[cache_string] = results["past_key_values"]

return results
return messages[-1]

def transform(self, input_lines, src_lang='en', **kwargs):
return self.chat_completion([{"role": "user", "content": input_lines}])
return self.chat_completion([{"role": "user", "content": input_lines}])["content"]

completion = transform

0 comments on commit 5e95d04

Please sign in to comment.