In [11]:
#dqn.py 

import torch
import torch.nn as nn
import torch.nn.functional as F # Often used for activation functions

class DQNetwork(nn.Module):
    """
    Deep Q-Network for Reinforcement Learning.

    This model substitutes the Q-table in Q-Learning. It takes state
    information related to a number of devices and outputs Q-values for
    all possible discrete actions.
    """
    def __init__(self, num_devices: int, hidden_layer_list: list[int] = None,
                 _input_features_override: int = None, # For RiskAverseDQN
                 _output_features_override: int = None # For RiskAverseDQN
                 ):
        super().__init__()
        self.num_devices = num_devices

        if hidden_layer_list is None:
            self.hidden_layer_list = [20, 20]
        else:
            self.hidden_layer_list = list(hidden_layer_list)

        if _input_features_override is not None:
            self.input_features = _input_features_override
        else:
            self.input_features = 4 * self.num_devices # Default

        if _output_features_override is not None:
            self.output_features = _output_features_override
        else:
            self.output_features = 3 ** self.num_devices # Default

        layers = []
        current_in_features = self.input_features
        if self.hidden_layer_list:
            for num_nodes in self.hidden_layer_list:
                if num_nodes <= 0:
                    raise ValueError("Number of nodes in a hidden layer must be positive.")
                layers.append(nn.Linear(current_in_features, num_nodes))
                #layers.append(nn.BatchNorm1d(num_nodes))
                layers.append(nn.LayerNorm(num_nodes))
                layers.append(nn.ReLU())
                current_in_features = num_nodes
        
        layers.append(nn.Linear(current_in_features, self.output_features))
        self.network = nn.Sequential(*layers)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """
        Performs a forward pass through the network.

        Args:
            state (torch.Tensor): The input state tensor.
                                  Expected shape: (batch_size, 4 * num_devices)

        Returns:
            torch.Tensor: The Q-values for each action.
                          Shape: (batch_size, 3^num_devices)
        """
        if state.shape[-1] != self.input_features:
            raise ValueError(
                f"Input tensor last dimension ({state.shape[-1]}) "
                f"does not match expected input features ({self.input_features})."
            )
        return self.network(state)

# --- Example Usage ---
if __name__ == '__main__':
    # Scenario 1: Default hidden layers
    num_dev = 2
    model1 = DQNetwork(num_devices=num_dev)
    print(f"Model 1 (num_devices={num_dev}, hidden_layers=default):")
    print(model1)
    print(f"  Input features: {model1.input_features} (4 * {num_dev})")
    print(f"  Output features: {model1.output_features} (3^{num_dev})")

    # Create a dummy batch of input states
    # batch_size = 5, num_devices = 2 => input_features = 4*2 = 8
    dummy_input1 = torch.randn(5, 4 * num_dev)
    output1 = model1(dummy_input1)
    print(f"  Dummy input shape: {dummy_input1.shape}")
    print(f"  Output shape: {output1.shape}")
    print("-" * 30)

    # Scenario 2: Custom hidden layers
    num_dev = 3
    custom_hidden = [64, 32, 16]
    model2 = DQNetwork(num_devices=num_dev, hidden_layer_list=custom_hidden)
    print(f"Model 2 (num_devices={num_dev}, hidden_layers={custom_hidden}):")
    print(model2)
    print(f"  Input features: {model2.input_features} (4 * {num_dev})")
    print(f"  Output features: {model2.output_features} (3^{num_dev})")

    dummy_input2 = torch.randn(10, 4 * num_dev) # batch_size = 10
    output2 = model2(dummy_input2)
    print(f"  Dummy input shape: {dummy_input2.shape}")
    print(f"  Output shape: {output2.shape}")
    print("-" * 30)

    # Scenario 3: No hidden layers
    num_dev = 1
    model3 = DQNetwork(num_devices=num_dev, hidden_layer_list=[]) # Empty list
    print(f"Model 3 (num_devices={num_dev}, hidden_layers=[]):")
    print(model3)
    print(f"  Input features: {model3.input_features} (4 * {num_dev})")
    print(f"  Output features: {model3.output_features} (3^{num_dev})")

    dummy_input3 = torch.randn(2, 4 * num_dev) # batch_size = 2
    output3 = model3(dummy_input3)
    print(f"  Dummy input shape: {dummy_input3.shape}")
    print(f"  Output shape: {output3.shape}")
    print("-" * 30)

    # Example of incorrect input shape
    try:
        num_dev_test = 2
        model_test = DQNetwork(num_devices=num_dev_test)
        wrong_input = torch.randn(1, 4 * num_dev_test + 1) # One extra feature
        model_test(wrong_input)
    except ValueError as e:
        print(f"Caught expected error for wrong input: {e}")

