In [None]:
import dendron
from dendron.actions.causal_lm_action import CausalLMActionConfig, CausalLMAction
from dendron.conditions.completion_condition import CompletionConditionConfig, CompletionCondition
from dendron.controls import Sequence, Fallback
from dendron import NodeStatus

import torch
from piper import PiperVoice
from spacy.lang.en import English 

import numpy as np
import sounddevice as sd

In [None]:
class MoreToSay(dendron.ConditionNode):
    def __init__(self, speech_input_key="speech_in"):
        super().__init__("more_to_say")
        self.speech_input_key = speech_input_key

    def tick(self):
        if self.blackboard[self.speech_input_key] != []:
            self.blackboard["turn_done"] = False
            return dendron.NodeStatus.SUCCESS
        else:
            self.blackboard["turn_done"] = True
            return dendron.NodeStatus.FAILURE

class TimeToThink(dendron.ConditionNode):
    """
    PRE:
        blackboard[human_input_key] should be set
    POST:
    """
    def __init__(self, human_input_key = "latest_human_input"):
        super().__init__("time_to_think")
        self.human_input_key = human_input_key
        self.last_human_input = None

    def tick(self):
        human_input = self.blackboard[self.human_input_key]
        if self.last_human_input is not None and human_input != self.last_human_input:
            status = NodeStatus.SUCCESS
        else:
            status = NodeStatus.FAILURE

        self.last_human_input = human_input
        return status

class GetTextInput(dendron.ActionNode):
    """
    PRE: None
    POST: 
        blackboard[latest_human_input_key] = input
        blackboard["in"] = chat
    """
    def __init__(self, latest_human_input_key = "latest_human_input"):
        super().__init__("get_text_input")
        self.latest_human_input_key = latest_human_input_key

    def tick(self):        
        self.blackboard[self.latest_human_input_key] = input("Human: ")

        chat = self.blackboard["chat_history"]
        chat.append({"role": "GPT4 Correct User", "content" : self.blackboard[self.latest_human_input_key]})
        self.blackboard["in"] = chat

        return NodeStatus.SUCCESS

In [None]:
class TTSAction(dendron.ActionNode):
    def __init__(self, name):
        super().__init__(name)
        self.voice = PiperVoice.load("en_US-danny-low.onnx", config_path="en_US-danny-low.onnx.json", use_cuda=False)
        
    def tick(self):
        try:
            input_text = self.blackboard["speech_in"]            
            self.blackboard["speech_out"] = [list(self.voice.synthesize_stream_raw(x, sentence_silence=0.1))[0] for x in input_text]
            self.blackboard["speech_in"] = []
        except Exception as e:
            print("Speech generation exception: ", e)
            return dendron.NodeStatus.FAILURE

        return dendron.NodeStatus.SUCCESS

def play_speech(self):
    num_utterances = len(self.blackboard["speech_out"])

    for i in range(num_utterances):
        audio = np.frombuffer(self.blackboard["speech_out"][i], dtype=np.int16)
        a = (audio - 32768) / 65536
        sd.play(a, 16000)
        sd.wait()

In [None]:
class SentenceSplitter(dendron.ActionNode):
    def __init__(self, in_key="speech_in"):
        super().__init__("sentence_splitter")
        self.in_key = in_key
        self.splitter = English()
        self.splitter.add_pipe("sentencizer")

    def tick(self):
        latest_text = self.blackboard[self.in_key].pop()
        if len(latest_text) > 64:
            sentences = self.splitter(latest_text).sents
            for s in sentences:
                s_prime = str(s).strip()
                if len(s_prime) > 0:
                    self.blackboard[self.in_key].append(s_prime)
        else:
            self.blackboard[self.in_key].append(latest_text)
        return NodeStatus.SUCCESS

In [None]:
speech_node = TTSAction("speech_node")
speech_node.add_post_tick(play_speech)

speech_seq = Sequence("speech_seq", [
    MoreToSay(),
    speech_node
])

In [None]:
chat_behavior_cfg = CausalLMActionConfig(load_in_4bit=True,
                                         max_new_tokens=256,
                                         do_sample=True,
                                         top_p=0.95,
                                         use_flash_attn_2=True,
                                         model_name='openchat/openchat-3.5-0106')

chat_node = CausalLMAction('chat_node', chat_behavior_cfg)

def chat_to_str(self, chat):
    return self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

def str_to_chat(self, str):
    key = "GPT4 Correct Assistant:"
    idx = str.rfind(key)
    response = str[idx+len(key):]
    chat = self.blackboard[self.input_key]
    chat.append({"role" : "GPT4 Correct Assistant", "content" : response})
    return chat

def set_next_speech(self):
    text_output = self.blackboard["out"][-1]["content"]
    self.blackboard["speech_in"].append(text_output)

chat_node.set_input_processor(chat_to_str)
chat_node.set_output_processor(str_to_chat)
chat_node.add_post_tick(set_next_speech)

thought_seq = Sequence("thought_seq", [
    TimeToThink(),
    chat_node,
    SentenceSplitter()
])

In [None]:
conversation_turn = Fallback("conversation_turn", [
    speech_seq,
    thought_seq,
    GetTextInput()
])

In [None]:
farewell_classifier_cfg = CompletionConditionConfig(
    input_key = "farewell_test_in",
    load_in_4bit=True,
    model_name='mlabonne/Monarch-7B',
    use_flash_attn_2=True
)

farewell_classification_node = CompletionCondition("farewell_classifier", farewell_classifier_cfg)

def farewell_success_fn(completion):
    """
    Return SUCCESS if the conversation is done.
    """
    if completion == "yes":
        return NodeStatus.SUCCESS
    else:
        return NodeStatus.FAILURE

def farewell_pretick(self):
    last_input = self.blackboard["latest_human_input"]
    #print(f"last_input = {last_input}")
    if last_input == "None":
        last_input = ""
    chat = [{"role": "user", "content": f"""The last thing the human said was "{last_input}". Has the user said something to indicate that the conversation is over? Only say yes if the user has said something. Ignore blank inputs."""}]
    self.blackboard[self.input_key] = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

farewell_classification_node.add_pre_tick(farewell_pretick)

def farewell_posttick(self):
    if self.status == NodeStatus.SUCCESS:
        self.blackboard["all_done"] = True

farewell_classification_node.add_post_tick(farewell_posttick)

In [None]:
class SayGoodbye(dendron.ActionNode):
    def __init__(self):
        super().__init__("say_goodbye")

    def tick(self):
        if self.blackboard["all_done"]:
            self.blackboard["speech_in"].append("Goodbye!")
            return NodeStatus.SUCCESS

In [None]:
goodbye_test = Sequence("goodbye_test", [
    farewell_classification_node,
    SayGoodbye(), 
    speech_node
])

In [None]:
root_node = Fallback("converse", [
    goodbye_test,
    conversation_turn,
])

tree = dendron.BehaviorTree("chat_tree", root_node)

In [None]:
tree.blackboard["chat_history"] = []
tree.blackboard["speech_in"] = []

tree.blackboard.register_entry(dendron.blackboard.BlackboardEntryMetadata(
    key = "latest_human_input",
    description = "The last thing the human said.",
    type_constructor = str
))
tree.blackboard["latest_human_input"] = None

tree.blackboard["completions_in"] = ["yes", "no"]
tree.blackboard["success_fn"] = farewell_success_fn
tree.blackboard["all_done"] = False

In [None]:
while not tree.blackboard["all_done"]:
    tree.tick_once()