In [None]:
from threading import Thread
import json
from stmpy import Machine, Driver
import paho.mqtt.client as mqtt
from utils import BROKER, PORT, TOPIC, QUEUE_TOPIC, HELP_TOPIC
from stmpy import Driver, Machine
from threading import Thread
import paho.mqtt.client as mqtt
import ipywidgets as widgets
from IPython.display import display
import random


class MQTT_Client_1:
    def __init__(self, teacher):
        self.teacher: Teacher = teacher
        self.count = 0
        self.client = mqtt.Client()
        self.client.on_connect = self.on_connect
        self.client.on_message = self.on_message
        self.stm_driver: Driver = None


    def on_connect(self, client, userdata, flags, rc):
        """Called upon connecting"""
        print(f"on_connect(): {mqtt.connack_string(rc)}")


    def on_message(self, client, userdata, msg):
        """Called when receiving a message"""
        # Decode Json-message and ignore non-json formatted messages.
        try:
            message: dict = json.loads(msg.payload.decode('utf-8'))
        except json.decoder.JSONDecodeError:
            print(f"=====\nWARNING: Received message with incorrect formating:\n{msg.payload}\nIgnoring message...\n=====")
            return
        
        if "msg" not in message.keys():
            print(f"=====\nWARNING: Json object does not contain the key 'msg':\n{message}\nIgnoring message...\n=====")
            return
                    
        print(f"on_message(): topic: {msg.topic}, msg: {message['msg']}")

        # Determine purpose of message and handle appropriately.
        # TODO: Determine if there are any differences between session_joined and session_created.
        if message["msg"] == "session_created":
            print(f"Received 'session_created'")
            self.teacher.session_id = message["session_id"]
            self.teacher.ta_code = message["ta_code"]
            self.teacher.student_code = message["student_code"]

            # Subscribe to new topics:
            self.subscribe_session_topics()

            self.stm_driver.send("session_created", "teacher")

        elif message["msg"] == "session_joined":
            print(f"Received 'session_joined'")
            self.teacher.session_id = message["session_id"]
            self.teacher.ta_code = message["ta_code"]
            self.teacher.student_code = message["student_code"]

            # Subscribe to new topics:
            self.subscribe_session_topics()

            self.stm_driver.send("session_created", "teacher")

        elif message["msg"] == "session_join_failed":
            self.teacher.error = message["error_message"]
            self.stm_driver.send("session_join_failed", "teacher")

        

    def subscribe_session_topics(self) -> None:
        """Subscribe to new topics after successfully joining a session"""
        self.client.unsubscribe(TOPIC)

        self.client.subscribe(f"{TOPIC}/{self.teacher.session_id}/{QUEUE_TOPIC}")
        self.client.subscribe(f"{TOPIC}/{self.teacher.session_id}/{HELP_TOPIC}")


    def start(self, broker, port):
        print("Connecting to {}:{}".format(broker, port))
        self.client.connect(broker, port)
        self.client.subscribe(TOPIC)

        try:
            thread = Thread(target=self.client.loop_forever)
            thread.start()
        except KeyboardInterrupt:
            print("Interrupted")
            self.client.disconnect()