Model 1 (num_devices=2, hidden_layers=default):
DQNetwork(
  (network): Sequential(
    (0): Linear(in_features=8, out_features=20, bias=True)
    (1): LayerNorm((20,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=20, out_features=20, bias=True)
    (4): LayerNorm((20,), eps=1e-05, elementwise_affine=True)
    (5): ReLU()
    (6): Linear(in_features=20, out_features=9, bias=True)
  )
)
  Input features: 8 (4 * 2)
  Output features: 9 (3^2)
  Dummy input shape: torch.Size([5, 8])
  Output shape: torch.Size([5, 9])
------------------------------
Model 2 (num_devices=3, hidden_layers=[64, 32, 16]):
DQNetwork(
  (network): Sequential(
    (0): Linear(in_features=12, out_features=64, bias=True)
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=32, bias=True)
    (4): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (5): ReLU()
    (6): Linear(in_features=32, out_features=16, b

In [12]:
#riskaverseqlearning.py

import random
import numpy as np
from collections import defaultdict
import itertools
import random
import math

class RiskAverseQLearning:
    def __init__(self, K, N, M, I, Ts, d, St=1000):
        self.num_devices = K
        self.num_sub6 = N
        self.num_mmWave = M
        self.num_Qtable = I
        self.frame_duration = Ts
        self.packet_size = d
        
        self.cold_start = St
        
        self.exploration_rate = 0.999 # eps
        self.decay_factor = 0.995 # lambda
        self.discount_factor = 0.9 # gamma
        
        self.risk_control = 0.5 # lambda_p
        self.utility_func_param = -0.5 # beta
        
        self.PLR_req = 0.1 # phi_max
        self.Lk = [6] * K
        
        self.known_average_rate = [[self.packet_size / min(1,Ts) * self.Lk[_],self.packet_size / min(1,Ts) * self.Lk[_]] for _ in range(K)]
        self.Success = [[0, 0] for _ in range(K)]
        self.Alloc = [[0, 0] for _ in range(K)]
        self.PLR = [[0.0, 0.0] for _ in range(K)]
        self.PSR = [1.0 for _ in range(K)]
        self.Q_table = [defaultdict(lambda:defaultdict(lambda:0.0)) for _ in range(I)]
        self.Count = [defaultdict(lambda:defaultdict(int)) for _ in range(I)]
        self.cur_state = None
        self.init_state()
    
    
        self.CC = defaultdict(int)
    
    def init_state(self):
        """
        Initialize to a random state
        """
        state = []
        for k in range(self.num_devices):
            a = random.choice(range(self.Lk[k]+1))
            b = random.choice(range(self.Lk[k]+1))
            state.extend([
                random.choice([0,1]),
                random.choice([0,1]),
                a,
                self.Lk[k]-a
            ])
        self.cur_state = tuple(state)
    
    def update_state(self):
        """
        Update the current state based on PLR and Success
        """
        state = []
        for k in range(self.num_devices):
            state.extend([
                int(self.PLR[k][0] <= self.PLR_req),
                int(self.PLR[k][1] <= self.PLR_req),
                self.Success[k][0],
                self.Success[k][1]
                ])
        self.cur_state = tuple(state)
    
    def get_random_action_tuple(self):
        """
        Get a random {x}^k tuple, x in (0,1,2)
        """
        return tuple(random.choices(range(3), k=self.num_devices))
    
    def get_action_tuple(self, Q_hat_index):
        """
        Get an action when choose Q^H = Q_hat_index
        """
        cur_state = self.cur_state
        
        # compute Q_hat explicitly
        Q_bar = defaultdict(lambda:defaultdict(lambda:0.0))
        for i in range(self.num_Qtable):
            if cur_state not in self.Q_table[i]:
                continue
            q = self.Q_table[i][cur_state]
            for a in q:
                Q_bar[cur_state][a] += q[a]
                
        Q_hat = defaultdict(lambda:defaultdict(lambda:0.0))
        #mx = 0
        for a in itertools.product(range(3), repeat=self.num_devices):
            # variance
            Q = Q_bar[cur_state][a] / self.num_Qtable
            for i in range(self.num_Qtable):
                qval = 0
                if cur_state in self.Q_table[i] and a in self.Q_table[i][cur_state]:
                    qval = self.Q_table[i][cur_state][a]
                Q_hat[cur_state][a] += (qval - Q)**2
                            
            # Q_hat
            q = 0
            if cur_state in self.Q_table[Q_hat_index] and a in self.Q_table[Q_hat_index][cur_state]:
                q = self.Q_table[Q_hat_index][cur_state][a]
            V = q - self.risk_control / max(1,(self.num_Qtable-1)) * Q_hat[cur_state][a]
            Q_hat[cur_state][a] = V
            #mx = max(mx, V)
        a, v = self.get_max_action(Q_hat, cur_state)
        #assert(mx >= v)
        #if self.CC[cur_state] > 0:
        #    print(cur_state, a, v, mx)
        return a
    
    def receive_reward(self, reward, cur_frame, sample_achievable):
        """
        Receive r(s,a), update to s'
        """
        
        # reward = array of [success sub6, success mmWave]
        self.Success = reward
        
        total_reward = 0

        # update PLR
        for k in range(self.num_devices):
            for i in range(2):
                last_plr = self.PLR[k][i] * (cur_frame-1)
                new_plr = 0
                if self.Alloc[k][i] != 0:
                    new_plr = 1 - self.Success[k][i] / self.Alloc[k][i]
                self.PLR[k][i] = (last_plr + new_plr) / cur_frame
        #print("PLR each device: ", [sum(x)/2 for x in self.PLR])
        #print("PLR each device: ", self.PLR)
        
        # update PSR
        for k in range(self.num_devices):
            last_psr = self.PSR[k] * (cur_frame-1)
            new_psr = 1
            if sum(self.Alloc[k]) != 0:
                new_psr = sum(self.Success[k]) / sum(self.Alloc[k])
            self.PSR[k] = (last_psr + new_psr) /  cur_frame
            
            total_reward += self.PSR[k]
            #total_reward += new_psr
            
            total_reward -= 1 - int(self.PLR[k][0] <= self.PLR_req)
            total_reward -= 1 - int(self.PLR[k][1] <= self.PLR_req)
        #print("PSR each device: ", self.PSR)
        
        # update known average rate
        for k in range(self.num_devices):
            for i in range(2):
                alpha = 0.7
                old_rate = self.known_average_rate[k][i] * alpha
                new_rate = sample_achievable[k][i] * (1-alpha)
                self.known_average_rate[k][i] = (old_rate + new_rate) 
        
        return total_reward
    
    def map_action(self, action_chosen):
        """
        Given tuple {x}^k, x in (0,1,2), map to [ <mm_i, sub6_i> ]
        """
        for i, action in enumerate(action_chosen):
            if action == 0:
                self.Alloc[i][0] = max(1, min(int(self.known_average_rate[i][0] * self.frame_duration / self.packet_size), self.Lk[i]))
                self.Alloc[i][1] = 0
            elif action == 1:
                self.Alloc[i][0] = 0
                self.Alloc[i][1] = max(1, min(int(self.known_average_rate[i][1] * self.frame_duration / self.packet_size), self.Lk[i]))
            else:
                self.Alloc[i][1] = max(1, min(int(self.known_average_rate[i][1] * self.frame_duration / self.packet_size), self.Lk[i]))
                self.Alloc[i][0] = max(1, min(int(self.known_average_rate[i][0] * self.frame_duration / self.packet_size), self.Lk[i] - self.Alloc[i][1]))
        return self.Alloc
    
    def get_max_action(self, Q, s):
        """
        Given Q and s, return [a, v] such that v = Q[s][a] max
        """
        if s not in Q:
            return [self.get_random_action_tuple(), 0]
        state = Q[s]
        curMaxNegAction = [None, -1e9]
        curMaxPosAction = [None, -1e9]
        negActionCount = 0
        for action in itertools.product(range(3), repeat=self.num_devices):
            if action not in state:
                continue
            v = state[action]
            if v < 0:
                if v > curMaxNegAction[1]:
                    curMaxNegAction = [action, v]
                negActionCount += 1
            else:
                if v > curMaxPosAction[1]:
                    curMaxPosAction = [action, v]
        if curMaxPosAction[1] > -1e9:
            return curMaxPosAction
        if negActionCount == 3**self.num_devices:
            return curMaxNegAction
        aList = []
        for action in itertools.product(range(3), repeat=self.num_devices):
            if action not in state:
                aList.append(action)
        return [random.choice(aList), 0]
    
    def get_current_action(self, cur_frame):
        """
        Env ask for best action
        """
        Q_hat_chosen = random.choice(range(self.num_Qtable))
        if cur_frame >= self.cold_start:
            if self.exploration_rate >= 0.01:
                self.exploration_rate *= self.decay_factor
        r1 = random.random()
        if r1 < self.exploration_rate:
            action_chosen = self.get_random_action_tuple()
        else:
            action_chosen = self.get_action_tuple(Q_hat_chosen)
            
        A = sum(k in (0, 2) for k in action_chosen)
        B = sum(k in (1, 2) for k in action_chosen)
            
        return action_chosen, (A <= self.num_sub6 and B <= self.num_mmWave)
        
    def calc_utility(self, x):
        return -math.exp(self.utility_func_param * x) 
        
    def update_to_new_state(self, reward, cur_frame, action, sample_achievable_rate):
        """
        Env send result of action a(t), calc r(t), update to s(t+1)
        """
        old_state = self.cur_state
        
        rew = self.receive_reward(reward, cur_frame, sample_achievable_rate)
        self.update_state()
        
        new_state = self.cur_state
        #self.CC[new_state] += 1
        #print("Old state: ", old_state)
        #print("New state: ", new_state)
        
        # update table
        msk = np.random.poisson(size=self.num_Qtable)
        for i, v in enumerate(msk):
            if v != 1:
                continue
            oldQ = 0
            if old_state in self.Q_table[i] and action in self.Q_table[i][old_state]:
                oldQ = self.Q_table[i][old_state][action]
            oldA = 1
            if old_state in self.Count[i] and action in self.Count[i][old_state]:
                oldA = 1 / self.Count[i][old_state][action]
            # find max Action in current Q (max[a] Q(s(t+1), *))
            max_QA = self.get_max_action(self.Q_table[i], new_state)[1]
            x0 = -10
            newQ = oldQ + oldA * (self.calc_utility(rew + self.discount_factor * max_QA - oldQ) - x0) # eq (21)
            
            self.Count[i][old_state][action] += 1 # line 14+15            
            self.Q_table[i][old_state][action] = newQ
        return rew

In [13]:
# deepqlearing_simple.py
import random
import numpy as np
from collections import deque, namedtuple
import itertools
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Define a named tuple for transitions stored in the replay memory
Transition = namedtuple('Transition', ('state', 'action_index', 'next_state', 'reward'))

class DeepQLearning:
    def __init__(self, K, N, M, I, Ts, d,
                 replay_memory_capacity=500,
                 batch_size=64,
                 gamma=0.9, # discount_factor
                 eps_start=0.999, # exploration_rate start
                 eps_end=0.05,
                 eps_decay=0.995, # decay_factor for exploration_rate
                 target_update_freq=100, # C from image (steps)
                 learning_rate=1e-4,
                 dqn_hidden_layers=None, # Pass to DQNetwork, e.g., [128, 64]
                 St = 1500,
                 ):
        # Parameters from the original signature
        self.num_devices = K
        self.num_sub6 = N      # Max devices on sub6, used in get_current_action constraint check
        self.num_mmWave = M    # Max devices on mmWave, used in get_current_action constraint check
        # I (num_Qtable from RiskAverseQLearning) is unused but kept for signature compatibility
        self.frame_duration = Ts
        self.packet_size = d
        self.cold_start = St
        
        # Q-learning and agent parameters
        self.exploration_rate = eps_start # Current epsilon for epsilon-greedy
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.decay_factor = eps_decay # Multiplicative decay factor for exploration_rate
        self.discount_factor = gamma  # GAMMA for Q-learning updates

        # State and reward tracking, same as RiskAverseQLearning
        self.PLR_req = 0.1
        self.Lk = [6] * K # Example: Max packets per device
        
        self.known_average_rate = [[self.packet_size / min(1,Ts) * self.Lk[_],self.packet_size / min(1,Ts) * self.Lk[_]] for _ in range(K)] # [sub6, mmWave]
        self.Success = [[0, 0] for _ in range(K)] # [sub6, mmWave] successful packets
        self.Alloc = [[0, 0] for _ in range(K)]   # [sub6, mmWave] allocated packets
        self.PLR = [[0.0, 0.0] for _ in range(K)] # [sub6, mmWave] Packet Loss Rate
        self.PSR = [1.0 for _ in range(K)]        # Overall Packet Success Rate per device
        
        self.cur_state = None # This will be a Python tuple representing the current state
        self.init_state()     # Initialize self.cur_state

        # DQN specific attributes
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Initialize Policy Network and Target Network
        self.policy_net = DQNetwork(num_devices=K, hidden_layer_list=dqn_hidden_layers).to(self.device)
        self.target_net = DQNetwork(num_devices=K, hidden_layer_list=dqn_hidden_layers).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval() # Target network is only for inference, not training directly

        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=learning_rate, amsgrad=True)
        self.replay_memory = deque(maxlen=replay_memory_capacity)
        self.BATCH_SIZE = batch_size
        self.TARGET_UPDATE_FREQUENCY = target_update_freq # C: steps to update target network
        
        self.total_steps_done = 0 # For decaying epsilon and updating target network

        # Helper for mapping actions (tuples) to indices and vice-versa
        # Action: tuple of length K, each element in {0, 1, 2}
        # 0: sub6 only, 1: mmWave only, 2: both
        self._action_tuples_list = list(itertools.product(range(3), repeat=self.num_devices))
        self._action_to_index_map = {action_tuple: i for i, action_tuple in enumerate(self._action_tuples_list)}
        self._index_to_action_map = {i: action_tuple for i, action_tuple in enumerate(self._action_tuples_list)}

    def _state_to_tensor(self, state_tuple):
        """Converts a state tuple to a PyTorch tensor for network input."""
        if state_tuple is None:
            return None
        # Ensure the state is flat list of numbers before converting to tensor
        return torch.tensor(list(state_tuple), dtype=torch.float32, device=self.device).unsqueeze(0)

    def init_state(self):
        """Initialize to a random state (tuple). Kept from RiskAverseQLearning."""
        state_list = []
        for k_idx in range(self.num_devices):
            # Example state components:
            # - PLR requirement met for sub6 (0 or 1)
            # - PLR requirement met for mmWave (0 or 1)
            # - Number of successful transmissions on sub6 in last step
            # - Number of successful transmissions on mmWave in last step
            # For init, we can use random values or typical starting values.
            # The original used random choices for 'a' and 'b' for success counts.
            # Let's use Lk[k_idx] as a reference for success counts init.
            # Assuming Success components are counts, not rates for state.
            sub6_success_init = random.choice(range(self.Lk[k_idx] + 1))
            mmwave_success_init = self.Lk[k_idx] - sub6_success_init # Ensure sum is Lk for this example part
            
            state_list.extend([
                random.choice([0, 1]),  # Mock PLR sub6 met
                random.choice([0, 1]),  # Mock PLR mmWave met
                sub6_success_init,      # Mock last success sub6 (count)
                mmwave_success_init     # Mock last success mmWave (count)
            ])
        self.cur_state = tuple(state_list)

    def update_state(self):
        """Update the current state (tuple) based on PLR and Success. Kept from RiskAverseQLearning."""
        state_list = []
        for k_idx in range(self.num_devices):
            state_list.extend([
                int(self.PLR[k_idx][0] <= self.PLR_req),  # PLR sub6 requirement met
                int(self.PLR[k_idx][1] <= self.PLR_req),  # PLR mmWave requirement met
                self.Success[k_idx][0],                   # Actual successful packets sub6
                self.Success[k_idx][1]                    # Actual successful packets mmWave
            ])
        self.cur_state = tuple(state_list)

    def get_random_action_tuple(self):
        """Get a random action tuple. Kept from RiskAverseQLearning."""
        return tuple(random.choices(range(3), k=self.num_devices))

    def get_action_tuple(self, state_tensor_for_net):
        """
        Selects an action using the policy network based on the current state tensor.
        Args:
            state_tensor_for_net (torch.Tensor): The current state as a tensor [1, num_features].
        Returns:
            tuple: The action tuple selected by the policy network.
        """
        with torch.no_grad(): # No gradient needed for action selection
            # policy_net outputs Q-values for all 3^K actions
            q_values = self.policy_net(state_tensor_for_net)
            # Select action with the highest Q-value
            action_index = q_values.max(1)[1].item() # .max(1) returns (values, indices)
        return self._index_to_action_map[action_index]

    def receive_reward(self, env_reward_signal, current_frame_number, sample_achievable_rates):
        """
        Calculate scalar reward based on environment feedback and update internal metrics (PLR, PSR, known_average_rate).
        Kept similar to RiskAverseQLearning.
        Args:
            env_reward_signal (list of lists): [[success_sub6_dev0, success_mmwave_dev0], ...]
            current_frame_number (int): The current frame number (1-indexed).
            sample_achievable_rates (list of lists): [[rate_sub6_dev0, rate_mmwave_dev0], ...]
        Returns:
            float: The calculated scalar reward.
        """
        self.Success = env_reward_signal # Update success counts based on environment feedback
        
        total_scalar_reward = 0.0

        # Update PLR for each device and each band
        for k_idx in range(self.num_devices):
            for band_idx in range(2): # 0 for sub6, 1 for mmWave
                # Calculate sum of past PLR values to maintain running average
                # current_frame_number is 1-indexed. For frame 1, (current_frame_number-1) is 0.
                sum_past_plr = self.PLR[k_idx][band_idx] * (current_frame_number - 1)
                
                current_plr_value = 0.0
                if self.Alloc[k_idx][band_idx] > 0:
                    current_plr_value = 1.0 - (self.Success[k_idx][band_idx] / self.Alloc[k_idx][band_idx])
                # If Alloc is 0, PLR is 0 (no packets sent, so no packets lost)
                
                self.PLR[k_idx][band_idx] = (sum_past_plr + current_plr_value) / current_frame_number
        
        # Update PSR (Packet Success Rate) and calculate part of the reward
        for k_idx in range(self.num_devices):
            sum_past_psr = self.PSR[k_idx] * (current_frame_number - 1)
            
            current_psr_value = 1.0 # Default PSR is 1 (e.g. if no packets allocated)
            if sum(self.Alloc[k_idx]) > 0:
                current_psr_value = sum(self.Success[k_idx]) / sum(self.Alloc[k_idx])
            
            self.PSR[k_idx] = (sum_past_psr + current_psr_value) / current_frame_number
            
            total_scalar_reward += current_psr_value
            
            # Penalize if PLR requirements are not met
            total_scalar_reward -= (1 - int(self.PLR[k_idx][0] <= self.PLR_req)) # Penalty for sub6 PLR miss
            total_scalar_reward -= (1 - int(self.PLR[k_idx][1] <= self.PLR_req)) # Penalty for mmWave PLR miss
        
        # Update known average achievable rate
        for k_idx in range(self.num_devices):
            for band_idx in range(2):
                A = 0.7
                sum_past_rates = self.known_average_rate[k_idx][band_idx] * A
                current_rate_sample = sample_achievable_rates[k_idx][band_idx] * (1.0-A)
                self.known_average_rate[k_idx][band_idx] = (sum_past_rates + current_rate_sample)
        
        return total_scalar_reward

    def map_action(self, action_chosen_tuple):
        """
        Maps the chosen action tuple to resource allocations (self.Alloc).
        Kept from RiskAverseQLearning.
        Args:
            action_chosen_tuple (tuple): The action chosen, e.g., (0, 1, 2) for K devices.
        Returns:
            list of lists: The updated self.Alloc.
        """
        for dev_idx, action_type in enumerate(action_chosen_tuple):
            dev_lk = self.Lk[dev_idx] # Max packets for this device
            # Estimated packets based on known average rate, frame duration, and packet size
            est_packets_sub6 = int(self.known_average_rate[dev_idx][0] * self.frame_duration / self.packet_size)
            est_packets_mmwave = int(self.known_average_rate[dev_idx][1] * self.frame_duration / self.packet_size)

            if action_type == 0: # Sub6 only
                self.Alloc[dev_idx][0] = max(0, min(est_packets_sub6, dev_lk))
                self.Alloc[dev_idx][1] = 0
            elif action_type == 1: # mmWave only
                self.Alloc[dev_idx][0] = 0
                self.Alloc[dev_idx][1] = max(0, min(est_packets_mmwave, dev_lk))
            else: # action_type == 2, Both (prioritize mmWave up to Lk-1, then sub6)
                # Allocate to mmWave, reserving at least 1 slot for sub6 if Lk > 0
                alloc_mmwave_limit = dev_lk - 1 if dev_lk > 0 else 0
                self.Alloc[dev_idx][1] = max(0, min(est_packets_mmwave, alloc_mmwave_limit))
                
                remaining_capacity_for_sub6 = dev_lk - self.Alloc[dev_idx][1]
                self.Alloc[dev_idx][0] = max(0, min(est_packets_sub6, remaining_capacity_for_sub6))
        return self.Alloc

    def get_current_action(self, cur_frame):
        """
        Selects an action using epsilon-greedy strategy:
        With probability epsilon, selects a random action.
        Otherwise, selects the best action according to the policy network.
        Also returns a flag indicating if the action respects resource constraints.
        """
        # Decay exploration rate
        #self.exploration_rate = self.eps_end + \
        #                        (self.eps_start - self.eps_end) * \
        #                        math.exp(-1. * self.total_steps_done * (1./(1/self.decay_factor)) ) # Using typical exponential decay related to decay_factor interpreation

        # Alternative simpler multiplicative decay:
        if cur_frame >= self.cold_start:
            if self.exploration_rate > self.eps_end:
                self.exploration_rate *= self.decay_factor


        rand_sample = random.random()
        if rand_sample < self.exploration_rate:
            action_chosen_tuple = self.get_random_action_tuple()
            # print(f"Step {self.total_steps_done}: RANDOM action! Eps: {self.exploration_rate:.3f}") # For debugging
        else:
            current_state_tensor = self._state_to_tensor(self.cur_state)
            action_chosen_tuple = self.get_action_tuple(current_state_tensor)
            # print(f"Step {self.total_steps_done}: DQN action. Eps: {self.exploration_rate:.3f}") # For debugging
            
        # Check constraints (A: num devices on sub6, B: num devices on mmWave)
        # Action type 0 (sub6), 1 (mmWave), 2 (both)
        num_active_sub6 = sum(1 for k_act_type in action_chosen_tuple if k_act_type in (0, 2))
        num_active_mmwave = sum(1 for k_act_type in action_chosen_tuple if k_act_type in (1, 2))
            
        constraints_met = (num_active_sub6 <= self.num_sub6 and num_active_mmwave <= self.num_mmWave)
        
        return action_chosen_tuple, constraints_met

    def _optimize_model(self):
        """Performs a single step of optimization on the policy network using a batch from replay memory."""
        if len(self.replay_memory) < self.BATCH_SIZE:
            return None # Not enough samples in memory to form a batch

        # Sample a random minibatch of transitions from the replay memory
        transitions = random.sample(self.replay_memory, self.BATCH_SIZE)
        # Convert batch-array of Transitions to Transition of batch-arrays.
        batch = Transition(*zip(*transitions))

        # Concatenate batch elements for PyTorch processing
        # batch.state, batch.next_state are tuples of tensors ([1, num_features])
        state_batch = torch.cat(batch.state)
        # batch.action_index is a tuple of integer action indices
        action_batch = torch.tensor(batch.action_index, device=self.device, dtype=torch.long).unsqueeze(1)
        # batch.reward is a tuple of scalar reward tensors ([1])
        reward_batch = torch.cat(batch.reward)
        next_state_batch = torch.cat(batch.next_state)
        
        # Compute Q(s_t, a_t)
        current_q_values = self.policy_net(state_batch).gather(1, action_batch)

        # Compute V(s_{t+1}) = max_{a'} Q_target(s_{t+1}, a')
        with torch.no_grad():
            next_state_max_q_values = self.target_net(next_state_batch).max(1)[0]
        
        # Compute the expected Q-values (y_j = r_j + gamma * V(s_{t+1}))
        expected_q_values = reward_batch + (self.discount_factor * next_state_max_q_values)

        loss = F.smooth_l1_loss(current_q_values, expected_q_values.unsqueeze(1))
        # loss = F.mse_loss(current_q_values, expected_q_values.unsqueeze(1))

        # Optimize the policy network
        self.optimizer.zero_grad()
        loss.backward()
        # Optional: Gradient clipping to stabilize training
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()
        
        return loss.item() # Return loss value for monitoring

    def update_to_new_state(self, env_reward_signal, current_frame_number, action_taken_tuple, sample_achievable_rates):
        """
        Processes the outcome of an action:
        1. Calculates the scalar reward and updates internal state metrics.
        2. Updates the agent's current state (s_t -> s_{t+1}).
        3. Stores the transition (s_t, a_t, r_t, s_{t+1}) in replay memory.
        4. Performs a model optimization step (training).
        5. Periodically updates the target network.
        Args:
            env_reward_signal (list of lists): Feedback from env (e.g., success counts).
            current_frame_number (int): Current frame/timestep number.
            action_taken_tuple (tuple): The action that was executed.
            sample_achievable_rates (list of lists): Observed achievable rates.
        Returns:
            float: The scalar reward received for the transition.
        """
        old_state_tuple = self.cur_state # s_t (Python tuple)
        old_state_tensor = self._state_to_tensor(old_state_tuple) # Convert to tensor

        # Calculate scalar reward (r_t) and update internal metrics based on env_reward_signal
        actual_scalar_reward = self.receive_reward(env_reward_signal, current_frame_number, sample_achievable_rates)
        reward_tensor = torch.tensor([actual_scalar_reward], device=self.device, dtype=torch.float32)
        
        # Update current state to new_state (s_{t+1}) based on updated metrics
        self.update_state()
        new_state_tuple = self.cur_state # s_{t+1} (Python tuple)
        new_state_tensor = self._state_to_tensor(new_state_tuple) # Convert to tensor
        
        # Convert action_taken_tuple to its index for storage
        action_index = self._action_to_index_map[action_taken_tuple]
        
        # Store the transition in replay memory D
        self.replay_memory.append(Transition(old_state_tensor, action_index, new_state_tensor, reward_tensor))

        # Increment step counter
        self.total_steps_done += 1

        # Perform one step of the optimization (on the policy network)
        loss_value = self._optimize_model() # This will only run if memory has enough samples

        # Periodically update the target network with weights from the policy network (every C steps)
        if self.total_steps_done % self.TARGET_UPDATE_FREQUENCY == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())
            # print(f"--- Step {self.total_steps_done}: Target network updated. Exploration: {self.exploration_rate:.4f} ---")
            # if loss_value is not None:
            #    print(f"Training Loss: {loss_value:.4f}")

        return actual_scalar_reward

In [14]:
#deepqlearning_riskaverse.py

import random
import numpy as np
from collections import deque, namedtuple
import itertools
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Define a named tuple for transitions stored in the replay memory
# s_orig: original state tuple from environment
# eta_idx: index of the discrete eta value for the current state (η_t)
# action_orig_idx: index of the original environment action (a_t)
# next_eta_idx: index of the discrete eta value chosen as part of the action (η_{t+1})
# reward_raw: the raw reward r_t from the environment
# next_s_orig: original next state tuple from environment (s_{t+1})
RiskTransition = namedtuple('RiskTransition', (
    's_orig_tuple', 'current_eta_idx',
    'action_orig_idx', 'chosen_next_eta_idx',
    'reward_raw', 'next_s_orig_tuple'
))

class RiskAverseDeepQLearning:
    def __init__(self, K, N, M, I_unused, Ts, d,
                 # --- Standard DQN params ---
                 replay_memory_capacity=5000,
                 batch_size=64,
                 gamma=0.9, # discount_factor
                 eps_start=0.999,
                 eps_end=0.05,
                 eps_decay=0.995,
                 target_update_freq=100, # C from image (steps)
                 learning_rate=1e-4,
                 dqn_hidden_layers=None,
                 St = 1500,
                 # --- Risk-Averse Specific Params ---
                 lambda_risk=0.5,  # λ_t in the paper (balancing expectation and CVaR)
                 alpha_cvar=0.05,   # α_t, confidence level for CVaR (e.g., 0.05 for 95% CVaR)
                 num_eta_levels=20, # D, number of discrete levels for η
                 eta_min_val=-20.0, # Lower bound for η discretization (e.g., min expected reward)
                 eta_max_val=20.0   # Upper bound for η discretization (e.g., max expected reward)
                 ):

        # Parameters from the original signature
        self.num_devices = K
        self.num_sub6 = N
        self.num_mmWave = M
        self.frame_duration = Ts
        self.packet_size = d
        self.cold_start = St
        
        # Q-learning and agent parameters
        self.exploration_rate = eps_start
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.decay_factor = eps_decay
        self.discount_factor = gamma # GAMMA

        # State and reward tracking (original parts)
        self.PLR_req = 0.1
        self.Lk = [6] * K
        self.known_average_rate =  [[self.packet_size / min(1,Ts) * self.Lk[_],self.packet_size / min(1,Ts) * self.Lk[_]] for _ in range(K)]
        self.Success = [[0, 0] for _ in range(K)]
        self.Alloc = [[0, 0] for _ in range(K)]
        self.PLR = [[0.0, 0.0] for _ in range(K)]
        self.PSR = [1.0 for _ in range(K)]
        
        self.cur_state_orig_tuple = None # Stores only the original environment state part

        # --- Risk-Averse Specific Attributes ---
        self.lambda_risk = lambda_risk
        self.alpha_cvar = alpha_cvar
        if not (0 < self.alpha_cvar <= 1):
            raise ValueError("alpha_cvar must be in (0, 1]")
        if not (0 <= self.lambda_risk <= 1):
            raise ValueError("lambda_risk must be in [0, 1]")

        self.num_eta_levels = num_eta_levels
        self.eta_min_val = eta_min_val
        self.eta_max_val = eta_max_val
        if self.num_eta_levels > 1:
            self.eta_discrete_values = torch.linspace(self.eta_min_val, self.eta_max_val, self.num_eta_levels)
        elif self.num_eta_levels == 1:
             self.eta_discrete_values = torch.tensor([(self.eta_min_val + self.eta_max_val) / 2.0])
        else:
            raise ValueError("num_eta_levels must be at least 1.")
        
        self.current_eta_idx = self.num_eta_levels // 2
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device} for RiskAverseDQN")
        self.eta_discrete_values = self.eta_discrete_values.to(self.device)


        # DQN specific attributes
        # The input to DQNetwork will be: original_state_features + 1 (for current_eta_value)
        # The output of DQNetwork will be: num_original_actions * num_eta_levels
        self.num_original_actions = 3 ** self.num_devices
        
        # DQNetwork input_features = 4 * K (original state) + 1 (current eta value)
        # DQNetwork output_features = (3^K actions) * num_eta_levels (for choosing next_eta_idx)
        policy_input_features = 4 * self.num_devices + 1 
        policy_output_features = self.num_original_actions * self.num_eta_levels

        self.policy_net = DQNetwork(num_devices=self.num_devices, # This argument might be misleading for DQNetwork
                                    hidden_layer_list=dqn_hidden_layers,
                                    # Override input/output for risk-averse case
                                    _input_features_override=policy_input_features,
                                    _output_features_override=policy_output_features
                                    ).to(self.device)
        self.target_net = DQNetwork(num_devices=self.num_devices,
                                    hidden_layer_list=dqn_hidden_layers,
                                    _input_features_override=policy_input_features,
                                    _output_features_override=policy_output_features
                                    ).to(self.device)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=learning_rate, amsgrad=True)
        self.replay_memory = deque(maxlen=replay_memory_capacity)
        self.BATCH_SIZE = batch_size
        self.TARGET_UPDATE_FREQUENCY = target_update_freq
        self.total_steps_done = 0

        self._original_action_tuples_list = list(itertools.product(range(3), repeat=self.num_devices))
        self._original_action_to_index_map = {action_tuple: i for i, action_tuple in enumerate(self._original_action_tuples_list)}
        self._original_index_to_action_map = {i: action_tuple for i, action_tuple in enumerate(self._original_action_tuples_list)}
        
        self.init_state() # Initializes self.cur_state_orig_tuple

    def _get_current_eta_value(self):
        return self.eta_discrete_values[self.current_eta_idx]

    def _state_to_tensor(self, state_orig_tuple, eta_value_tensor):
        """Converts original state tuple and eta_value tensor to a combined PyTorch tensor."""
        if state_orig_tuple is None: return None
        s_tensor = torch.tensor(list(state_orig_tuple), dtype=torch.float32, device=self.device)
        # Ensure eta_value_tensor is [1] or compatible for cat
        if eta_value_tensor.ndim == 0:
            eta_value_tensor = eta_value_tensor.unsqueeze(0)
        return torch.cat((s_tensor, eta_value_tensor), dim=0).unsqueeze(0) # batch dim

    # --- Functions to keep similar to DeepQLearning, but adapt for state/action ---

    def init_state(self):
        """Initialize original state part. Eta is handled by self.current_eta_idx."""
        state_list = []
        for k_idx in range(self.num_devices):
            sub6_success_init = random.choice(range(self.Lk[k_idx] + 1))
            mmwave_success_init = self.Lk[k_idx] - sub6_success_init
            state_list.extend([
                random.choice([0, 1]), random.choice([0, 1]),
                sub6_success_init, mmwave_success_init
            ])
        self.cur_state_orig_tuple = tuple(state_list)
        # Also reset current_eta_idx for a new episode/init
        self.current_eta_idx = self.num_eta_levels // 2


    def update_state(self):
        """Update original state part. Eta is updated based on action."""
        state_list = []
        for k_idx in range(self.num_devices):
            state_list.extend([
                int(self.PLR[k_idx][0] <= self.PLR_req),
                int(self.PLR[k_idx][1] <= self.PLR_req),
                self.Success[k_idx][0], self.Success[k_idx][1]
            ])
        self.cur_state_orig_tuple = tuple(state_list)

    def get_random_action_tuple(self):
        """Returns (random_original_action_tuple, random_next_eta_idx)"""
        random_orig_action_tuple = tuple(random.choices(range(3), k=self.num_devices))
        random_next_eta_idx = random.randrange(self.num_eta_levels)
        return random_orig_action_tuple, random_next_eta_idx

    def get_action_tuple(self, current_full_state_tensor):
        """
        Selects (original_action_tuple, next_eta_idx) using the policy network.
        Args:
            current_full_state_tensor (torch.Tensor): Combined (original_state, current_eta_value).
        Returns:
            tuple: (chosen_original_action_tuple, chosen_next_eta_idx)
        """
        with torch.no_grad():
            q_values_all_composite_actions = self.policy_net(current_full_state_tensor) # Shape [1, num_orig_actions * num_eta_levels]
            
            # Find the index of the max Q-value (this is a flat index)
            composite_action_flat_idx = q_values_all_composite_actions.argmax(dim=1).item()
            
            # Convert flat index back to (original_action_idx, next_eta_idx)
            chosen_original_action_idx = composite_action_flat_idx // self.num_eta_levels
            chosen_next_eta_idx = composite_action_flat_idx % self.num_eta_levels
            
            chosen_original_action_tuple = self._original_index_to_action_map[chosen_original_action_idx]
        return chosen_original_action_tuple, chosen_next_eta_idx

    # receive_reward and map_action remain identical to DeepQLearning as they deal with original rewards/actions
    def receive_reward(self, env_reward_signal, current_frame_number, sample_achievable_rates):
        self.Success = env_reward_signal
        total_scalar_raw_reward = 0.0
        for k_idx in range(self.num_devices):
            for band_idx in range(2):
                sum_past_plr = self.PLR[k_idx][band_idx] * (current_frame_number - 1)
                current_plr_value = 0.0
                if self.Alloc[k_idx][band_idx] > 0:
                    current_plr_value = 1.0 - (self.Success[k_idx][band_idx] / self.Alloc[k_idx][band_idx])
                self.PLR[k_idx][band_idx] = (sum_past_plr + current_plr_value) / current_frame_number
        for k_idx in range(self.num_devices):
            sum_past_psr = self.PSR[k_idx] * (current_frame_number - 1)
            current_psr_value = 1.0
            if sum(self.Alloc[k_idx]) > 0:
                current_psr_value = sum(self.Success[k_idx]) / sum(self.Alloc[k_idx])
            self.PSR[k_idx] = (sum_past_psr + current_psr_value) / current_frame_number
            total_scalar_raw_reward += current_psr_value
            total_scalar_raw_reward -= (1 - int(self.PLR[k_idx][0] <= self.PLR_req))
            total_scalar_raw_reward -= (1 - int(self.PLR[k_idx][1] <= self.PLR_req))
        for k_idx in range(self.num_devices):
            for band_idx in range(2):
                A = 0.7
                sum_past_rates = self.known_average_rate[k_idx][band_idx] * A
                current_rate_sample = sample_achievable_rates[k_idx][band_idx] * (1.0-A)
                self.known_average_rate[k_idx][band_idx] = (sum_past_rates + current_rate_sample)
        return total_scalar_raw_reward


    def map_action(self, original_action_chosen_tuple):
        # Uses the original action part to determine allocations
        for dev_idx, action_type in enumerate(original_action_chosen_tuple):
            dev_lk = self.Lk[dev_idx]
            est_packets_sub6 = int(self.known_average_rate[dev_idx][0] * self.frame_duration / self.packet_size)
            est_packets_mmwave = int(self.known_average_rate[dev_idx][1] * self.frame_duration / self.packet_size)
            if action_type == 0:
                self.Alloc[dev_idx][0] = max(0, min(est_packets_sub6, dev_lk))
                self.Alloc[dev_idx][1] = 0
            elif action_type == 1:
                self.Alloc[dev_idx][0] = 0
                self.Alloc[dev_idx][1] = max(0, min(est_packets_mmwave, dev_lk))
            else:
                alloc_mmwave_limit = dev_lk - 1 if dev_lk > 0 else 0
                self.Alloc[dev_idx][1] = max(0, min(est_packets_mmwave, alloc_mmwave_limit))
                remaining_capacity_for_sub6 = dev_lk - self.Alloc[dev_idx][1]
                self.Alloc[dev_idx][0] = max(0, min(est_packets_sub6, remaining_capacity_for_sub6))
        return self.Alloc

    def get_current_action(self, cur_frame=0):
        """Selects (original_action_tuple, next_eta_idx) using epsilon-greedy."""
        #self.exploration_rate = self.eps_end + \
        #                        (self.eps_start - self.eps_end) * \
        #                        math.exp(-1. * self.total_steps_done * (1./(1/self.decay_factor)) )
        if cur_frame >= self.cold_start:
            if self.exploration_rate >= self.eps_end:
                self.exploration_rate *= self.decay_factor
        
        rand_sample = random.random()
        if rand_sample < self.exploration_rate:
            chosen_orig_action_tuple, chosen_next_eta_idx = self.get_random_action_tuple()
        else:
            current_eta_val_tensor = self._get_current_eta_value().detach() # Ensure it's a scalar tensor
            current_full_state_tensor = self._state_to_tensor(self.cur_state_orig_tuple, current_eta_val_tensor)
            chosen_orig_action_tuple, chosen_next_eta_idx = self.get_action_tuple(current_full_state_tensor)
            
        # Constraint check based on original action part
        num_active_sub6 = sum(1 for k_act_type in chosen_orig_action_tuple if k_act_type in (0, 2))
        num_active_mmwave = sum(1 for k_act_type in chosen_orig_action_tuple if k_act_type in (1, 2))
        constraints_met = (num_active_sub6 <= self.num_sub6 and num_active_mmwave <= self.num_mmWave)
        
        return chosen_orig_action_tuple, chosen_next_eta_idx, constraints_met

    def _optimize_model(self):
        if len(self.replay_memory) < self.BATCH_SIZE:
            return None

        transitions = random.sample(self.replay_memory, self.BATCH_SIZE)
        batch = RiskTransition(*zip(*transitions))

        # Unpack batch
        s_orig_batch_tensor = torch.tensor(batch.s_orig_tuple, dtype=torch.float32, device=self.device)
        current_eta_idx_batch = torch.tensor(batch.current_eta_idx, device=self.device, dtype=torch.long)
        action_orig_idx_batch = torch.tensor(batch.action_orig_idx, device=self.device, dtype=torch.long)
        chosen_next_eta_idx_batch = torch.tensor(batch.chosen_next_eta_idx, device=self.device, dtype=torch.long)
        reward_raw_batch = torch.tensor(batch.reward_raw, device=self.device, dtype=torch.float32)
        next_s_orig_batch_tensor = torch.tensor(batch.next_s_orig_tuple, dtype=torch.float32, device=self.device)
        
        # Get current eta values (η_j) from the discrete values tensor
        current_eta_val_batch = self.eta_discrete_values[current_eta_idx_batch] # Shape [BATCH_SIZE]
        
        # Get next eta values chosen as action (η_{j+1})
        chosen_next_eta_val_batch = self.eta_discrete_values[chosen_next_eta_idx_batch] # Shape [BATCH_SIZE]

        # Prepare policy_net input: Concatenate original state tensor and current eta value
        # Ensure eta_values have an extra dimension for concatenation if needed
        policy_net_input_batch = torch.cat(
            (s_orig_batch_tensor, current_eta_val_batch.unsqueeze(1)), dim=1
        ) # Shape [BATCH_SIZE, num_orig_features + 1]

        # Q( (s_j, η_j), (a_j, η_{j+1}) )
        # Q_values for all (orig_action, next_eta) pairs for each state in batch
        all_q_for_current_state = self.policy_net(policy_net_input_batch) # Shape [BATCH_SIZE, num_orig_actions * num_eta_levels]
        
        # Construct composite action index for gathering
        composite_action_indices = action_orig_idx_batch * self.num_eta_levels + chosen_next_eta_idx_batch
        current_q_values = all_q_for_current_state.gather(1, composite_action_indices.unsqueeze(1)).squeeze(1) # Shape [BATCH_SIZE]

        # 1. Risk-adjusted immediate reward part: -(λ/α)[η_j - r_j]_+ + (1-λ)r_j
        term1_cvar_part = -(self.lambda_risk / self.alpha_cvar) * F.relu(current_eta_val_batch - reward_raw_batch)
        term1_exp_part = (1 - self.lambda_risk) * reward_raw_batch
        risk_adjusted_immediate_reward = term1_cvar_part + term1_exp_part # Shape [BATCH_SIZE]

        # 2. Middle term: γ * λ * η_{j+1}
        term2_gamma_lambda_eta_next = self.discount_factor * self.lambda_risk * chosen_next_eta_val_batch # Shape [BATCH_SIZE]
        
        # 3. Future part: γ * max_{a', η''} Q_target(s_{j+1}, η_{j+1}, a', η''; θ_target)
        target_net_input_batch = torch.cat(
            (next_s_orig_batch_tensor, chosen_next_eta_val_batch.unsqueeze(1)), dim=1
        ) # Shape [BATCH_SIZE, num_orig_features + 1]
        
        with torch.no_grad():
            q_target_all_composite_actions = self.target_net(target_net_input_batch) # Shape [BATCH_SIZE, num_orig_actions * num_eta_levels]
            max_q_target_next = q_target_all_composite_actions.max(dim=1)[0] # Max over all (a', η'') pairs. Shape [BATCH_SIZE]
        
        term3_gamma_max_q_target = self.discount_factor * max_q_target_next # Shape [BATCH_SIZE]
        
        # Combine for y_j
        expected_q_values = risk_adjusted_immediate_reward + term2_gamma_lambda_eta_next + term3_gamma_max_q_target

        # Compute loss
        loss = F.smooth_l1_loss(current_q_values, expected_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()
        
        return loss.item()

    def update_to_new_state(self, env_reward_signal, current_frame_number,
                            action_taken_orig_tuple, chosen_next_eta_idx, # Action now has two parts
                            sample_achievable_rates):
        
        old_original_state_tuple = self.cur_state_orig_tuple
        previous_current_eta_idx = self.current_eta_idx # This is η_t

        # Calculate raw reward and update original state metrics (PLR, PSR etc.)
        raw_scalar_reward = self.receive_reward(env_reward_signal, current_frame_number, sample_achievable_rates)
        
        # Update original part of the state (s_t -> s_{t+1})
        self.update_state()
        new_original_state_tuple = self.cur_state_orig_tuple
        
        # Store transition in replay memory
        # (s_orig_t, η_t_idx, a_orig_t_idx, η_{t+1}_idx, r_raw_t, s_orig_{t+1})
        orig_action_idx = self._original_action_to_index_map[action_taken_orig_tuple]
        
        self.replay_memory.append(RiskTransition(
            old_original_state_tuple, previous_current_eta_idx,
            orig_action_idx, chosen_next_eta_idx,
            raw_scalar_reward, new_original_state_tuple
        ))

        self.current_eta_idx = chosen_next_eta_idx
        
        self.total_steps_done += 1
        loss_value = self._optimize_model()

        if self.total_steps_done % self.TARGET_UPDATE_FREQUENCY == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())
            # if loss_value is not None:
            #    print(f"Step {self.total_steps_done}: Target net updated. RiskDQN Loss: {loss_value:.4f}, Exp: {self.exploration_rate:.3f}")


        return raw_scalar_reward # Return raw reward for external monitoring

