In [None]:
import yaml

config = yaml.safe_load(open('config.yml'))

In [None]:
from moda.models import load_model

router_model, router_tokenizer = load_model(config["router"]["name"], config["router"]["load_in_8bit"])
router_terminators = [
    router_tokenizer.eos_token_id,
    router_tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

socket_model, socket_tokenizer = load_model(config["socket"]["name"], config["socket"]["load_in_8bit"])
socket_terminators = [
    socket_tokenizer.eos_token_id,
    socket_tokenizer.convert_tokens_to_ids("<|im_end|>"),
    socket_tokenizer.convert_tokens_to_ids("<|endoftext|>")
]

In [None]:
from moda.conversation import Conversation
from moda.functions import get_functions_metadata

functions_metadata = get_functions_metadata(config['functions'])
conversation = Conversation(functions=functions_metadata)

In [None]:
request = "Почему трава зелёная"
# request = "What is EBITDA?"

conversation.add_user_message(request)

conversation.messages

In [None]:
from moda.models import get_input_ids, trigger_model

router_input_ids = get_input_ids(router_tokenizer, conversation.messages).to(router_model.device)
print(router_input_ids)

router_result = trigger_model(router_input_ids, router_terminators, router_model, router_tokenizer)
print(router_result)

conversation.add_bot_message(router_result)

In [None]:
from moda.functions import has_function_call, extract_function_call_from_string
from moda.models import get_adapter, load_adapter

function_call = None
# If response has function call
if has_function_call(router_result):
    # Extract name of adapter from functino call
    function_call = extract_function_call_from_string(router_result)
    print(function_call)

    # Get details about adapter from library
    adapter = get_adapter(function_call["name"], config["functions"])
    # print(adapter)

    # Assign adapter to socket model
    adapter_model, adapter_tokenizer = load_adapter(socket_model, adapter["name"])
    # print(adapter_model, adapter_tokenizer)

    # Prepare messages to adapter model
    messages = [
        {"role": "system", "content": adapter["prompt"]},
        {"role": "user", "content": conversation.get_last_message()},
    ]

    # Send user question to model
    socket_input_ids = socket_tokenizer.apply_chat_template(
        messages,
        chat_template=adapter['chat_template'],
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(socket_model.device)
    print(socket_input_ids)

    socket_terminators = [
        socket_tokenizer.eos_token_id,
        socket_tokenizer.convert_tokens_to_ids("</s>"),
        socket_tokenizer.convert_tokens_to_ids("<|endoftext|>")
    ]

    socket_result = trigger_model(socket_input_ids, socket_terminators, adapter_model, adapter_tokenizer)
    print(socket_result)