In [None]:
import dendron
from dendron.actions.causal_lm_action import CausalLMActionConfig, CausalLMAction

In [None]:
import torch
from transformers import BarkModel, BarkProcessor
from optimum.bettertransformer import BetterTransformer
import sounddevice as sd

In [None]:
class TTSAction(dendron.ActionNode):
    def __init__(self, name):
        super().__init__(name)
        self.processor = BarkProcessor.from_pretrained("suno/bark-small")
        self.model = BarkModel.from_pretrained("suno/bark-small").to("cuda")
        self.model = BetterTransformer.transform(self.model, keep_original_model=False)
        self.model.enable_cpu_offload()

    def tick(self):
        try:
            input_text = self.blackboard["speech_in"]
            inputs = self.processor(text=input_text, voice_preset="v2/en_speaker_9", return_tensors="pt").to("cuda")
            self.blackboard["speech_out"] = self.model.generate(**inputs).cpu().numpy()
        except Exception as e:
            print("Speech generation exception: ", e)
            return dendron.NodeStatus.FAILURE

        return dendron.NodeStatus.SUCCESS

In [None]:
def play_speech(self):
    sd.play(self.blackboard["speech_out"][0], self.model.generation_config.sample_rate)
    sd.wait()

In [None]:
speech_node = TTSAction("speech_node")
speech_node.add_post_tick(play_speech)

In [None]:
chat_behavior_cfg = CausalLMActionConfig(load_in_4bit=True,
                                         max_new_tokens=128,
                                         do_sample=True,
                                         top_p=0.95,
                                         use_flash_attn_2=True,
                                         model_name='openchat/openchat_3.5')

chat_node = CausalLMAction('chat_node', chat_behavior_cfg)

def chat_to_str(self, chat):
    return self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

def str_to_chat(self, str):
    key = "GPT4 Correct Assistant:"
    idx = str.rfind(key)
    response = str[idx+len(key):]
    chat = self.blackboard[self.input_key]
    chat.append({"role" : "GPT4 Correct Assistant", "content" : response})
    return chat

chat_node.set_input_processor(chat_to_str)
chat_node.set_output_processor(str_to_chat)

In [None]:
def set_next_speech(self):
    text_output = self.blackboard["out"][-1]["content"]
    self.blackboard["speech_in"] = " " + text_output

chat_node.add_post_tick(set_next_speech)

In [None]:
root_node = dendron.controls.Sequence("think_then_talk", [
    chat_node,
    speech_node
])

tree = dendron.BehaviorTree("talker_tree", root_node)

In [None]:
chat = []

while True:
    input_str = input("Input: ")
    chat.append({"role": "GPT4 Correct User", "content" : input_str})
    tree.blackboard["in"] = chat
    tree.tick_once()
    print("Output: ", tree.blackboard["out"][-1]["content"])
    if "Goodbye" in tree.blackboard["out"][-1]["content"]:
        break