In [15]:
import numpy as np
import json
import random
import os
from datetime import datetime
import torch #sanity check

# Matplotlib for plotting
import matplotlib.pyplot as plt
from IPython.display import clear_output # For live updates in notebooks

# ---- Constants ----
AP_COORD = (0, 0) # Access Point coordinates
POINT_SIZE = 5 # Used for grid alignment in original GUI, not directly for simulation
PACKET_SIZE = 8000 # bits
FRAME_DURATION = 1e-3 # seconds (1ms)
USE_RISK_AVERSE = False # Set to True to use RiskAverseQLearning
USE_DEEP_RISK = False # Set to True to use RiskAverseDeepQLearning, False for DeepQLearning (if USE_RISK_AVERSE is False)
SCALE = 1.0 # Scale position (1.0: 1 block = 5m)

# Suffix for output filenames based on selected Q-learning model
SUFFIX_STR = ""
if USE_RISK_AVERSE:
    SUFFIX_STR = "RA"
else:
    if USE_DEEP_RISK:
        SUFFIX_STR = "DEEPRA"
    else:
        SUFFIX_STR = "DEEP"

# ---- Network Simulation Core ----

class NetworkSimulator:
    """
    Simulates a 5G network with an AP, devices, and blockages,
    using a Q-learning agent for resource allocation.
    """
    def __init__(self, json_config_path=None):
        """
        Initializes the simulator.
        
        Args:
            json_config_path (str, optional): Path to a JSON file containing
                                               layout and simulation settings.
        """
        # Data structures to hold device and blockage information
        self.devices_data = [] 
        self.blockages_data = [] 
        
        # Global simulation configuration parameters
        self.global_config = self._default_config()

        # Load configuration from JSON if provided
        if json_config_path:
            self.load_config_from_json(json_config_path)

        # Simulation state variables
        self.cur_frame = 0
        self.distances = [] 
        self.nblocks = [] 
        self.plr = [] 
        self.RewList = [] 
        self.MetricList = [] 
        self.runnRew = 0 
        self.qlearn = None 
        self.choice_counts = [] 

        # Data lists for Matplotlib plotting (stores (frame, value) tuples)
        self.plot_data_metric = []
        self.plot_data_reward = []
        self.plot_data_plr_per_device = [] 
        self.plot_data_sub6_rate_per_device = []
        self.plot_data_mmwave_rate_per_device = []

        # Frequency for live plot updates (every N frames)
        self.live_plot_update_interval = 500 

    def _default_config(self):
        """Returns a dictionary of default global simulation parameters."""
        return {
            "pwr": 5,             # Tx power (dBm)
            "noise": -169,        # Noise power (dBm/Hz)
            "bw_sub6": 100,       # Bandwidth Sub-6 (MHz)
            "bw_mm": 1000,        # Bandwidth mmWave (MHz)
            "num_subchannel": 4,  # Number of Sub-6GHz subchannels
            "num_beam": 4,        # Number of mmWave beams
            "nframes": 10000,     # Number of frames to simulate
        }

    def load_config_from_json(self, path):
        """
        Loads network layout (devices, blockages) and global simulation settings
        from a specified JSON file.
        """
        if not os.path.exists(path):
            print(f"Error: JSON config file not found at {path}")
            return False

        with open(path, "r") as f:
            data = json.load(f)

        # Clear existing data before loading new
        self.devices_data.clear()
        self.blockages_data.clear()

        # Restore devices data
        for d in data.get("devices", []):
            self.devices_data.append({
                'pos': tuple(d["pos"]),
                'sub6_packets': d.get('sub6_packets', 0), 
                'mmwave_packets': d.get('mmwave_packets', 0), 
                'history': {"sub6_success": [], "mmwave_success": []} # Initialize history
            })
        # Restore blockages data
        for b in data.get("blockages", []):
            self.blockages_data.append({'pos': tuple(b["pos"])})

        # Restore global config settings
        s = data.get("settings", {})
        for key, value in s.items():
            if key in self.global_config: 
                self.global_config[key] = value
        
        print(f"Loaded configuration from {path}")
        print(f"  Number of Devices: {len(self.devices_data)}")
        print(f"  Number of Blockages: {len(self.blockages_data)}")
        print(f"  Simulation Frames: {self.global_config['nframes']}")
        return True

    def save_json(self, path):
        """
        Saves the current layout (devices, blockages) and simulation settings
        to a JSON file.
        """
        data = {
            "devices": [
                {"pos": list(d['pos']), "sub6_packets": d['sub6_packets'], "mmwave_packets": d['mmwave_packets']}
                for d in self.devices_data
            ],
            "blockages": [
                {"pos": list(b['pos'])}
                for b in self.blockages_data
            ],
            "settings": self.global_config
        }
        try:
            with open(path, "w") as f:
                json.dump(data, f, indent=2)
            print(f"Saved current configuration to {path}")
            return True
        except Exception as e:
            print(f"Error saving JSON to {path}: {e}")
            return False

    def setup_simulation(self):
        """
        Initializes simulation-specific parameters (distances, blockages count,
        PLR tracking) and instantiates the Q-learning agent.
        """
        print("\nSetting up simulation environment...")
        self.cur_frame = 0
        self.distances = [0] * len(self.devices_data)
        self.nblocks = [0] * len(self.devices_data)
        
        # Calculate initial distances and blockage counts for each device
        for i, device in enumerate(self.devices_data):
            dev_pos = [x * SCALE for x in device['pos']] #shrunk position for quick check
            
            # Distance from AP to device
            dx = dev_pos[0] - AP_COORD[0]
            dy = dev_pos[1] - AP_COORD[1]
            sq_dist = dx**2 + dy**2
            self.distances[i] = (sq_dist ** 0.50) if sq_dist > 0 else 1e-6 

            # Count blockages between AP and device
            for blockage in self.blockages_data:
                blk_pos = [x * SCALE for x in blockage['pos']]
                
                # Check for collinearity
                cross_product = (blk_pos[1] - AP_COORD[1]) * (dev_pos[0] - blk_pos[0]) - \
                                (blk_pos[0] - AP_COORD[0]) * (dev_pos[1] - blk_pos[1])
                
                if abs(cross_product) > 1e-6: 
                    continue
                
                # Check if blockage is between AP and device
                v_ab_x = blk_pos[0] - AP_COORD[0]
                v_ab_y = blk_pos[1] - AP_COORD[1]
                v_ad_x = dev_pos[0] - AP_COORD[0]
                v_ad_y = dev_pos[1] - AP_COORD[1]
                
                dot_product = v_ab_x * v_ad_x + v_ab_y * v_ad_y
                len_v_ab_sq = v_ab_x**2 + v_ab_y**2
                len_v_ad_sq = v_ad_x**2 + v_ad_y**2

                if dot_product >= 0 and len_v_ab_sq <= len_v_ad_sq and len_v_ab_sq > 0:
                    self.nblocks[i] += 1
        
        # Initialize lists for simulation tracking
        self.plr = [0] * len(self.devices_data)
        self.RewList = []
        self.MetricList = []
        self.runnRew = 0
        
        # Clear data collected for plotting from previous runs
        self.plot_data_metric = []
        self.plot_data_reward = []
        self.plot_data_plr_per_device = [[] for _ in self.devices_data]
        self.choice_counts = [[0, 0, 0] for _ in self.devices_data] # [sub6_count, mmWave_count, idle_count]
        self.plot_data_sub6_rate_per_device = [[] for _ in self.devices_data]
        self.plot_data_mmwave_rate_per_device = [[] for _ in self.devices_data]
        
        # Instantiate the Q-learning agent
        num_devices = len(self.devices_data)
        num_subchannels = self.global_config['num_subchannel']
        num_beams = self.global_config['num_beam']

        if USE_RISK_AVERSE:
            self.qlearn = RiskAverseQLearning(num_devices, num_subchannels, num_beams, 4, FRAME_DURATION, PACKET_SIZE)
        else:
            if USE_DEEP_RISK:
                self.qlearn = RiskAverseDeepQLearning(
                    num_devices, num_subchannels, num_beams, 4,
                    FRAME_DURATION, PACKET_SIZE,
                    replay_memory_capacity = 2000,
                    batch_size = 128,
                    dqn_hidden_layers=[20, 20, 20], 
                    St = 1000,
                    num_eta_levels=20,
                    eta_min_val=-num_devices * 2,
                    eta_max_val=num_devices 
                )
            else:
                self.qlearn = DeepQLearning(
                    num_devices, num_subchannels, num_beams, 4,
                    FRAME_DURATION, PACKET_SIZE,
                    replay_memory_capacity = 2000,
                    batch_size = 128,
                    dqn_hidden_layers=[20, 20, 20],
                    St = 1000
                )

            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            assert(self.qlearn.device == device)
        print("Simulation setup complete.")

    def _db_to_linear(self, db):
        """Converts decibel (dB) value to linear scale."""
        return 10 ** (db / 10)

    def _dbm_to_linear(self, dbm):
        """Converts dBm value to Watts (linear scale)."""
        return self._db_to_linear(dbm - 30)

    def _tx_beam_gain(self, theta, eps=0.1):
        """Calculates a simplified transmit beam gain."""
        theta = max(abs(theta), 1e-6) 
        return (2 * np.pi - (2 * np.pi - theta) * eps) / theta

    def simulation_step(self):
        """
        Executes one frame of the network simulation.
        """
        num_devices = len(self.devices_data)
        if num_devices == 0:
            return 

        achievable_rate = [(0, 0)] * num_devices 

        P_tx_dbm = self.global_config['pwr']
        P_tx_W = self._dbm_to_linear(P_tx_dbm)
        
        noise_density_dbm_hz = self.global_config['noise']
        noise_density_W_hz = self._dbm_to_linear(noise_density_dbm_hz)

        W_sub_MHz = self.global_config['bw_sub6']
        W_sub_Hz = W_sub_MHz * 1e6
        N_subchannels = max(1, self.global_config['num_subchannel'])
        W_sub_per_channel_Hz = W_sub_Hz / N_subchannels
        P_tx_sub_W = P_tx_W / N_subchannels 
        
        W_mm_MHz = self.global_config['bw_mm']
        W_mm_Hz = W_mm_MHz * 1e6
        P_tx_mm_W = P_tx_W 
        
        mmWave_tx_beamwidth_rad = 0.1
        mmWave_tx_sidelobe_gain = 0.1 
        mmWave_rx_gain_lin = 1 

        for i, dev in enumerate(self.devices_data):
            d = self.distances[i] 

            # --- Sub-6GHz Channel Model ---
            PL_sub_db = 38.5 + 30 * np.log10(d) 
            PL_sub_lin = self._db_to_linear(-PL_sub_db)
            h_small_scale_sub_power = np.random.rayleigh(scale=1.0)**2 
            h_comb_sub_power = h_small_scale_sub_power * PL_sub_lin 
            noise_power_sub = noise_density_W_hz * W_sub_per_channel_Hz
            gamma_sub = (P_tx_sub_W * h_comb_sub_power) / noise_power_sub
            rate_sub_bps = W_sub_per_channel_Hz * np.log2(1 + gamma_sub)

            # --- mmWave Channel Model ---
            is_blocked = self.nblocks[i] > 0
            if is_blocked: 
                shadowing_nlos_db = np.random.normal(0, 8.7) 
                PL_mm_db = 72 + 29.2 * np.log10(d) + shadowing_nlos_db
            else: 
                shadowing_los_db = np.random.normal(0, 5.8) 
                PL_mm_db = 61.4 + 20 * np.log10(d) + shadowing_los_db
            PL_mm_lin = self._db_to_linear(-PL_mm_db)
            h_small_scale_mm_power = np.random.rayleigh(scale=1.0)**2
            G_tx_lin = self._tx_beam_gain(mmWave_tx_beamwidth_rad, mmWave_tx_sidelobe_gain) 
            G_rx_lin = mmWave_rx_gain_lin 
            h_comb_mm_power = G_tx_lin * h_small_scale_mm_power * PL_mm_lin * G_rx_lin 
            noise_power_mm = noise_density_W_hz * W_mm_Hz
            gamma_mm = (P_tx_mm_W * h_comb_mm_power) / noise_power_mm
            rate_mm_bps = W_mm_Hz * np.log2(1 + gamma_mm)

            achievable_rate[i] = (max(0, rate_sub_bps), max(0, rate_mm_bps))
            
        achievable_packets = [
            list(map(lambda x: int(x * FRAME_DURATION / PACKET_SIZE), A))
            for A in achievable_rate
        ]
        
        q_learn_act = None
        eta_idx = None
        safe = True 
        
        if not USE_RISK_AVERSE and USE_DEEP_RISK:
            q_learn_act, eta_idx, safe = self.qlearn.get_current_action(self.cur_frame)
        else:
            q_learn_act, safe = self.qlearn.get_current_action(self.cur_frame)
        
        for d_idx, action_choice in enumerate(q_learn_act):
            self.choice_counts[d_idx][action_choice] += 1
        
        action = self.qlearn.map_action(q_learn_act) 

        success = [[0, 0] for _ in range(num_devices)] 
        if safe: 
            success = [[min(act, achi) for act, achi in zip(a_proposed, a_achievable)]
                       for a_proposed, a_achievable in zip(action, achievable_packets)]
        
        reward = 0
        if not USE_RISK_AVERSE and USE_DEEP_RISK:
            reward = self.qlearn.update_to_new_state(success, self.cur_frame, q_learn_act, eta_idx, achievable_rate)
        else:
            reward = self.qlearn.update_to_new_state(success, self.cur_frame, q_learn_act, achievable_rate)
        
        metric = 0 
        for k in range(num_devices):
            total_sent = sum(action[k])
            total_received = sum(success[k])
            
            cur_psr = 0 
            if total_sent > 0:
                cur_psr = total_received / total_sent
            cur_plr = 1 - cur_psr 
            
            old_plr_sum = self.plr[k] * (self.cur_frame - 1) if self.cur_frame > 1 else 0
            self.plr[k] = (old_plr_sum + cur_plr) / self.cur_frame if self.cur_frame > 0 else cur_plr
            
            metric += self.qlearn.PLR_req - self.plr[k] 
        metric /= num_devices 
        
        self.RewList.append(reward) 
        self.MetricList.append(metric) 
        self.runnRew += reward 

        # Collect data for final and live plotting
        self._collect_plot_data(metric, reward, self.plr, achievable_rate)

        # Print periodic status updates AND update live plots
        if self.cur_frame % self.live_plot_update_interval == 0 or self.cur_frame == 1:
            clear_output(wait=True) # Clear previous output in the notebook cell
            print(f"--- Live Update: Frame {self.cur_frame}/{self.global_config['nframes']} ---")
            print(f"  Current Reward: {reward:.4f} | Avg Reward (overall): {self.runnRew / max(1, self.cur_frame) / max(1, num_devices):.4f}")
            print(f"  Current Metric (ΔP): {metric:.4f}")
            print(f"  Avg PLR per device: {[f'{p:.4f}' for p in self.plr]}")
            
            self._update_live_plots()

    def _collect_plot_data(self, metric, reward, current_plr_list, current_achievable_rate_list):
        """
        Internal method to collect simulation data points for later plotting.
        """
        self.plot_data_metric.append((self.cur_frame, metric))
        self.plot_data_reward.append((self.cur_frame, self.runnRew / max(1, self.cur_frame) / max(1, len(self.devices_data))))
        
        if not self.plot_data_plr_per_device: 
            self.plot_data_plr_per_device = [[] for _ in self.devices_data]
        if not self.plot_data_sub6_rate_per_device or len(self.plot_data_sub6_rate_per_device) != len(self.devices_data):
            self.plot_data_sub6_rate_per_device = [[] for _ in self.devices_data]
        if not self.plot_data_mmwave_rate_per_device or len(self.plot_data_mmwave_rate_per_device) != len(self.devices_data):
            self.plot_data_mmwave_rate_per_device = [[] for _ in self.devices_data]

        for i, plr_val in enumerate(current_plr_list):
            self.plot_data_plr_per_device[i].append((self.cur_frame, plr_val))
        for i, (sub6_rate, mmwave_rate) in enumerate(current_achievable_rate_list):
            self.plot_data_sub6_rate_per_device[i].append((self.cur_frame, sub6_rate / 1e6)) # Convert to Mbps for plotting
            self.plot_data_mmwave_rate_per_device[i].append((self.cur_frame, mmwave_rate / 1e6)) # Convert to Mbps for plotting
    
    def _update_live_plots(self):
        """
        Generates and displays live Matplotlib plots directly in the notebook cell.
        Called periodically during the simulation.
        """
        # Apply a nice Matplotlib style
        plt.style.use('seaborn-v0_8-darkgrid')

        # --- Plot 1: Metric (Delta P) ---
        plt.figure(figsize=(10, 3))
        if self.plot_data_metric:
            frames, metrics = zip(*self.plot_data_metric)
            plt.plot(frames, metrics, label="Metric (ΔP)", color='blue')
        plt.title(f"Metric (ΔP) - Frame {self.cur_frame}")
        plt.xlabel("Frame")
        plt.ylabel("Metric (ΔP)")
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.close() # Close figure to free memory

        # --- Plot 2: Average Reward per Device ---
        plt.figure(figsize=(10, 3))
        if self.plot_data_reward:
            frames, rewards = zip(*self.plot_data_reward)
            plt.plot(frames, rewards, label="Average Reward per Device", color='green')
        plt.title(f"Average Reward per Device - Frame {self.cur_frame}")
        plt.xlabel("Frame")
        plt.ylabel("Average Reward")
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.close()

        # --- Plot 3: Packet Loss Rate (PLR) per Device ---
        plt.figure(figsize=(10, 3))
        for i, device_plr_data in enumerate(self.plot_data_plr_per_device):
            if device_plr_data:
                frames, plrs = zip(*device_plr_data)
                plt.plot(frames, plrs, label=f"Device {i+1} PLR")
        
        if self.qlearn and hasattr(self.qlearn, 'PLR_req'):
            plt.axhline(y=self.qlearn.PLR_req, color='dimgray', linestyle='--', linewidth=1.2, label='PLR Max (Target)')
        else: 
            plt.axhline(y=0.1, color='dimgray', linestyle='--', linewidth=1.2, label='Default PLR Max (0.1)')

        plt.title(f"Packet Loss Rate (PLR) - Frame {self.cur_frame}")
        plt.xlabel("Frame")
        plt.ylabel("PLR")
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.close()

        # --- Plot 4: Current Choice Distribution Bar Graph ---
        plt.figure(figsize=(10, 3))
        num_devices = len(self.devices_data)
        
        if num_devices > 0:
            ind = np.arange(num_devices) 
            width = 0.2 
            
            final_choice_counts_array = np.array(self.choice_counts)
            total_choices_per_device = final_choice_counts_array.sum(axis=1) 
            
            normalized_counts = np.zeros_like(final_choice_counts_array, dtype=float)
            for i in range(num_devices):
                if total_choices_per_device[i] > 0:
                    normalized_counts[i, :] = final_choice_counts_array[i, :] / total_choices_per_device[i]
            
            plt.bar(ind - width, normalized_counts[:, 0], width, label="Sub-6GHz", color='skyblue')
            plt.bar(ind, normalized_counts[:, 1], width, label="mmWave", color='lightcoral')
            plt.bar(ind + width, normalized_counts[:, 2], width, label="Idle", color='lightgreen')

            plt.title(f"Choice Distribution - Frame {self.cur_frame}")
            plt.xlabel("Device")
            plt.ylabel("Ratio of Choices")
            plt.xticks(ind, [f"Device {i+1}" for i in range(num_devices)])
            plt.legend()
            plt.ylim(0, 1) 
            plt.tight_layout()
            plt.show()
            plt.close()

        # --- Plot 5: Achievable Rate (Sub-6GHz) per Device ---
        plt.figure(figsize=(10, 3))
        for i, device_sub6_rate_data in enumerate(self.plot_data_sub6_rate_per_device):
            if device_sub6_rate_data:
                frames, rates = zip(*device_sub6_rate_data)
                plt.plot(frames, rates, label=f"Device {i+1}")
        
        plt.title(f"Achievable Rate (Sub-6GHz) - Frame {self.cur_frame}")
        plt.xlabel("Frame")
        plt.ylabel("Rate (Mbps)")
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.close()

        plt.figure(figsize=(10, 3))
        for i, device_mmwave_rate_data in enumerate(self.plot_data_mmwave_rate_per_device):
            if device_mmwave_rate_data:
                frames, rates = zip(*device_mmwave_rate_data)
                plt.plot(frames, rates, label=f"Device {i+1}")
        
        plt.title(f"Achievable Rate (mmWave) - Frame {self.cur_frame}")
        plt.xlabel("Frame")
        plt.ylabel("Rate (Mbps)")
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.close()

    def run_simulation(self):
        """
        Starts and runs the entire simulation process.
        """
        if not self.devices_data:
            print("No devices configured. Please load a JSON layout or manually add device data.")
            return

        self.setup_simulation() 
        
        nframes = self.global_config['nframes']
        print(f"\nStarting simulation for {nframes} frames...")

        for frame in range(nframes):
            self.cur_frame = frame + 1 
            self.simulation_step() 
        
        clear_output(wait=True) # Clear final live plot to show only saved plots and final log
        print("\nSimulation finished.")
        self.generate_and_save_outputs() 

    def generate_and_save_outputs(self):
        """
        Generates Matplotlib plots of simulation metrics and saves them as images.
        Also saves a comprehensive text log file. This is for the *final* plots.
        """
        timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = "/kaggle/working/output/"
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        plt.style.use('seaborn-v0_8-darkgrid') 

        # --- Plot 1: Metric (Delta P) over Frames ---
        plt.figure(figsize=(10, 6))
        if self.plot_data_metric:
            frames, metrics = zip(*self.plot_data_metric)
            plt.plot(frames, metrics, label="Metric (ΔP)", color='blue')
        plt.title("Final Metric (ΔP) over Frames")
        plt.xlabel("Frame")
        plt.ylabel("Metric (ΔP)")
        plt.legend()
        plt.tight_layout()
        metric_filename = os.path.join(output_dir, f"{timestamp_str}_metric_{SUFFIX_STR}.png")
        plt.savefig(metric_filename)
        plt.close() 
        print(f"Saved figure: {metric_filename}")

        # --- Plot 2: Average Reward per Device over Frames ---
        plt.figure(figsize=(10, 6))
        if self.plot_data_reward:
            frames, rewards = zip(*self.plot_data_reward)
            plt.plot(frames, rewards, label="Average Reward per Device", color='green')
        plt.title("Final Average Reward per Device over Frames")
        plt.xlabel("Frame")
        plt.ylabel("Average Reward")
        plt.legend()
        plt.tight_layout()
        reward_filename = os.path.join(output_dir, f"{timestamp_str}_reward_{SUFFIX_STR}.png")
        plt.savefig(reward_filename)
        plt.close()
        print(f"Saved figure: {reward_filename}")

        # --- Plot 3: Packet Loss Rate (PLR) per Device over Frames ---
        plt.figure(figsize=(10, 6))
        for i, device_plr_data in enumerate(self.plot_data_plr_per_device):
            if device_plr_data:
                frames, plrs = zip(*device_plr_data)
                plt.plot(frames, plrs, label=f"Device {i+1} PLR")
        
        if self.qlearn and hasattr(self.qlearn, 'PLR_req'):
            plt.axhline(y=self.qlearn.PLR_req, color='dimgray', linestyle='--', linewidth=1.2, label='PLR Max (Target)')
        else: 
            plt.axhline(y=0.1, color='dimgray', linestyle='--', linewidth=1.2, label='Default PLR Max (0.1)')

        plt.title("Final Packet Loss Rate (PLR) per Device")
        plt.xlabel("Frame")
        plt.ylabel("PLR")
        plt.legend()
        plt.tight_layout()
        plr_filename = os.path.join(output_dir, f"{timestamp_str}_plr_{SUFFIX_STR}.png")
        plt.savefig(plr_filename)
        plt.close()
        print(f"Saved figure: {plr_filename}")

        # --- Plot 4: Final Choice Distribution Bar Graph ---
        plt.figure(figsize=(10, 6))
        num_devices = len(self.devices_data)
        
        if num_devices > 0:
            ind = np.arange(num_devices) 
            width = 0.2 
            
            final_choice_counts_array = np.array(self.choice_counts)
            total_choices_per_device = final_choice_counts_array.sum(axis=1) 
            
            normalized_counts = np.zeros_like(final_choice_counts_array, dtype=float)
            for i in range(num_devices):
                if total_choices_per_device[i] > 0:
                    normalized_counts[i, :] = final_choice_counts_array[i, :] / total_choices_per_device[i]
            
            plt.bar(ind - width, normalized_counts[:, 0], width, label="Sub-6GHz", color='skyblue')
            plt.bar(ind, normalized_counts[:, 1], width, label="mmWave", color='lightcoral')
            plt.bar(ind + width, normalized_counts[:, 2], width, label="Idle", color='lightgreen')

            plt.title("Final Choice Distribution (Ratio of Total Choices)")
            plt.xlabel("Device")
            plt.ylabel("Ratio of Choices")
            plt.xticks(ind, [f"Device {i+1}" for i in range(num_devices)])
            plt.legend()
            plt.ylim(0, 1) 
            plt.tight_layout()
            choice_filename = os.path.join(output_dir, f"{timestamp_str}_choice_distribution_{SUFFIX_STR}.png")
            plt.savefig(choice_filename)
            plt.close()
            print(f"Saved figure: {choice_filename}")
        else:
            print("Skipping Choice Distribution plot: No devices found.")

        # --- Plot 5: Final Achievable Rate (Sub-6GHz) per Device ---
        plt.figure(figsize=(10, 6))
        for i, device_sub6_rate_data in enumerate(self.plot_data_sub6_rate_per_device):
            if device_sub6_rate_data:
                frames, rates = zip(*device_sub6_rate_data)
                plt.plot(frames, rates, label=f"Device {i+1} Sub-6GHz Rate")
        
        plt.title("Final Achievable Rate (Sub-6GHz) per Device")
        plt.xlabel("Frame")
        plt.ylabel("Rate (Mbps)")
        plt.legend()
        plt.tight_layout()
        sub6_rate_filename = os.path.join(output_dir, f"{timestamp_str}_sub6_rate_{SUFFIX_STR}.png")
        plt.savefig(sub6_rate_filename)
        plt.close()
        print(f"Saved figure: {sub6_rate_filename}")

        # --- Plot 6: Final Achievable Rate (mmWave) per Device ---
        plt.figure(figsize=(10, 6))
        for i, device_mmwave_rate_data in enumerate(self.plot_data_mmwave_rate_per_device):
            if device_mmwave_rate_data:
                frames, rates = zip(*device_mmwave_rate_data)
                plt.plot(frames, rates, label=f"Device {i+1} mmWave Rate")
        
        plt.title("Final Achievable Rate (mmWave) per Device")
        plt.xlabel("Frame")
        plt.ylabel("Rate (Mbps)")
        plt.legend()
        plt.tight_layout()
        mmwave_rate_filename = os.path.join(output_dir, f"{timestamp_str}_mmwave_rate_{SUFFIX_STR}.png")
        plt.savefig(mmwave_rate_filename)
        plt.close()
        print(f"Saved figure: {mmwave_rate_filename}")

        # --- Save Simulation Log File ---
        log_filepath = os.path.join(output_dir, f'{timestamp_str}_log_{SUFFIX_STR}.txt')
        with open(log_filepath, 'w+') as f:
            f.write(f'--- Simulation Log ---\n')
            f.write(f'Timestamp: {timestamp_str}\n')
            f.write(f'Q-Learning Model: {"RiskAverseQLearning" if USE_RISK_AVERSE else "RiskAverseDeepQLearning" if USE_DEEP_RISK else "DeepQLearning"}\n')
            f.write(f'Number of Devices: {len(self.devices_data)}\n')
            f.write(f'Number of Blockages: {len(self.blockages_data)}\n')
            f.write(f'Simulated Frames: {self.global_config["nframes"]}\n')
            f.write(f'\n--- Global Configuration ---\n')
            for k, v in self.global_config.items():
                f.write(f'{k}: {v}\n')
            f.write(f'\n--- Device Details ---\n')
            for i, dev in enumerate(self.devices_data):
                f.write(f'Device {i+1}: Position={dev["pos"]}, Distance={self.distances[i]:.2f}m, Blockages={self.nblocks[i]}\n')
            f.write(f'\n--- Final Simulation Results ---\n')
            
            final_avg_plr = sum(self.plr) / len(self.devices_data) if len(self.devices_data) > 0 else 0
            f.write(f'Final Average PLR across all devices: {final_avg_plr:.4f}\n')
            f.write(f'PLR for each device: [ {" ".join(f"{x:.4f}" for x in self.plr)} ]\n')
            
            avg_success = (len(self.devices_data) - sum(self.plr)) / len(self.devices_data) if len(self.devices_data) > 0 else 0
            f.write(f'Average Success Rate across all devices: {avg_success:.4f}\n')
            
            final_avg_reward_per_frame_per_device = self.plot_data_reward[-1][1] if self.plot_data_reward else 0
            f.write(f'Final Average Reward per Device per Frame: {final_avg_reward_per_frame_per_device:.4f}\n')
            
            final_metric = self.MetricList[-1] if self.MetricList else 0
            f.write(f'Final Metric (ΔP): {final_metric:.4f}\n')

            f.write(f'\n--- Achievable Rates (Mbps) ---\n')
            if len(self.devices_data) > 0:
                avg_sub6_rates = [np.mean([rate for _, rate in dev_data]) for dev_data in self.plot_data_sub6_rate_per_device if dev_data]
                avg_mmwave_rates = [np.mean([rate for _, rate in dev_data]) for dev_data in self.plot_data_mmwave_rate_per_device if dev_data]
                
                f.write(f'Average Sub-6GHz Rates per device: [ {" ".join(f"{x:.2f}" for x in avg_sub6_rates)} ] Mbps\n')
                f.write(f'Average mmWave Rates per device: [ {" ".join(f"{x:.2f}" for x in avg_mmwave_rates)} ] Mbps\n')
                
                overall_avg_sub6_rate = np.mean(avg_sub6_rates) if avg_sub6_rates else 0
                overall_avg_mmwave_rate = np.mean(avg_mmwave_rates) if avg_mmwave_rates else 0
                f.write(f'Overall Average Sub-6GHz Rate: {overall_avg_sub6_rate:.2f} Mbps\n')
                f.write(f'Overall Average mmWave Rate: {overall_avg_mmwave_rate:.2f} Mbps\n')
            else:
                f.write(f'No device data available for rate calculations.\n')
            
            f.write(f'\n--- Data per Frame (Metric | Reward) ---\n')
            for f_idx in range(len(self.MetricList)):
                f.write(f'Frame {f_idx+1}: Metric={self.MetricList[f_idx]:.4f} | Reward={self.RewList[f_idx]:.4f}\n')
                if self.plot_data_sub6_rate_per_device and self.plot_data_mmwave_rate_per_device:
                    f.write(f'  Rates (Mbps): ')
                    for d_idx in range(len(self.devices_data)):
                        sub6_r = self.plot_data_sub6_rate_per_device[d_idx][f_idx][1] if f_idx < len(self.plot_data_sub6_rate_per_device[d_idx]) else 0
                        mmwave_r = self.plot_data_mmwave_rate_per_device[d_idx][f_idx][1] if f_idx < len(self.plot_data_mmwave_rate_per_device[d_idx]) else 0
                        f.write(f'Device {d_idx+1}(S6:{sub6_r:.2f}, MM:{mmwave_r:.2f}) ')
                    f.write('\n')
        print(f"Saved log: {log_filepath}")

