In [None]:
import dendron
from dendron.actions.causal_lm_action import CausalLMActionConfig, CausalLMAction
from dendron.controls import Sequence, Fallback
from dendron import NodeStatus

import torch
from piper import PiperVoice
import numpy as np
import sounddevice as sd

In [None]:
class GetTextInput(dendron.ActionNode):
    """
    PRE: None
    POST: 
        blackboard[latest_human_input_key] = input
        blackboard["in"] = chat
    """
    def __init__(self, latest_human_input_key = "latest_human_input"):
        super().__init__("get_text_input")
        self.latest_human_input_key = latest_human_input_key

    def tick(self):
        self.blackboard[self.latest_human_input_key] = input("Human: ")

        chat = self.blackboard["chat_history"]
        chat.append({"role": "GPT4 Correct User", "content" : self.blackboard[self.latest_human_input_key]})
        self.blackboard["in"] = chat

        return NodeStatus.SUCCESS

In [None]:
class MoreToSay(dendron.ConditionNode):
    def __init__(self, speech_input_key="speech_in"):
        super().__init__("more_to_say")
        self.speech_input_key = speech_input_key

    def tick(self):
        if self.blackboard[self.speech_input_key] != []:
            return dendron.NodeStatus.SUCCESS
        else:
            return dendron.NodeStatus.FAILURE

In [None]:
class TTSAction(dendron.ActionNode):
    def __init__(self, name):
        super().__init__(name)
        self.voice = PiperVoice.load("en_US-danny-low.onnx", config_path="en_US-danny-low.onnx.json", use_cuda=False)
        
    def tick(self):
        try:
            input_text = self.blackboard["speech_in"].pop() 
            self.blackboard["speech_out"] = self.voice.synthesize_stream_raw("\t" + input_text, sentence_silence=0.1)
        except Exception as e:
            print("Speech generation exception: ", e)
            return dendron.NodeStatus.FAILURE

        return dendron.NodeStatus.SUCCESS

def play_speech(self):
    audio_stream = self.blackboard["speech_out"]
    for sent in audio_stream:
        audio = np.frombuffer(sent, dtype=np.int16)
        a = (audio - 32768) / 65536
        sd.play(a, 16000)
        sd.wait()

In [None]:
speech_node = TTSAction("speech_node")
speech_node.add_post_tick(play_speech)

speech_seq = Sequence("speech_seq", [
    MoreToSay(),
    speech_node
])

In [None]:
class TimeToThink(dendron.ConditionNode):
    """
    PRE:
        blackboard[human_input_key] should be set
    POST:
    """
    def __init__(self, human_input_key = "latest_human_input"):
        super().__init__("time_to_think")
        self.human_input_key = human_input_key
        self.last_human_input = ""

    def tick(self):
        human_input = self.blackboard[self.human_input_key]
        if self.last_human_input == human_input:
            status = NodeStatus.FAILURE
        else:
            status = NodeStatus.SUCCESS

        self.last_human_input = human_input
        return status

In [None]:
chat_behavior_cfg = CausalLMActionConfig(load_in_4bit=True,
                                         max_new_tokens=128,
                                         do_sample=True,
                                         top_p=0.95,
                                         use_flash_attn_2=True,
                                         model_name='openchat/openchat-3.5-0106')

chat_node = CausalLMAction('chat_node', chat_behavior_cfg)

def chat_to_str(self, chat):
    return self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

def str_to_chat(self, str):
    key = "GPT4 Correct Assistant:"
    idx = str.rfind(key)
    response = str[idx+len(key):]
    chat = self.blackboard[self.input_key]
    chat.append({"role" : "GPT4 Correct Assistant", "content" : response})
    return chat

def set_next_speech(self):
    text_output = self.blackboard["out"][-1]["content"]
    self.blackboard["speech_in"].append(text_output)

chat_node.set_input_processor(chat_to_str)
chat_node.set_output_processor(str_to_chat)
chat_node.add_post_tick(set_next_speech)

In [None]:
thought_seq = Sequence("thought_seq", [
    TimeToThink(),
    chat_node
])

In [None]:
root_node = Fallback("conversation_turn", [
    speech_seq,
    thought_seq,                
    GetTextInput()
])
tree = dendron.BehaviorTree("talker_tree", root_node)

In [None]:
tree.blackboard["chat_history"] = []
tree.blackboard["speech_in"] = []
tree.blackboard["latest_human_input"] = ""

while True:
    tree.tick_once()