In [None]:
import numpy as np
import heapq
from collections import deque
import matplotlib.pyplot as plt


class Event:
    def __init__(self, event_type, time, queue_id=None, job_id=None):
        self.event_type = event_type
        self.time = time
        self.queue_id = queue_id
        self.job_id = job_id

    def __lt__(self, other):
        return self.time < other.time


class Queue:
    def __init__(self, arrival_rate, ini_jobs):
        self.queue = deque()
        self.arrival_rate = arrival_rate
        self.ini_jobs = ini_jobs
        self.total_jobs = 0
        self.completed_jobs = 0
        self.total_wait_time = 0
        self.last_queue_length = 0
        self.last_time = 0
        self.total_queue_length_time = 0

    def add_job(self, arrival_time):
        job_id = self.total_jobs
        self.queue.append((job_id, arrival_time))
        self.total_jobs += 1
        return job_id

    def process_job(self, current_time):
        if not self.queue:
            return None, 0, None
        job_id, arrival_time = self.queue.popleft()
        wait_time = current_time - arrival_time
        self.completed_jobs += 1
        self.total_wait_time += wait_time
        return job_id, wait_time, arrival_time

    def update_queue_stats(self, current_time):
        duration = current_time - self.last_time
        self.total_queue_length_time += self.last_queue_length * duration
        self.last_queue_length = len(self.queue)
        self.last_time = current_time


class Server:
    def __init__(self, service_rate, switch_time):
        self.service_rate = service_rate
        self.switch_time = switch_time
        self.state = 'idle'  # 'idle' / 'serving' / 'switching'
        self.current_queue = None
        self.current_job_id = None
        self.current_job_arrival_time = None
        self.last_state_change_time = 0.0
        self.total_busy_time = 0.0

    def update_utilization(self, current_time):
        duration = current_time - self.last_state_change_time
        if self.state in ['serving', 'switching']:
            self.total_busy_time += duration
        self.last_state_change_time = current_time