In [16]:
#gpu check
print(torch.cuda.is_available())

True


In [17]:
#json

#one_json
one_json="""{
  "devices": [
    {
      "pos": [
        0,
        -20
      ]
    },
    {
      "pos": [
        20,
        0
      ]
    },
    {
      "pos": [
        -60,
        60
      ]
    }
  ],
  "blockages": [
    {
      "pos": [
        10,
        0
      ]
    }
  ],
  "settings": {
    "pwr": 5.0,
    "noise": -169.0,
    "bw_sub6": 100.0,
    "bw_mm": 1000.0,
    "num_subchannel": 4,
    "num_beam": 4,
    "nframes": 10000
  }
}"""

two_json="""{
  "devices": [
    {
      "pos": [
        0,
        -20
      ]
    },
    {
      "pos": [
        20,
        0
      ]
    },
    {
      "pos": [
        -60,
        60
      ]
    },
    {
      "pos": [
        -40,
        -40
      ]
    },
    {
      "pos": [
        15,
        60
      ]
    },
    {
      "pos": [
        -20,
        20
      ]
    },
    {
      "pos": [
        -40,
        -10
      ]
    },
    {
      "pos": [
        55,
        -55
      ]
    },
    {
      "pos": [
        50,
        0
      ]
    },
    {
      "pos": [
        50,
        40
      ]
    }
  ],
  "blockages": [
    {
      "pos": [
        10,
        0
      ]
    },
    {
      "pos": [
        -10,
        10
      ]
    }
  ],
  "settings": {
    "pwr": 5.0,
    "noise": -169.0,
    "bw_sub6": 100.0,
    "bw_mm": 1000.0,
    "num_subchannel": 16,
    "num_beam": 16,
    "nframes": 10000
  }
}"""

