In [1]:
from ipynb.fs.full.sim import Simulator
from ipynb.fs.full.traffic import Frame
from ipynb.fs.full.logger import Logger
from ipynb.fs.full.topology import Switch, BasicSwitch, IngressPort, EgressPort
from collections import deque
import heapq

In [2]:
class ATSGroup:
    def __init__(self, cir_bps: float, cbs_bytes: int, max_residence_time: float):
        self.cir = cir_bps      # Commited Information Rate
        self.cbs = cbs_bytes    # Commited Burst Size
        self.mrt = max_residence_time
        self.empty_to_full_duration = (cbs_bytes*8) / cir_bps 

        self.bucket_empty_time = -1.0     # time when bucket was empty
        self.group_eligibility_time = 0.0

In [3]:
class ATSSwitch(Switch):
    def __init__(self, name: str, simulator: Simulator, logger: Logger, routing_table: dict, stream_to_group: dict, ats_groups: dict):
        """
        routing_table: stream_id -> list[egress_port_id]
        stream_to_group: stream_id -> ats_group_id
        ats_groups: ats_group_id -> ATSGroup
        """
        super().__init__(name, simulator, logger)
        
        self.routing_table = routing_table
        self.stream_to_group = stream_to_group
        self.ats_groups = ats_groups
        self.queues = {}      # egress_port_id -> list[(eligibility_time, Frame)]
        self.sending = {}     # egress_port_id -> bool
        self._seq = 0         # tie-breaker for identical eligibility times

    def add_egress_port(self, port_id: str, link):
        port = super().add_egress_port(port_id, link)
        self.queues[port_id] = []
        self.sending[port_id] = False
        return port
                
    def process(self, frame: Frame, ingress_port: IngressPort):
        now = self.sim.time
        frame_arival_time = now
        
        # Group Check
        group_id = self.stream_to_group.get(frame.stream_id)
        if group_id is None:
            self.logger.log(now, self.name, "ATS_DROP_NO_GROUP", frame, None)
            return
        group: ATSGroup = self.ats_groups[group_id]

        # Routing
        out_ports = self.routing_table.get(frame.stream_id)
        if not out_ports:
            self.logger.log(now, self.name, "ATS_DROP_NO_ROUTE", frame, None)
            return

        # MRT check 
        if now - frame.creation_time > group.mrt:
            self.logger.log(now, self.name, "ATS_DROP_MRT_VIOLATION", frame, None)
            return

        # Compute eligibility
        length_recovery_duration = (frame.size*8)/group.cir
        scheduler_eligibility_time = group.bucket_empty_time + length_recovery_duration
        bucket_full_time = group.bucket_empty_time + group.empty_to_full_duration
        eligibility_time = max(frame_arival_time, group.group_eligibility_time, scheduler_eligibility_time)  
        group.group_eligibility_time = eligibility_time
        group.bucket_empty_time = scheduler_eligibility_time if eligibility_time < bucket_full_time else scheduler_eligibility_time + eligibility_time - bucket_full_time  
        
        # Enqueue per egress port
        for port_id in out_ports:
            self._seq += 1
            heapq.heappush(self.queues[port_id], (eligibility_time, self._seq, frame))

            if not self.sending[port_id]:
                self._try_send(port_id)

    def _try_send(self, port_id: str):
        queue = self.queues[port_id]
    
        if not queue:
            self.sending[port_id] = False
            return
    
        eligibility, _, frame = queue[0]
    
        if self.sim.time < eligibility:
            self.sim.schedule(eligibility - self.sim.time, self._try_send, port_id)
            return
    
        heapq.heappop(queue)
        self.sending[port_id] = True
    
        finish_time = self.egress_ports[port_id].send(frame)
        self.sim.schedule(finish_time - self.sim.time, self._on_tx_done, port_id)

    def _on_tx_done(self, port_id: str):
        self.sending[port_id] = False
        self._try_send(port_id)