class QNetwork:
    """Two-hidden-layer ReLU MLP used to approximate Q(X, a, b)."""
    def __init__(self, input_dim, hidden_dims=(64, 32), seed=0, beta1=0.9, beta2=0.999, epsilon=1e-8):
        self.rng = np.random.default_rng(seed)
        self.W1 = self.rng.normal(scale=0.1, size=(input_dim, hidden_dims[0]))
        self.b1 = np.zeros(hidden_dims[0])
        self.W2 = self.rng.normal(scale=0.1, size=(hidden_dims[0], hidden_dims[1]))
        self.b2 = np.zeros(hidden_dims[1])
        self.W3 = self.rng.normal(scale=0.1, size=(hidden_dims[1], 1))
        self.b3 = np.zeros(1)

        # Adam parameters
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.m_W1 = np.zeros_like(self.W1)
        self.m_b1 = np.zeros_like(self.b1)
        self.m_W2 = np.zeros_like(self.W2)
        self.m_b2 = np.zeros_like(self.b2)
        self.m_W3 = np.zeros_like(self.W3)
        self.m_b3 = np.zeros_like(self.b3)
        self.v_W1 = np.zeros_like(self.W1)
        self.v_b1 = np.zeros_like(self.b1)
        self.v_W2 = np.zeros_like(self.W2)
        self.v_b2 = np.zeros_like(self.b2)
        self.v_W3 = np.zeros_like(self.W3)
        self.v_b3 = np.zeros_like(self.b3)

        # Time step counter for bias correction
        self.t = 0

    def forward(self, x):
        z1 = x @ self.W1 + self.b1
        h1 = np.maximum(0.0, z1)
        z2 = h1 @ self.W2 + self.b2
        h2 = np.maximum(0.0, z2)
        out = float(h2 @ self.W3 + self.b3)
        cache = (x, z1, h1, z2, h2)
        return out, cache

    def predict(self, x):
        out, _ = self.forward(x)
        return out

    def params_vector(self):
        return np.concatenate([ 
            self.W1.ravel(), self.b1.ravel(), 
            self.W2.ravel(), self.b2.ravel(),
            self.W3.ravel(), self.b3.ravel()
        ])

    def train(self, x, target, eta, clip=None):
        self.t += 1  # Increment time step
        q_pred, (x, z1, h1, z2, h2) = self.forward(x)
        error = q_pred - target  # dL/dq for 0.5*(q-target)^2

        # Gradients
        grad_W3 = np.outer(h2, error)
        grad_b3 = np.array([error])
        grad_h2 = self.W3.flatten() * error
        grad_z2 = grad_h2 * (z2 > 0)
        grad_W2 = np.outer(h1, grad_z2)
        grad_b2 = grad_z2
        grad_h1 = grad_z2 @ self.W2.T
        grad_z1 = grad_h1 * (z1 > 0)
        grad_W1 = np.outer(x, grad_z1)
        grad_b1 = grad_z1

        grads = [grad_W1, grad_b1, grad_W2, grad_b2, grad_W3, grad_b3]

        # Adam updates
        for i, grad in enumerate(grads):
            # m_t and v_t update for Adam
            if i == 0:
                self.m_W1 = self.beta1 * self.m_W1 + (1 - self.beta1) * grad
                self.v_W1 = self.beta2 * self.v_W1 + (1 - self.beta2) * grad**2
            elif i == 1:
                self.m_b1 = self.beta1 * self.m_b1 + (1 - self.beta1) * grad
                self.v_b1 = self.beta2 * self.v_b1 + (1 - self.beta2) * grad**2
            elif i == 2:
                self.m_W2 = self.beta1 * self.m_W2 + (1 - self.beta1) * grad
                self.v_W2 = self.beta2 * self.v_W2 + (1 - self.beta2) * grad**2
            elif i == 3:
                self.m_b2 = self.beta1 * self.m_b2 + (1 - self.beta1) * grad
                self.v_b2 = self.beta2 * self.v_b2 + (1 - self.beta2) * grad**2
            elif i == 4:
                self.m_W3 = self.beta1 * self.m_W3 + (1 - self.beta1) * grad
                self.v_W3 = self.beta2 * self.v_W3 + (1 - self.beta2) * grad**2
            elif i == 5:
                self.m_b3 = self.beta1 * self.m_b3 + (1 - self.beta1) * grad
                self.v_b3 = self.beta2 * self.v_b3 + (1 - self.beta2) * grad**2

        # Bias correction
        m_W1_hat = self.m_W1 / (1 - self.beta1**self.t)
        v_W1_hat = self.v_W1 / (1 - self.beta2**self.t)
        m_b1_hat = self.m_b1 / (1 - self.beta1**self.t)
        v_b1_hat = self.v_b1 / (1 - self.beta2**self.t)
        m_W2_hat = self.m_W2 / (1 - self.beta1**self.t)
        v_W2_hat = self.v_W2 / (1 - self.beta2**self.t)
        m_b2_hat = self.m_b2 / (1 - self.beta1**self.t)
        v_b2_hat = self.v_b2 / (1 - self.beta2**self.t)
        m_W3_hat = self.m_W3 / (1 - self.beta1**self.t)
        v_W3_hat = self.v_W3 / (1 - self.beta2**self.t)
        m_b3_hat = self.m_b3 / (1 - self.beta1**self.t)
        v_b3_hat = self.v_b3 / (1 - self.beta2**self.t)

        # Update parameters with Adam
        self.W1 -= eta * m_W1_hat / (np.sqrt(v_W1_hat) + self.epsilon)
        self.b1 -= eta * m_b1_hat / (np.sqrt(v_b1_hat) + self.epsilon)
        self.W2 -= eta * m_W2_hat / (np.sqrt(v_W2_hat) + self.epsilon)
        self.b2 -= eta * m_b2_hat / (np.sqrt(v_b2_hat) + self.epsilon)
        self.W3 -= eta * m_W3_hat / (np.sqrt(v_W3_hat) + self.epsilon)
        self.b3 -= eta * m_b3_hat / (np.sqrt(v_b3_hat) + self.epsilon)

        delta = target - q_pred
        return delta, q_pred