if not os.path.exists("/kaggle/working/input/"):
    os.makedirs("/kaggle/working/input/")
with open("/kaggle/working/input/1.json", "w") as f:
    f.write(one_json)
with open("/kaggle/working/input/2.json", "w") as f:
    f.write(two_json)

In [18]:


json_file_path = "/kaggle/working/input/2.json"

for _ in range(5):
    simulator = NetworkSimulator(json_config_path=json_file_path)
    simulator.run_simulation()


Simulation finished.
Saved figure: /kaggle/working/output/20250620_065915_metric_DEEP.png
Saved figure: /kaggle/working/output/20250620_065915_reward_DEEP.png
Saved figure: /kaggle/working/output/20250620_065915_plr_DEEP.png
Saved figure: /kaggle/working/output/20250620_065915_choice_distribution_DEEP.png
Saved figure: /kaggle/working/output/20250620_065915_sub6_rate_DEEP.png
Saved figure: /kaggle/working/output/20250620_065915_mmwave_rate_DEEP.png
Saved log: /kaggle/working/output/20250620_065915_log_DEEP.txt


In [19]:
# save output
!zip -r /kaggle/working/output.zip /kaggle/working/output



updating: kaggle/working/output/ (stored 0%)
updating: kaggle/working/output/20250620_055646_sub6_rate_DEEPRA.png (deflated 2%)
updating: kaggle/working/output/20250620_060650_log_DEEPRA.txt (deflated 77%)
updating: kaggle/working/output/20250620_061152_metric_DEEPRA.png (deflated 17%)
updating: kaggle/working/output/20250620_060650_mmwave_rate_DEEPRA.png (deflated 4%)
updating: kaggle/working/output/20250620_060153_plr_DEEPRA.png (deflated 9%)
updating: kaggle/working/output/20250620_061152_sub6_rate_DEEPRA.png (deflated 3%)
updating: kaggle/working/output/20250620_061152_choice_distribution_DEEPRA.png (deflated 22%)
updating: kaggle/working/output/20250620_060650_sub6_rate_DEEPRA.png (deflated 3%)
updating: kaggle/working/output/20250620_060153_choice_distribution_DEEPRA.png (deflated 22%)
updating: kaggle/working/output/20250620_060153_reward_DEEPRA.png (deflated 12%)
updating: kaggle/working/output/20250620_060153_mmwave_rate_DEEPRA.png (deflated 4%)
updating: kaggle/working/output

In [20]:
# teho 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 79879
print("23:17")

23:17