class Teacher:
    def __init__(self):
        # After creating/joining a session, these should be set.
        self.session_id: int = None
        self.ta_code: int = None
        self.student_code: int = None

        self.mqtt_client: mqtt.Client = None
        self.stm: Machine = None
        
        self.error = None


    def __str__(self) -> str:
        return f"Session: {self.session_id}, TA-code: {self.ta_code}, Student-code: {self.student_code}"


    def create_session(self, b):
        """
        Create session. Send message to server and receive codes indicating the session
        has been created.
        """
        create_session = {"msg": "create_session"}
        self.mqtt_client.publish(TOPIC, json.dumps(create_session, indent=4))

        self.stm.send("start_session")
    

    def join_session(self, b):
        """Join an already existing session using a TA code."""
        self.stm.send("start_session")
        
        join_session = {"msg": "join_session", "ta_code": self.ta_code_field.value}
        self.mqtt_client.publish(TOPIC, json.dumps(join_session, indent=4))
        

    def in_idle(self):
        """Called upon entering idle-state."""
        print(f"In idle-state: {self}")
        self.button_create = widgets.Button(description="Create Session")
        self.button_create.on_click(self.create_session)
        self.button_join = widgets.Button(description="Join Session")
        self.button_join.on_click(self.join_session)
        self.ta_code_field = widgets.Text(value='', placeholder='', description='TA code:', disabled=False)

        error_field = widgets.Text(value=self.error, placeholder='', description='Error code:', disabled=True)

        display(self.button_create, self.ta_code_field, self.button_join, error_field)


    def in_wait(self):
        """Called upon entering wait-state."""
        print(f"In wait-state: {self}")
        return
    

    def in_lab_session_active(self):
        """Called upon entering lab_sesssion_active-state."""
        self.button_help = widgets.Button(description="Help")
        self.button_help.on_click(self.help)
        self.button_end = widgets.Button(description="End lab session")
        self.button_end.on_click(self.end_session)
        display(self.button_help, self.button_end)

        print(f"In lab_sesssion_active-state: {self}")


    def in_help(self):
        """Called upon entering help-state."""  
        self.button_help = widgets.Button(description="Finished Helping")
        self.button_help.on_click(self.finish_help)
        display(self.button_help)

        print(f"In help-state: {self}")


    def help(self, b):
        """Called when Help-button is pressed in lab_session_active"""
        # Inform server that group is receiving help.
        print(b)
        self.mqtt_client.publish(TOPIC, "helping")
        self.stm.send("help_group")


    def finish_help(self, b):
        """Called when Finished Help-button is pressed in help"""
        # Inform server that group has been helped successfully.
        self.mqtt_client.publish(TOPIC, "finished helping")
        self.stm.send("finish_help")


    def end_session(self, b):
        """Called when End Session-button is pressed in lab_session_active"""
        # Commit Sudoku
        self.stm.send("end_lab")
        

# initial transition
t0 = {'source': 'initial',
      'target': 'idle'}

# transitions
t1 = {'trigger':'start_session', 
      'source':'idle', 
      'target':'wait'}
t2 = {'trigger':'session_created', 
      'source':'wait', 
      'target':'lab_session_active'}
t3 = {'trigger':'session_join_failed', 
      'source':'wait', 
      'target':'idle'}

t4 = {'trigger':'help_group', 
      'source':'lab_session_active', 
      'target':'help'}
t5 = {'trigger':'finish_help', 
      'source':'help', 
      'target':'lab_session_active'}

t6 = {'trigger':'task_done', 
      'source':'lab_session_active', 
      'target':'lab_session_active'}
t7 = {'trigger':'update_queue', 
      'source':'lab_session_active', 
      'target':'lab_session_active'}

t8 = {'trigger':'end_lab', 
      'source':'lab_session_active', 
      'target':'exit'}
 
# the states:
idle = {'name': 'idle',
        'entry': 'in_idle'}

wait = {'name': 'wait',
        'entry': 'in_wait'}

lab_session_active = {'name': 'lab_session_active',
        'entry': 'in_lab_session_active'}

help = {'name': 'help',
        'entry': 'in_help'}



In [None]:
teacher = Teacher()
state_machine = Machine(transitions=[t0, t1, t2, t3, t4, t5, t6, t7, t8], states=[idle, wait, lab_session_active, help], obj=teacher, name="teacher")
teacher.stm = state_machine

driver = Driver()
driver.add_machine(state_machine)

myclient = MQTT_Client_1(teacher)
teacher.mqtt_client = myclient.client
myclient.stm_driver = driver

driver.start()
myclient.start(BROKER, PORT)