In [4]:
from threading import Thread
import json
import random
import paho.mqtt.client as mqtt
from utils import BROKER, PORT, TOPIC, JOIN_TOPIC, HELP_TOPIC, QUEUE_TOPIC, UPDATE_TOPIC, PROGRESS_TOPIC
from threading import Thread
import json
import datetime

class SessionInstance:
    """
    Session server implementation:
    Handles a single lab session.
    """

    def __init__(self, id: int, ta_code: int, student_code: int):
        self.id: int = id
        self.ta_code: int = ta_code
        self.student_code: int = student_code
        self.mqtt_client = MQTT_Session_Server(self, self.id)

        self.queue: list[str] = [] # ["Arne", "Peder", "Jonny"] # Students that have requested help
        self.teachers: list[str] = [] # Connected teachers
        self.students: list[str] = [] # Connected students
        self.progress: dict = {} # {"Arne": 3, "Peder": 1, "Jonny": 4} # Connected students and their progress
        # TODO: students and progress should be the same object, but whatever


    def student_leave(self, group_name: str) -> None:
        """Called when a student leaves the session."""
        if group_name not in self.students:
            return
        
        self.students.remove(group_name)

        del self.progress[group_name]
        self.mqtt_client.send_progress_update()

        if group_name in self.queue:
            self.queue.remove(group_name)
            self.mqtt_client.send_queue_update()

        self.assert_session_alive()


    def teacher_leave(self, ta_name: str) -> None:
        """Called when a teacher leaves the session."""
        if ta_name not in self.teachers:
            return
        
        self.teachers.remove(ta_name)
        self.assert_session_alive()


    def assert_session_alive(self):
        """If there are no teachers or students, kill the session."""
        if len(self.teachers) > 0 or len(self.students):
            return
        
        print("Session is empty and can end.")
        # TODO: Kill the session.
        del self

    def add_progress(self, group_name: str) -> None:
        """Add group to progress-dict"""
        if group_name in self.progress.keys():
            return
        
        self.progress[group_name] = 1
        self.mqtt_client.send_progress_update()
        
    
    def increment_progress(self, group_name: str, question: int) -> None:
        """Increment the progress of a group"""
        if group_name not in self.progress.keys():
            return
        
        self.progress[group_name] = question + 1
        self.mqtt_client.send_progress_update()


    def add_queue(self, student_name: str, question: str) -> None:
        """Add student to queue"""
        if student_name not in self.queue:
            self.queue.append(student_name)


    def pop_queue(self, student_name: str) -> None:
        """Remove student from queue"""
        self.queue.remove(student_name)

    
    def __str__(self) -> str:
        return f"Session: {self.id}, TA-code: {self.ta_code}, Student-code: {self.student_code}"
 