class PollingSystem:
    def __init__(self, queue_nums, arrival_rates, ini_jobs_list,
                 service_rate, switch_time, simulation_time):
        self.queue_nums = queue_nums
        self.queues = [Queue(arrival_rates[i], ini_jobs_list[i]) for i in range(queue_nums)]
        self.server = Server(service_rate, switch_time)
        self.event_queue = []
        self.current_time = 0.0
        self.simulation_time = simulation_time
        self.job_counter = 0

        # === AMQ / RL ?? ===
        self.gamma = 0.9
        self.eta0 = 0.01
        self.step_counter = 0
        self.cost_attack = 8.0
        self.cost_defend = 6.0
        self.switch_cost_val = 1.0
        self.delta_clip = 50.0
        self.step_clip = 1e3
        self.train_interval = 1

        # === nonlinear Q approximator ===
        # ????????? 5 ? [1, x_i + d_i, (x_i + d_i)^2, a, b]
        self.input_dim = 5 * queue_nums
        self.q_net = QNetwork(self.input_dim, hidden_dims=(128, 64), seed=42)

        self.rl_history = []
        self.queue_length_records = []

        self.schedule_initial_events()

    def schedule_initial_events(self):
        initial_lengths = [len(q.queue) for q in self.queues]
        self.queue_length_records.append((self.current_time, initial_lengths))
        for qi, q in enumerate(self.queues):
            for _ in range(q.ini_jobs):
                q.add_job(self.current_time)
                self.job_counter += 1
        for i, queue in enumerate(self.queues):
            self.schedule_arrival(i)

    def record_queue_lengths(self):
        current_lengths = [len(q.queue) for q in self.queues]
        self.queue_length_records.append((self.current_time, current_lengths))

    def schedule_event(self, event):
        heapq.heappush(self.event_queue, event)

    def schedule_arrival(self, queue_id):
        rate = self.queues[queue_id].arrival_rate
        if rate <= 0:
            return
        inter_arrival = np.random.exponential(1.0 / rate)
        event_time = self.current_time + inter_arrival
        if event_time <= self.simulation_time:
            event = Event('arrival', event_time, queue_id)
            self.schedule_event(event)

    def schedule_service_completion(self, queue_id):
        service_time = np.random.exponential(1.0 / self.server.service_rate)
        event_time = self.current_time + service_time
        event = Event('service_completion', event_time, queue_id, self.server.current_job_id)
        self.schedule_event(event)

    def schedule_switch_completion(self, queue_id):
        event_time = self.current_time + self.server.switch_time
        event = Event('switch_completion', event_time, queue_id)
        self.schedule_event(event)

    def get_longest_queue(self):
        queue_lengths = [(len(queue.queue), i) for i, queue in enumerate(self.queues)]
        queue_lengths.sort(key=lambda x: (-x[0], x[1]))
        return queue_lengths[0][1] if queue_lengths and queue_lengths[0][0] > 0 else None

    def get_shortest_queue(self):
        queue_lengths = [(len(queue.queue), i) for i, queue in enumerate(self.queues)]
        queue_lengths.sort(key=lambda x: (x[0], x[1]))
        return queue_lengths[0][1] if queue_lengths and queue_lengths[0][0] >= 0 else None

    def compute_reward(self, X, a, b, ns=1.0):
        n = len(X)
        sx = float(sum(X))
        sxx = float(sum(x * x for x in X))
        if sxx > 0.0 and n > 0:
            fairness = (sx * sx) / (n * sxx)
        else:
            fairness = 0.0

        switch_cost = self.switch_cost_val * float(ns)
        cost_a = self.cost_attack * float(a)
        cost_b = self.cost_defend * float(b)

        return fairness - switch_cost - cost_a + cost_b

    def features(self, X, a, b):
        n = self.queue_nums
        xi = np.asarray(X, dtype=float)
        if n == 0:
            return np.zeros(0, dtype=float)

        max_val = np.max(xi)
        min_val = np.min(xi)
        idx_max = [i for i in range(n) if xi[i] == max_val]
        idx_min = [i for i in range(n) if xi[i] == min_val]

        feats = []
        for i in range(n):
            di = 0.0
            if (a == 1 and b == 0 and i in idx_max) or ((a, b) != (1, 0) and i in idx_min):
                di = 1.0
            val = xi[i] + di
            feats.extend([
                1.0,
                val,
                val*val,
                float(a),
                float(b),
            ])
        return np.array(feats, dtype=float)

    def q_value(self, X, a, b):
        phi = self.features(X, a, b)
        return float(self.q_net.predict(phi))

    def solve_theta_minimax(self, Q_next):
        best_c = float('inf')
        best_t = 0.0
        for t in np.linspace(0.0, 1.0, 201):
            th0, th1 = 1.0 - t, t
            f0 = th0 * Q_next[(0, 0)] + th1 * Q_next[(0, 1)]
            f1 = th0 * Q_next[(1, 0)] + th1 * Q_next[(1, 1)]
            c = max(f0, f1)
            if c < best_c:
                best_c, best_t = c, t
        return best_c, {0: 1.0 - best_t, 1: best_t}

    def run_simulation(self):
        while self.event_queue:
            event = heapq.heappop(self.event_queue)
            self.current_time = event.time
            if self.current_time > self.simulation_time:
                break

            if event.event_type == 'arrival':
                self.process_arrival(event.queue_id)
            elif event.event_type == 'service_completion':
                self.process_service_completion(event.queue_id, event.job_id)
                self.maybe_attack_defend()
            elif event.event_type == 'switch_completion':
                self.process_switch_completion(event.queue_id)

        for queue in self.queues:
            queue.update_queue_stats(self.simulation_time)
        self.server.update_utilization(self.simulation_time)

    def sample_actions_from_behavior(self, X):
        l1 = float(sum(X))
        p_attack = np.exp(-l1 / 2.0)
        a = 1 if np.random.rand() < p_attack else 0

        if l1 != 0.0:
            p_defend = 1.0 - np.exp(-l1 / 2.0)
        else:
            p_defend = 0.5
        b = 1 if np.random.rand() < p_defend else 0
        return a, b

    def maybe_attack_defend(self):
        X_k = [len(q.queue) for q in self.queues]
        a, b = self.sample_actions_from_behavior(X_k)

        attack_success = (a == 1 and b == 0 and self.server.state == 'serving')

        R_k1 = self.compute_reward(X_k, a, b, ns=1.0)

        if attack_success:
            self.perform_attack()

        X_k1 = [len(q.queue) for q in self.queues]

        Q_next = {(aa, bb): self.q_value(X_k1, aa, bb)
                  for aa in (0, 1) for bb in (0, 1)}
        c_star, theta = self.solve_theta_minimax(Q_next)

        q_curr = self.q_value(X_k, a, b)
        raw_target = R_k1 + self.gamma * c_star
        delta_raw = raw_target - q_curr
        if not np.isfinite(delta_raw):
            delta_raw = 0.0
        delta_k = float(np.clip(delta_raw, -self.delta_clip, self.delta_clip))
        target = q_curr + delta_k

        phi = self.features(X_k, a, b)

        self.step_counter += 1
        K = 50000.0
        eta_k = self.eta0 / (1.0 + self.step_counter / K)

        if (self.step_counter % self.train_interval) == 0:
            _, q_before = self.q_net.train(phi, target, eta=eta_k, clip=self.step_clip)
            q_after = self.q_net.predict(phi)
        else:
            q_before = q_after = q_curr

        self.rl_history.append({
            "time": self.current_time,
            "X_k": X_k,
            "a": a,
            "b": b,
            "X_k1": X_k1,
            "R_k1": R_k1,
            "delta_raw": float(delta_raw),
            "delta_clipped": float(delta_k),
            "c_star": float(c_star),
            "theta": theta,
            "eta": eta_k,
            "q_before": float(q_before),
            "q_after": float(q_after),
            "params": self.q_net.params_vector(),
        })

    def perform_attack(self):
        target_queue = self.get_shortest_queue()
        if target_queue is None or target_queue == self.server.current_queue:
            return
        if self.server.current_job_id is not None:
            self.queues[self.server.current_queue].queue.appendleft(
                (self.server.current_job_id, self.server.current_job_arrival_time)
            )
        self.server.update_utilization(self.current_time)
        self.server.state = 'switching'
        self.server.last_state_change_time = self.current_time
        self.server.current_job_id = None
        self.server.current_job_arrival_time = None
        self.schedule_switch_completion(target_queue)

    def process_arrival(self, queue_id):
        queue = self.queues[queue_id]
        queue.update_queue_stats(self.current_time)
        job_id = queue.add_job(self.current_time)
        self.job_counter += 1
        self.schedule_arrival(queue_id)
        if self.server.state == 'idle':
            self.start_next_service()

    def process_service_completion(self, queue_id, job_id):
        if self.server.current_job_id != job_id:
            return
        queue = self.queues[queue_id]
        queue.update_queue_stats(self.current_time)
        job_id, wait_time, arrival_time = queue.process_job(self.current_time)
        self.server.update_utilization(self.current_time)
        if job_id is not None:
            self.server.state = 'idle'
            self.server.current_job_id = None
            self.server.current_job_arrival_time = None
        self.start_next_service()

    def process_switch_completion(self, queue_id):
        self.server.update_utilization(self.current_time)
        self.server.state = 'idle'
        self.server.current_queue = queue_id
        self.start_next_service()

    def start_next_service(self):
        next_queue = self.get_longest_queue()
        if next_queue is None:
            self.server.state = 'idle'
            self.server.current_queue = None
            self.server.current_job_id = None
            self.server.current_job_arrival_time = None
            return
        queue = self.queues[next_queue]
        job_id, wait_time, arrival_time = queue.process_job(self.current_time)
        if job_id is None:
            self.server.state = 'idle'
            self.server.current_queue = None
            self.server.current_job_id = None
            self.server.current_job_arrival_time = None
            return
        self.server.update_utilization(self.current_time)
        self.server.state = 'serving'
        self.server.current_queue = next_queue
        self.server.current_job_id = job_id
        self.server.current_job_arrival_time = arrival_time
        self.schedule_service_completion(next_queue)

    def print_statistics(self):
        print("Simulation finished at time:", self.current_time)
        total_completed = sum(q.completed_jobs for q in self.queues)
        print("Total completed jobs:", total_completed)
        for i, queue in enumerate(self.queues):
            avg_wait = (queue.total_wait_time / queue.completed_jobs) if queue.completed_jobs > 0 else 0
            avg_queue_length = (queue.total_queue_length_time / self.current_time) if self.current_time > 0 else 0
            print(f"Queue {i}:")
            print("  Completed jobs:", queue.completed_jobs)
            print("  Average wait time:", avg_wait)
            print("  Average queue length:", avg_queue_length)
        util = (self.server.total_busy_time / self.current_time) if self.current_time > 0 else 0
        print("Server utilization:", util)
    
    def plot_w_evolution(self, max_dims=10, max_steps=10000):
        """
        绘制 Q 网络参数的演化曲线，显示每个分量与收敛值的差的绝对值。
        只使用前 max_steps 次更新（如果实际更新步数不足，则使用全部）。
        """
        if not self.rl_history:
            print("No RL updates recorded.")
            return

        total_steps = len(self.rl_history)
        use_steps = min(max_steps, total_steps)

        # 0, 1, ..., use_steps-1
        steps = np.arange(use_steps)
        param_mat = np.array([self.rl_history[i]["params"] for i in range(use_steps)])

        # 计算与收敛值的差的绝对值
        final_params = param_mat[-1]  # 最后一个参数值（收敛值）
        diff_param_mat = np.abs(param_mat - final_params)  # 与收敛值的差的绝对值

        # 图 1：||theta||_2 的差值
        norms = np.linalg.norm(diff_param_mat, axis=1)
        plt.figure(figsize=(10, 5))
        plt.plot(steps, norms)
        plt.xlabel("Update step")
        plt.ylabel("||theta - theta_final||_2")
        plt.title(f"Norm of Q-network parameters difference (first {use_steps} steps)")
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # 图 2：前若干维分量的差值
        d = diff_param_mat.shape[1]
        dims = min(max_dims, d)
        plt.figure(figsize=(10, 6))
        for i in range(dims):
            plt.plot(steps, diff_param_mat[:, i], label=f"theta[{i}] difference")
        plt.xlabel("Update step")
        plt.ylabel("Difference of theta components")
        plt.title(f"First {dims} components of parameters difference (first {use_steps} steps)")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()



