In [22]:
import yaml

config = yaml.safe_load(open('config.yml'))

In [23]:
from moda.models import load_model

router_model, router_tokenizer = load_model(config["router"]["name"], config["router"]["load_in_8bit"])
router_terminators = [
    router_tokenizer.eos_token_id,
    router_tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

socket_model, socket_tokenizer = load_model(config["socket"]["name"], config["socket"]["load_in_8bit"])
socket_terminators = [
    socket_tokenizer.eos_token_id,
    socket_tokenizer.convert_tokens_to_ids("<|im_end|>"),
    socket_tokenizer.convert_tokens_to_ids("<|endoftext|>")
]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [24]:
from moda.conversation import Conversation
from moda.functions import get_functions_metadata

functions_metadata = get_functions_metadata(config['functions'])
conversation = Conversation(functions=functions_metadata)

In [25]:
request = "Почему трава зелёная"
# request = "What is EBITDA?"

conversation.add_user_message(request)

conversation.messages

[{'role': 'system',
  'content': 'You are a helpful assistant with access to the following functions: \n [{\'type\': \'function\', \'function\': {\'name\': \'IlyaGusev/saiga2_7b_lora\', \'description\': \'This is the Saiga adapter trained on multiple russian datasets.\\nUse this adapter to answer questions on Russian.\\n\', \'parameters\': {\'type\': \'object\', \'properties\': {\'question\': {\'type\': \'string\', \'description\': \'Text of question which user asked\'}}, \'required\': [\'question\']}}}, {\'type\': \'function\', \'function\': {\'name\': \'briefai/LongShort-Llama-2-7B\', \'description\': \'Is a large language model fine-tuned on earnings call documents to extract financial KPIs from the earnings call documents.\\nUse this adapter when users need to analyze financial data.\\n\', \'parameters\': {\'type\': \'object\', \'properties\': {\'question\': {\'type\': \'string\', \'description\': \'Text of question which user asked\'}}, \'required\': [\'question\']}}}]\n\nTo use t

In [26]:
from moda.models import get_input_ids, trigger_model

router_input_ids = get_input_ids(router_tokenizer, conversation.messages).to(router_model.device)
print(router_input_ids)

router_result = trigger_model(router_input_ids, router_terminators, router_model, router_tokenizer)
print(router_result)

conversation.add_bot_message(router_result)

tensor([[128000, 128006,   9125, 128007,    271,   2675,    527,    264,  11190,
          18328,    449,   2680,    311,    279,   2768,   5865,     25,    720,
          62208,   1337,   1232,    364,   1723,    518,    364,   1723,   1232,
           5473,    609,   1232,    364,     40,  97199,     38,    817,     85,
           2754,     64,  16960,     17,     62,     22,     65,    918,   6347,
            518,    364,   4789,   1232,    364,   2028,    374,    279,  16233,
          16960,  13253,  16572,    389,   5361,  64245,  30525,   7255,     77,
          10464,    420,  13253,    311,   4320,   4860,    389,   8690,   7255,
             77,    518,    364,  14105,   1232,   5473,   1337,   1232,    364,
           1735,    518,    364,  13495,   1232,   5473,   7998,   1232,   5473,
           1337,   1232,    364,    928,    518,    364,   4789,   1232,    364,
           1199,    315,   3488,    902,   1217,   4691,   8439,   2186,    364,
           6413,   1232,   2

In [27]:
from moda.functions import has_function_call, extract_function_call_from_string
from moda.models import get_adapter, load_adapter

function_call = None
# If response has function call
if has_function_call(router_result):
    # Extract name of adapter from functino call
    function_call = extract_function_call_from_string(router_result)
    print(function_call)

    # Get details about adapter from library
    adapter = get_adapter(function_call["name"], config["functions"])
    # print(adapter)

    # Assign adapter to socket model
    adapter_model, adapter_tokenizer = load_adapter(socket_model, adapter["name"])
    # print(adapter_model, adapter_tokenizer)

    # Prepare messages to adapter model
    messages = [
        {"role": "system", "content": adapter["prompt"]},
        {"role": "user", "content": conversation.get_last_message()},
    ]

    # Send user question to model
    socket_input_ids = socket_tokenizer.apply_chat_template(
        messages,
        chat_template=adapter['chat_template'],
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(socket_model.device)
    print(socket_input_ids)

    socket_terminators = [
        socket_tokenizer.eos_token_id,
        socket_tokenizer.convert_tokens_to_ids("</s>"),
        socket_tokenizer.convert_tokens_to_ids("<|endoftext|>")
    ]

    socket_result = trigger_model(socket_input_ids, socket_terminators, adapter_model, adapter_tokenizer)
    print(socket_result)

{'name': 'IlyaGusev/saiga2_7b_lora'}
tensor([[    1,  5205,    13,  3492,   263,  8444, 20255,   363, 22862,  5155,
           373, 10637, 29889,    13,     2,    13,     1,  1792,    13, 20093,
          1093,  1805, 12550,   846, 24263, 13459,  3162,     2,    13,     1,
           465, 22137,    13]], device='cuda:0')
Зелёная трава - это символ весны, пробуждения природы, зеленого роста. Зелёный цвет ассоциируется с радостью, благополучием, здоровьем.