class Server:
    """
    Main server implementation: 
    Has it's own mqtt-client which is always active. This server is only used for creation and joining
    of session, after which, clients are passed on to SessionInstances.
    """

    def __init__(self):
        self.main_mqtt_client: mqtt.Client = None
        self.sessions: list[SessionInstance] = []

        # Entry codes:
        self.student_code_list: list[int] = []
        self.ta_code_list: list[int] = []


    def create_session(self, ta_name) -> None:
        """Generate codes, create new mqtt client, and send success back to teacher."""
        ta_code = self.generate_access_code()
        student_code = self.generate_access_code()

        # Add codes to global list.
        self.student_code_list.append(student_code)
        self.ta_code_list.append(ta_code)

        # Create instance of session and add to global list.
        instance = SessionInstance(len(self.sessions), ta_code, student_code)
        instance.teachers.append(ta_name)
        self.sessions.append(instance)

        # Start session mqtt server.
        instance.mqtt_client.start(BROKER, PORT)
        instance.mqtt_client.send_queue_update()
        instance.mqtt_client.send_progress_update()

        # Inform teacher that session has been created.
        session_created = {"msg": "session_created", "session_id": instance.id, "ta_code": ta_code, "student_code": student_code, "ta_name": ta_name}
        self.main_mqtt_client.publish(f"{TOPIC}/{JOIN_TOPIC}", json.dumps(session_created, indent=4))

    
    def join_session_ta(self, code, ta_name) -> None:
        """Ensure ta code is correct and send success back to client."""
        # If code is wrong inform teacher:
        # TODO: Some sort of identification is necessary to avoid race conditions when two users
        # try to join at the same time. For students we have group name, need similar thing here.
        code = int(code)
        if code not in self.ta_code_list:
            session_join_failed = {"msg": "session_join_failed", "error_message": f"{code} is an incorrect code.", "ta_name": ta_name}
            self.main_mqtt_client.publish(f"{TOPIC}/{JOIN_TOPIC}", json.dumps(session_join_failed, indent=4))
            return
        
        # Retrieve session instance and sanity check
        index = self.ta_code_list.index(code)
        instance = self.sessions[index]
        assert instance.ta_code == code or instance.student_code == code, f"{code} does not belong to {instance}"

        # If group name is taken, inform teacher. 
        if ta_name in instance.teachers:
            session_join_failed = {"msg": "session_join_failed", "error_message": f"{ta_name} is already taken.", "ta_name": ta_name}
            self.main_mqtt_client.publish(f"{TOPIC}/{JOIN_TOPIC}", json.dumps(session_join_failed, indent=4))
            return

        instance.teachers.append(ta_name)

        # Inform teacher that session has been joined and of session_id.
        session_joined = {"msg": "session_joined", "session_id": instance.id, "ta_code": instance.ta_code, "student_code": instance.student_code, "ta_name": ta_name}
        self.main_mqtt_client.publish(f"{TOPIC}/{JOIN_TOPIC}", json.dumps(session_joined, indent=4))


    def join_session_student(self, code, group_name) -> None:
        """Ensure student code is correct and send success back to client."""
        # If code is wrong inform student:
        code = int(code)
        if code not in self.student_code_list:
            session_join_failed = {"msg": "session_join_failed", "error_message": f"{code} is an incorrect code.", "group_name": group_name}
            self.main_mqtt_client.publish(f"{TOPIC}/{JOIN_TOPIC}", json.dumps(session_join_failed, indent=4))
            return

        # Retrieve session instance and sanity check
        index = self.student_code_list.index(code)
        instance = self.sessions[index]
        assert instance.ta_code == code or instance.student_code == code, f"{code} does not belong to {instance}"

        # If group name is taken, inform students. 
        if group_name in instance.students:
            session_join_failed = {"msg": "session_join_failed", "error_message": f"{group_name} is already taken.", "group_name": group_name}
            self.main_mqtt_client.publish(f"{TOPIC}/{JOIN_TOPIC}", json.dumps(session_join_failed, indent=4))
            return
        
        instance.students.append(group_name)
        instance.add_progress(group_name)

        # Inform student that session has been joined and of session_id.
        session_joined = {"msg": "session_joined", "session_id": instance.id, "group_name": group_name}
        self.main_mqtt_client.publish(f"{TOPIC}/{JOIN_TOPIC}", json.dumps(session_joined, indent=4))


    def generate_access_code(self) -> int:
        """Generate random student/ta code"""
        code = random.randint(0,99) 
        while code in self.student_code_list or code in self.ta_code_list:
            code = random.randint(0,99)
        return code