if __name__ == "__main__":
    queue_nums = 3
    arrival_rates = [5.0, 5.0, 5.0]
    ini_jobs_list = [0, 20, 18]
    service_rate = 25.0
    switch_time = 2.0
    simulation_time = 10000.0

    system = PollingSystem(
        queue_nums=queue_nums,
        arrival_rates=arrival_rates,
        ini_jobs_list=ini_jobs_list,
        service_rate=service_rate,
        switch_time=switch_time,
        simulation_time=simulation_time,
    )

    system.run_simulation()
    system.print_statistics()

    final_state = [len(q.queue) for q in system.queues]
    print("\nFinal Q-network parameter norm:")
    print(np.linalg.norm(system.q_net.params_vector()))
    print("Final state (queue lengths):", final_state)
    for a in (0, 1):
        for b in (0, 1):
            q_val = system.q_value(final_state, a, b)
            print(f"Q(X_final, a={a}, b={b}) = {q_val:.4f}")

    if len(system.rl_history) > 0:
        print("\nLast 5 updates:")
        for rec in system.rl_history[-5:]:
            print(f"time={rec['time']:.2f}, delta={rec['delta_clipped']}, q_before={rec['q_before']}, q_after={rec['q_after']}")
            print("param_norm:", np.linalg.norm(rec["params"]))
            print("-" * 40)

    system.plot_w_evolution(max_dims=10)

  out = float(h2 @ self.W3 + self.b3)


In [None]:
class PollingSystemWithNN(PollingSystem):
    def __init__(self, queue_nums, arrival_rates, ini_jobs_list, service_rate, switch_time, simulation_time, q_net):
        """
        初始化时，除了基本的仿真参数外，还传入外部神经网络 (q_net)。
        """
        super().__init__(queue_nums, arrival_rates, ini_jobs_list, service_rate, switch_time, simulation_time)
        self.q_net = q_net  # 接收外部传入的神经网络

    def run_simulation_with_nn(self):
        """
        运行仿真过程，神经网络用于决策（进攻、防守）。
        保留训练过程的历史记录（rl_history）。
        """
        while self.event_queue:
            event = heapq.heappop(self.event_queue)
            self.current_time = event.time

            # 检查是否超过仿真时间，结束仿真
            if self.current_time > self.simulation_time:
                break

            # 根据事件类型处理不同的事件
            if event.event_type == 'arrival':
                self.process_arrival(event.queue_id)
            elif event.event_type == 'service_completion':
                self.process_service_completion(event.queue_id, event.job_id)
                self.maybe_attack_defend_with_nn()
            elif event.event_type == 'switch_completion':
                self.process_switch_completion(event.queue_id)

        # 更新队列的统计信息
        for queue in self.queues:
            queue.update_queue_stats(self.simulation_time)
        self.server.update_utilization(self.simulation_time)

    def maybe_attack_defend_with_nn(self):
        """
        使用传入的神经网络进行决策（进攻或防守）。
        记录强化学习历史（rl_history）。
        """
        X_k = [len(q.queue) for q in self.queues]  # 当前队列长度作为状态
        a, b = self.sample_actions_from_nn(X_k)  # 使用神经网络来选择进攻或防守动作

        attack_success = (a == 1 and b == 0 and self.server.state == 'serving')

        # 计算奖励（奖励函数可以根据实际情况设计，这里假设我们使用了原有的奖励计算方法）
        R_k1 = self.compute_reward(X_k, a, b, ns=1.0)

        if attack_success:
            self.perform_attack()

        # 计算下一状态 X_k1
        X_k1 = [len(q.queue) for q in self.queues]

        # 计算Q值并进行参数更新
        Q_next = {(aa, bb): self.q_value(X_k1, aa, bb) for aa in (0, 1) for bb in (0, 1)}
        c_star, theta = self.solve_theta_minimax(Q_next)

        q_curr = self.q_value(X_k, a, b)
        raw_target = R_k1 + self.gamma * c_star
        delta_raw = raw_target - q_curr
        delta_k = float(np.clip(delta_raw, -self.delta_clip, self.delta_clip))
        target = q_curr + delta_k

        phi = self.features(X_k, a, b)

        # 更新神经网络
        self.step_counter += 1
        K = 50000.0
        eta_k = self.eta0 / (1.0 + self.step_counter / K)

        # 每隔一定步数进行训练
        if (self.step_counter % self.train_interval) == 0:
            _, q_before = self.q_net.train(phi, target, eta=eta_k, clip=self.step_clip)
            q_after = self.q_net.predict(phi)
        else:
            q_before = q_after = q_curr

        # 保存RL训练过程中的数据
        self.rl_history.append({
            "time": self.current_time,
            "X_k": X_k,
            "a": a,
            "b": b,
            "X_k1": X_k1,
            "R_k1": R_k1,
            "delta_raw": float(delta_raw),
            "delta_clipped": float(delta_k),
            "c_star": float(c_star),
            "theta": theta,
            "eta": eta_k,
            "q_before": float(q_before),
            "q_after": float(q_after),
            "params": self.q_net.params_vector(),
        })

    def sample_actions_from_nn(self, X):
        """
        ???????????????? X ?????a?????b????
        """
        q_values = []
        for a in (0, 1):
            for b in (0, 1):
                phi = self.features(X, a, b)
                q_values.append(self.q_net.predict(phi))

        # ??Q??????
        best_action = int(np.argmax(q_values))
        a, b = divmod(best_action, 2)  # ???????? (a, b) ?????
        return a, b



        # 选择Q值最大的动作
        best_action = np.argmax(q_values)
        a, b = divmod(best_action, 2)  # 将一维动作映射回 (a, b) 对应的动作
        return a, b


if __name__ == "__main__":
    q_net = system.q_net

    # 初始化仿真系统，并传入神经网络
    system_with_nn = PollingSystemWithNN(
        queue_nums=3,
        arrival_rates=[5.0, 5.0, 5.0],
        ini_jobs_list=[0, 20, 18],
        service_rate=25.0,
        switch_time=2.0,
        simulation_time=10000.0,
        q_net=q_net  # 传入神经网络
    )

    # 运行仿真
    system_with_nn.run_simulation_with_nn()

    # 输出仿真结果
    system_with_nn.print_statistics()

  out = float(h2 @ self.W3 + self.b3)


Simulation finished at time: 9999.972881225436
Total completed jobs: 149880
Queue 0:
  Completed jobs: 49812
  Average wait time: 0.5575476277031185
  Average queue length: 2.8197302818008203
Queue 1:
  Completed jobs: 50203
  Average wait time: 0.5699569449639634
  Average queue length: 2.8636682920475716
Queue 2:
  Completed jobs: 49865
  Average wait time: 0.5816271182627175
  Average queue length: 2.8750754563258223
Server utilization: 0.49718766254978974


: 