class MQTT_Session_Server:
    """
    Session server handles the communication within a single lab session. Handles communications
    in the team02/<session_id>/# topics.
    """
    
    def __init__(self, session, id):
        self.session: SessionInstance = session
        self.session_id: int = id
        self.count = 0
        self.client = mqtt.Client()
        self.client.on_connect = self.on_connect
        self.client.on_message = self.on_message


    def on_connect(self, client, userdata, flags, rc):
        print(f"on_connect(): {mqtt.connack_string(rc)}")


    def on_message(self, client, userdata, msg):
        # Decode Json-message and ignore non-json formatted messages.
        try:
            message = json.loads(msg.payload.decode('utf-8'))
        except json.decoder.JSONDecodeError:
            print(f"=====\nWARNING: Received message with incorrect formating:\n{msg.payload}\nIgnoring message...\n=====")
            return
        
        if "msg" not in message.keys():
            print(f"=====\nWARNING: Json object does not contain the key 'msg':\n{message}\nIgnoring message...\n=====")
            return
        
        print(f"Session {self.session_id}: on_message(): topic: {msg.topic}, msg: {message['msg']}")

        if msg.topic == f"{TOPIC}/{self.session_id}/{HELP_TOPIC}":
            self.help_handler(message)
        elif msg.topic == f"{TOPIC}/{self.session_id}/{UPDATE_TOPIC}":
            self.update_handler(message)

        
    def start(self, broker, port):
        print("Connecting to {}:{}".format(broker, port))

        self.client.connect(broker, port)
        self.client.subscribe(f"{TOPIC}/{self.session_id}/#")

        try:
            thread = Thread(target=self.client.loop_forever)
            thread.start()
        except KeyboardInterrupt:
            print("Interrupted")
            self.client.disconnect()


    def send_queue_update(self):
        """
        Should be called every time the queue is changed to update subscribers about
        the new queue. This message should be retained.
        """
        queue_update = {"msg": "queue_update", "queue": self.session.queue}
        self.client.publish(f"{TOPIC}/{self.session_id}/{QUEUE_TOPIC}", json.dumps(queue_update, indent=4), retain=True)


    def send_progress_update(self):
        """
        Should be called every time the progress of each group is updated. This message
        should be retained.
        """
        progress_update = {"msg": "progress_update", "progress": self.session.progress}
        self.client.publish(f"{TOPIC}/{self.session_id}/{PROGRESS_TOPIC}", json.dumps(progress_update, indent=4), retain=True)


    def help_handler(self, message: dict):
        """Handle messages in the help-topic"""

        print(message)
        if message["msg"] == "request_help":
            self.session.add_queue(message["group_name"], message["question"])
            self.send_queue_update()

        elif message["msg"] == "provide_help":
            self.session.pop_queue(message["group_name"])
            self.send_queue_update()
            # TODO: Do we want to inform the group that they are about to receive help?


    def update_handler(self, message: dict):
        """Handle message in the update-topic"""
        if message["msg"] == "task_finished":
            self.session.increment_progress(message["group_name"], int(message["question"]))
            with open(f"./task_stats/{self.session_id}_task_stats.txt", "a") as infile:
                infile.write(f"{message['msg']}, {message['question']}, {message['group_name']}, {datetime.datetime.now()}\n")
        
        if message["msg"] == "leave_session":
            if "group_name" in message.keys():
                self.session.student_leave(message["group_name"])
            elif "ta_name" in message.keys():
                self.session.teacher_leave(message["ta_name"])
                
            

class MQTT_Main_Server:
    """
    The main server handles the creation of session and allows clients to join sessions.
    For that purpose is it only subscribed to the JOIN_TOPIC, to receive and send messages.
    """

    def __init__(self, server):
        self.server: Server = server
        self.count = 0
        self.client = mqtt.Client()
        self.client.on_connect = self.on_connect
        self.client.on_message = self.on_message


    def on_connect(self, client, userdata, flags, rc):
        print(f"on_connect(): {mqtt.connack_string(rc)}")


    def on_message(self, client, userdata, msg):
        # Decode Json-message and ignore non-json formatted messages.
        try:
            message: dict = json.loads(msg.payload.decode('utf-8'))
        except json.decoder.JSONDecodeError:
            print(f"=====\nWARNING: Received message with incorrect formating:\n{msg.payload}\nIgnoring message...\n=====")
            return
        
        if "msg" not in message.keys():
            print(f"=====\nWARNING: Json object does not contain the key 'msg':\n{message}\nIgnoring message...\n=====")
            return
        
        print(f"MAIN: on_message(): topic: {msg.topic}, msg: {message['msg']}")
        
        # Determine purpose of message and handle appropriately.
        if message["msg"] == "create_session":
            self.server.create_session(message["ta_name"])
        elif message["msg"] == "join_session":
            # If group_name is part of message, then the request stems from a student, otherwise a TA.
            if "group_name" in message.keys():
                self.server.join_session_student(message["student_code"], message["group_name"])
            elif "ta_name" in message.keys():
                self.server.join_session_ta(message["ta_code"], message["ta_name"])
            
        
       
        
    def start(self, broker, port):
        print("Connecting to {}:{}".format(broker, port))
        self.client.connect(broker, port)
        self.client.subscribe(f"{TOPIC}/{JOIN_TOPIC}")

        try:
            thread = Thread(target=self.client.loop_forever)
            thread.start()
        except KeyboardInterrupt:
            print("Interrupted")
            self.client.disconnect()
    


In [None]:


server = Server()
myclient = MQTT_Main_Server(server)
server.main_mqtt_client = myclient.client

myclient.start(BROKER, PORT)



In [None]:

print(f"There are {len(server.sessions)} session currently running:")

for session in server.sessions:
    print(f"\nSession {session.id}:")
    print(f"TA-code: {session.ta_code} | Student-code: {session.student_code}")
    print(f"Teachers: {session.teachers}")
    print(f"Students: {session.students}")
    print(f"Students: {session.progress}")

