In [19]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import PPO, DQN, A2C
from stable_baselines3.common.vec_env import DummyVecEnv
from gymnasium import spaces
import torch as th

with open('c_states_dict_v4.pkl', 'rb') as f:
    c_states = pickle.load(f)

def my_norm(vec):
    vec = vec-np.mean(vec)
    vec = vec/np.std(vec)
    return vec

# for snr in c_states.keys():
#     data = c_states[snr]
#     for episode in range(data.shape[0]):
        
print(c_states.keys())
print(c_states[5].shape)

dict_keys([5, 2.5, 0, -2.5, -5])
(2500, 18, 7, 120)


In [11]:


class FrequencyDetectionEnv(gym.Env):
    """Custom Environment for Frequency Objective Response Detection that follows gymnasium interface"""

    def __init__(self, data):
        super(FrequencyDetectionEnv, self).__init__()

        self.data = data
        self.episodes = data.shape[0]-1
        self.frequencies = data.shape[1]-1
        self.measures = data.shape[2]
        self.windows = data.shape[3]-1
        self.current_episode = 0
        self.current_window = 2
        self.current_frequency = 0

        # Action space: 0 (no detection) or 1 (detection)
        self.action_space = spaces.Discrete(2)

        # Observation space: 5 measures for each of the 4 frequencies (1 focus + 3 noise)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4*self.measures,), dtype=np.float32)

        # Initialize the internal state
        self.state = None
        self.false_positives = 0
        self.true_positives = 0
        self.fp_rate = 0
        self.tp_rate = 0   
        self.tpr_hist = []
        self.fpr_hist = []
        self.latest_result = {'tp_rate':np.nan,'fp_rate':np.nan}

    def reset(self, seed=None, options=None):
        """Resets the environment to an initial state and returns the initial observation."""
        super().reset(seed=seed)
        
        if  self.current_episode< self.episodes:
            self.current_episode +=1
        else:
            self.current_episode = np.random.randint(0, self.episodes)

        self.current_window = 2
        self.current_frequency = 0
        self.false_positives = 0
        self.true_positives = 0
        self.fp_rate = 0
        self.tp_rate = 0        
         
        self.tpr_hist = []
        self.fpr_hist = []
        
        # Initialize the state
        self.state = self._get_state(self.current_frequency, self.current_window)
        return self.state, self.latest_result

    def _get_state(self, frequency, window, episode= 0):
        """Helper function to get the state at a given frequency and window."""
         
        if episode != 0:
            self.current_episode = episode

        focus_measure = self.data[self.current_episode, frequency, :, window]
        
        # Select 3 random noise frequencies, excluding the current focus frequency
        noise_indices = np.setdiff1d(np.arange(9, 17), frequency - 9 if frequency >= 9 else [])
        noise_frequencies = np.random.choice(noise_indices, 3, replace=False)
        noise_measures = [self.data[self.current_episode, nf, :, window] for nf in noise_frequencies]
         
        measures = np.concatenate(([focus_measure], noise_measures), axis=0).flatten()

        return np.array(measures, dtype=np.float32)

    def step(self, action, episode = -1):
        """Executes one time step within the environment."""
        # Returns:
        # tuple (observation, reward, terminated, truncated, info).
        # Return type:
        # Tuple[Tuple | Dict[str, Any] | ndarray | int, float, bool, bool, Dict]
        
        if episode != -1:
            self.current_episode = episode

        reward = 0
        kp_fp = 1
        fp_des = 0.05
        
        should_detect = self.current_frequency <= 8
        if action == 1:
            if should_detect:
                self.true_positives += 1
            else:
                self.false_positives += 1

        self.current_frequency += 1
        window_done = self.current_frequency >= self.frequencies
        epsiode_terminated = self.current_window >= self.windows
        info = {'tp_rate':np.nan,'fp_rate':np.nan}

        if window_done:
            self.state = self._get_state(self.current_frequency-1, self.current_window)

            current_ep_prog = 100
            
            self.tp_rate = 100*self.true_positives / (self.frequencies//2+1)
            self.fp_rate = 100*self.false_positives / (self.frequencies//2+1)
                
            reward = +( -(((100-fp_des)**2)/100 +(fp_des**2)/-100) +((self.fp_rate)**2)/(-100)+ ((self.tp_rate)**2)/(100) )*current_ep_prog/100
            
            self.tpr_hist.append(self.tp_rate)
            self.fpr_hist.append(self.fp_rate)

            if not(epsiode_terminated):
                self.current_window += 1
                self.current_frequency = 0
                self.true_positives = 0
                self.false_positives = 0
                    
        if epsiode_terminated:
            self.state = self._get_state(self.current_frequency, self.current_window)
            current_ep_prog =100

            self.tp_rate = 100*self.true_positives / (self.current_frequency+1)
            self.fp_rate = 100*self.false_positives / (self.current_frequency+1)

            reward = +( -(((100-fp_des)**2)/100+(fp_des**2)/-100)
                    +(kp_fp*(np.round(np.mean(self.fpr_hist),4))**2)/(-100)
                    +((np.round(np.mean(self.tpr_hist),4))**2)/(100))*current_ep_prog/100
           
            truncated = False

            info = {'tp_rate':np.round(np.mean(self.tpr_hist),2),'fp_rate':np.round(np.mean(self.fpr_hist),2)}
            self.latest_result = info

            self.tpr_hist = []
            self.fpr_hist = [] 
            self.current_window = 2
            self.current_frequency = 0
            self.true_positives = 0
            self.false_positives = 0

        else:
            self.state = self._get_state(self.current_frequency, self.current_window)

            should_detect = self.current_frequency <= 8
            tp = 0; fp = 0; fn =0; tn = 0
            if action == 1:
                if should_detect:
                    tp = 1
                else:
                    fp = 1
            else:
                if should_detect:
                    fn = 1
                else:
                    tn = 1

            reward = tp-kp_fp*fp-fn+tn

            truncated = False

        return self.state, reward, epsiode_terminated, truncated, info
    def render(self, mode='human'):
        """Render the environment for visualization purposes."""
        if mode=='human':
            print(f'Window: {self.current_window}, Frequency: {self.current_frequency},TP: {self.tp_rate}, FP: {self.fp_rate}')

    def close(self):
        """Clean up any resources used by the environment."""
        pass

In [3]:
class TEST_FrequencyDetectionEnv(gym.Env):
    """Custom Environment for Frequency Objective Response Detection that follows gymnasium interface"""

    def __init__(self, data):
        super(TEST_FrequencyDetectionEnv, self).__init__()

        self.data = data
        self.episodes = data.shape[0]-1
        self.frequencies = data.shape[1]-1
        self.measures = data.shape[2]
        self.windows = data.shape[3]-1
        self.current_episode = 0
        self.current_window = 2
        self.current_frequency = 0

        # Action space: 0 (no detection) or 1 (detection)
        self.action_space = spaces.Discrete(2)

        # Observation space: 5 measures for each of the 4 frequencies (1 focus + 3 noise)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4*self.measures,), dtype=np.float32)

        # Initialize the internal state
        self.state = None
        self.false_positives = 0
        self.true_positives = 0
        self.fp_rate = 0
        self.tp_rate = 0   
        self.tpr_hist = []
        self.fpr_hist = []  
        self.latest_result = {'tp_rate':np.nan,'fp_rate':np.nan}   

    def reset(self, seed=None, options=None):
        """Resets the environment to an initial state and returns the initial observation."""
        super().reset(seed=seed)
        
        if  self.current_episode< self.episodes:
            self.current_episode +=1
        else:
            self.current_episode = np.random.randint(0, self.episodes)

        self.current_window = 2
        self.current_frequency = 0
        self.false_positives = 0
        self.true_positives = 0
        self.fp_rate = 0
        self.tp_rate = 0        
         
        self.tpr_hist = []
        self.fpr_hist = []
        
        # Initialize the state
        self.state = self._get_state(self.current_frequency, self.current_window)
        return self.state, self.latest_result

    def _get_state(self, frequency, window, episode= 0):
        """Helper function to get the state at a given frequency and window."""
         
        if episode != 0:
            self.current_episode = episode
         
        # print(frequency)
        focus_measure = self.data[self.current_episode, frequency, :, window]
        
        # Select 3 random noise frequencies, excluding the current focus frequency
        noise_indices = np.setdiff1d(np.arange(9, 17), frequency - 9 if frequency >= 9 else [])
        noise_frequencies = np.random.choice(noise_indices, 3, replace=False)
        noise_measures = [self.data[self.current_episode, nf, :, window] for nf in noise_frequencies]
         
        measures = np.concatenate(([focus_measure], noise_measures), axis=0).flatten()
       
        return np.array(measures, dtype=np.float32)

    def step(self, action, episode = -1):
        """Executes one time step within the environment."""
        # Returns:
        # tuple (observation, reward, terminated, truncated, info).
        # Return type:
        # Tuple[Tuple | Dict[str, Any] | ndarray | int, float, bool, bool, Dict]

        if episode != -1:
            self.current_episode = episode

        reward = 0
        fp_des = 0.05

        should_detect = self.current_frequency <= 8
        if action == 1:
            if should_detect:
                self.true_positives += 1
            else:
                self.false_positives += 1

        self.current_frequency += 1
        window_done = self.current_frequency >= self.frequencies
        epsiode_terminated = self.current_window >= self.windows
        info = {'tp_rate':np.nan,'fp_rate':np.nan}
        if window_done:
            self.state = self._get_state(self.current_frequency-1, self.current_window)

            should_detect = self.current_frequency <= 8
            tp = 0; fp = 0; fn =0; tn = 0
            if action == 1:
                if should_detect:
                    tp = 0
                else:
                    fp = 1
            else:
                if should_detect:
                    fn = 0
                else:
                    tn = 1

            reward = tp-fp-fn+tn


            self.tp_rate = 100*self.true_positives / (self.frequencies//2+1)
            self.fp_rate = 100*self.false_positives / (self.frequencies//2+1)

            self.tpr_hist.append(self.tp_rate)
            self.fpr_hist.append(self.fp_rate)

            if not(epsiode_terminated):
                self.current_window += 1
                self.current_frequency = 0
                self.true_positives = 0
                self.false_positives = 0
                    
        if epsiode_terminated:
            self.state = self._get_state(self.current_frequency, self.current_window)

            self.tp_rate = 100*self.true_positives / (self.current_frequency+1)
            self.fp_rate = 100*self.false_positives / (self.current_frequency+1)

            should_detect = self.current_frequency <= 8
            tp = 0; fp = 0; fn =0; tn = 0
            if action == 1:
                if should_detect:
                    tp = 0
                else:
                    fp = 1
            else:
                if should_detect:
                    fn = 0
                else:
                    tn = 1

            reward = tp-fp-fn+tn
           
            truncated = False

            info = {'tp_rate':np.round(np.mean(self.tpr_hist),2),'fp_rate':np.round(np.mean(self.fpr_hist),2)}
            self.latest_result = info

            self.tpr_hist = []
            self.fpr_hist = [] 
            self.current_window = 2
            self.current_frequency = 0
            self.true_positives = 0
            self.false_positives = 0

        else:
            self.state = self._get_state(self.current_frequency, self.current_window)

            should_detect = self.current_frequency <= 8
            tp = 0; fp = 0; fn =0; tn = 0
            if action == 1:
                if should_detect:
                    tp = 0
                else:
                    fp = 1
            else:
                if should_detect:
                    fn = 0
                else:
                    tn = 1

            reward = tp-fp-fn+tn

            truncated = False

        reward = tp-fp-fn+tn

        return self.state, reward, epsiode_terminated, truncated, info

    def render(self, mode='human'):
        """Render the environment for visualization purposes."""
        if mode =='human':
            print(f'Window: {self.current_window}, Frequency: {self.current_frequency},TP: {self.tp_rate}, FP: {self.fp_rate}')

    def close(self):
        """Clean up any resources used by the environment."""
        pass


In [75]:
snr = 0
data = c_states[snr]
timesteps = (120-2)*(18-1)*100
print(f'Total decisions: {timesteps}')
print(data.shape[2]+1)

Total decisions: 200600
8


In [64]:
# Train PPO model
# policy_kwargs = dict(activation_fn=th.nn.ReLU,
#                      net_arch=dict(data.shape[2]+1, vf=[data.shape[2]+1])
#                      )
# net_arch=dict(pi=[25,25], vf=[25,25]))
env = FrequencyDetectionEnv(data)
model_ppo = PPO('MlpPolicy', env, verbose=1#, learning_rate=5*1e-4,gamma=0.9,
            #exploration_fraction= 0.1, exploration_initial_eps = 1, exploration_final_eps = 0.05
            ,policy_kwargs=dict(activation_fn=th.nn.ReLU,
                     net_arch=dict(pi =[data.shape[2]+1], vf=[data.shape[2]+1]))
                )
model_ppo.learn(total_timesteps=timesteps)
model_ppo.save('mini_ppo_snr5.zip')
_, latest = env.reset()
print(latest)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1.99e+03 |
|    ep_rew_mean     | -94.6    |
| time/              |          |
|    fps             | 1005     |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.99e+03    |
|    ep_rew_mean          | -64         |
| time/                   |             |
|    fps                  | 841         |
|    iterations           | 2           |
|    time_elapsed         | 4           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010760534 |
|    clip_fraction        | 0.0954      |
|    clip_range           | 0.2         |
|    entropy_loss   

KeyboardInterrupt: 

In [76]:
# Train DQN model
env = FrequencyDetectionEnv(data)
model_dqn = DQN('MlpPolicy', env, verbose=1#, learning_rate=5*1e-4,gamma=0.9,
            #exploration_fraction= 0.1, exploration_initial_eps = 1, exploration_final_eps = 0.05
            ,policy_kwargs=dict(activation_fn=th.nn.ReLU,net_arch=[data.shape[2]+1])
                )
model_dqn.learn(total_timesteps=timesteps)
model_dqn.save('mini_dqn_snr0.zip')
_, latest = env.reset()
print(latest)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.99e+03 |
|    ep_rew_mean      | -69.6    |
|    exploration_rate | 0.623    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 932      |
|    time_elapsed     | 8        |
|    total_timesteps  | 7960     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.992    |
|    n_updates        | 1964     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.99e+03 |
|    ep_rew_mean      | 116      |
|    exploration_rate | 0.246    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 910      |
|    time_elapsed     | 17       |
|    total_timesteps  | 15920    |
| train/              |        

In [12]:
env = FrequencyDetectionEnv(data)
model_a2c = A2C('MlpPolicy', env, verbose=1#, learning_rate=5*1e-4,gamma=0.9,
            #exploration_fraction= 0.1, exploration_initial_eps = 1, exploration_final_eps = 0.05
            ,policy_kwargs=dict(activation_fn=th.nn.ReLU,net_arch=[data.shape[2]+1])
            )
model_a2c.learn(total_timesteps=timesteps)
model_a2c.save('mini_a2c_snr5.zip')
_, latest = env.reset()
print(latest)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
------------------------------------
| time/                 |          |
|    fps                | 583      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -0.678   |
|    explained_variance | -0.211   |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | -0.209   |
|    value_loss         | 0.706    |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 580      |
|    iterations         | 200      |
|    time_elapsed       | 1        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss       | -0.666   |
|    explained_variance | 0.284    |
|    learning_rate      | 0.0007   |
|    n_updates    

# EVAL 2 (exp data, filtered)

In [72]:
class EVAL_FrequencyDetectionEnv(gym.Env):
    """Custom Environment for Frequency Objective Response Detection that follows gymnasium interface"""

    def __init__(self, data):
        super(EVAL_FrequencyDetectionEnv, self).__init__()

        self.data = data
        self.episodes =  data.shape[2]-1
        self.frequencies = data.shape[0]-1
        self.measures = data.shape[1]
        self.windows = data.shape[2]-1
        self.current_episode = 0
        self.current_window = 2
        self.current_frequency = 0

        # Action space: 0 (no detection) or 1 (detection)
        self.action_space = spaces.Discrete(2)

        # Observation space: 5 measures for each of the 4 frequencies (1 focus + 3 noise)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4*self.measures,), dtype=np.float32)

        # Initialize the internal state
        self.state = None
        self.false_positives = 0
        self.true_positives = 0
        self.fp_rate = 0
        self.tp_rate = 0   
        self.tpr_hist = []
        self.fpr_hist = []  
        self.latest_result = {'tp_rate':np.nan,'fp_rate':np.nan}   

    def reset(self, seed=None, options=None):
        """Resets the environment to an initial state and returns the initial observation."""
        super().reset(seed=seed)
        
        if  self.current_episode< self.episodes:
            self.current_episode +=1
        else:
            self.current_episode = np.random.randint(0, self.episodes)

        self.current_window = 2
        self.current_frequency = 0
        self.false_positives = 0
        self.true_positives = 0
        self.fp_rate = 0
        self.tp_rate = 0        
         
        self.tpr_hist = []
        self.fpr_hist = []
        
        # Initialize the state
        self.state = self._get_state(self.current_frequency, self.current_window)
        return self.state, self.latest_result

    def _get_state(self, frequency, window, episode= 0):
        """Helper function to get the state at a given frequency and window."""
         
        if episode != 0:
            self.current_episode = episode
         
        # print(frequency)
        focus_measure = self.data[frequency, :, window]
        
        # Select 3 random noise frequencies, excluding the current focus frequency
        noise_indices = np.setdiff1d(np.arange(9, 17), frequency - 9 if frequency >= 9 else [])
        noise_frequencies = np.random.choice(noise_indices, 3, replace=False)
        noise_measures = [self.data[nf, :, window] for nf in noise_frequencies]
         
        # measures = my_norm(np.concatenate(([focus_measure], noise_measures), axis=0).flatten())
        measures = np.concatenate(([focus_measure], noise_measures), axis=0).flatten()
       
        return np.array(measures, dtype=np.float32)

    def step(self, action, episode = -1):
        """Executes one time step within the environment."""
        # Returns:
        # tuple (observation, reward, terminated, truncated, info).
        # Return type:
        # Tuple[Tuple | Dict[str, Any] | ndarray | int, float, bool, bool, Dict]

        if episode != -1:
            self.current_episode = episode

        reward = 0
        fp_des = 0.05

        should_detect = self.current_frequency <= 8
        if action == 1:
            if should_detect:
                self.true_positives += 1
            else:
                self.false_positives += 1

        self.current_frequency += 1
        window_done = self.current_frequency >= self.frequencies
        epsiode_terminated = self.current_window >= self.windows
        info = {'tp_rate':np.nan,'fp_rate':np.nan}
        if window_done:
            self.state = self._get_state(self.current_frequency-1, self.current_window)

            should_detect = self.current_frequency <= 8
            tp = 0; fp = 0; fn =0; tn = 0
            if action == 1:
                if should_detect:
                    tp = 0
                else:
                    fp = 1
            else:
                if should_detect:
                    fn = 0
                else:
                    tn = 1

            reward = tp-fp-fn+tn


            self.tp_rate = 100*self.true_positives / (self.frequencies//2+1)
            self.fp_rate = 100*self.false_positives / (self.frequencies//2+1)

            self.tpr_hist.append(self.tp_rate)
            self.fpr_hist.append(self.fp_rate)

            if not(epsiode_terminated):
                self.current_window += 1
                self.current_frequency = 0
                self.true_positives = 0
                self.false_positives = 0
                    
        if epsiode_terminated:
            self.state = self._get_state(self.current_frequency, self.current_window)

            self.tp_rate = 100*self.true_positives / (self.current_frequency+1)
            self.fp_rate = 100*self.false_positives / (self.current_frequency+1)

            should_detect = self.current_frequency <= 8
            tp = 0; fp = 0; fn =0; tn = 0
            if action == 1:
                if should_detect:
                    tp = 0
                else:
                    fp = 1
            else:
                if should_detect:
                    fn = 0
                else:
                    tn = 1

            reward = tp-fp-fn+tn
           
            truncated = False

            info = {'tp_rate':np.round(np.mean(self.tpr_hist),2),'fp_rate':np.round(np.mean(self.fpr_hist),2), 
                    'tp_hist:':self.tpr_hist,'fp_hist:':self.fpr_hist}
            
            # print(info)
            self.latest_result = info

            self.tpr_hist = []
            self.fpr_hist = [] 
            self.current_window = 2
            self.current_frequency = 0
            self.true_positives = 0
            self.false_positives = 0

        else:
            self.state = self._get_state(self.current_frequency, self.current_window)

            should_detect = self.current_frequency <= 8
            tp = 0; fp = 0; fn =0; tn = 0
            if action == 1:
                if should_detect:
                    tp = 0
                else:
                    fp = 1
            else:
                if should_detect:
                    fn = 0
                else:
                    tn = 1

            reward = tp-fp-fn+tn

            truncated = False

        reward = tp-fp-fn+tn

        return self.state, reward, epsiode_terminated, truncated, info

    def render(self, mode='human'):
        """Render the environment for visualization purposes."""
        if mode =='human':
            print(f'Window: {self.current_window}, Frequency: {self.current_frequency},TP: {self.tp_rate}, FP: {self.fp_rate}')

    def close(self):
        """Clean up any resources used by the environment."""
        pass

In [14]:
with open('c_states_dict_exp_filt.pkl', 'rb') as f:
    c_states_exp_filt = pickle.load(f)

print(c_states_exp_filt.keys())
print(c_states_exp_filt[1,1].shape)

dict_keys([(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (8, 1), (8, 2), (8, 3), (8, 4), (8, 5), (9, 1), (9, 2), (9, 3), (9, 4), (9, 5), (10, 1), (10, 2), (10, 3), (10, 4), (10, 5), (11, 1), (11, 2), (11, 3), (11, 4), (11, 5)])
(18, 7, 58)


In [33]:
print((data.shape[2]+1)*(data.shape[0]))

1062


In [None]:

for ivol in range(1,12):
    for iint in range(1,6):
        c_states_exp_filt[ivol,iint] = np.nan_to_num(c_states_exp_filt[ivol,iint], nan=0, posinf=10, neginf=-10)

# Validate the trained model
intensidades = ['70','60','50','40','30']

for iint in range(1,6):
    print(f'Intensidade: {intensidades[iint-1]} dB')
    for ivol in range(1,12):
        
        data = c_states_exp_filt[ivol,iint] 
        print(data.shape)
        env = EVAL_FrequencyDetectionEnv(data)
        obs, info = env.reset()
        # model = PPO.load('ppo_snr0-5.zip', env=env)
        # model = DQN.load('dqn_snr0-5.zip', env=env)
        
        # model = DQN.load('dqn_snr0_306000steps_308eps.zip', env=env)

        model = PPO.load('mini_ppo_snr5.zip', env=env)

        # model.learn(total_timesteps=(data.shape[2]-1-5)*(data.shape[0]-1))
        # model.learn(total_timesteps=1)
        # obs, info = env.reset()
        # print(info)
        
        for test_episode in range((data.shape[2]-2)*(data.shape[0]-1)):
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = env.step(action)

            if terminated:
                obs, info = env.reset()
                print(info)
            env.close()

In [78]:

for ivol in range(1,12):
    for iint in range(1,6):
        c_states_exp_filt[ivol,iint] = np.nan_to_num(c_states_exp_filt[ivol,iint], nan=0, posinf=10, neginf=-10)

# Validate the trained model
intensidades = ['70','60','50','40','30']
hist_tp = []
hist_fp = [] 

for iint in range(1,2):
    print(f'Intensidade: {intensidades[iint-1]} dB')
    for ivol in range(1,12):
        
        data = c_states_exp_filt[ivol,iint] 
        env = EVAL_FrequencyDetectionEnv(data)
        obs, info = env.reset()
        model = DQN.load('mini_dqn_snr5.zip', env=env)

        for test_episode in range((data.shape[2]-1)*(data.shape[0]-1)):
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = env.step(action)

            if terminated:
                obs, info = env.reset()
                hist_tp.append(info['tp_rate'])
                hist_fp.append(info['fp_rate'])
                # print(info)
            env.close()

print(f'TPR:{np.mean(hist_tp)}')
print(f'FPR:{np.mean(hist_fp)}')

Intensidade: 70 dB
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
TPR:85.56545454545453
FPR:30.101818181818178


In [17]:

for ivol in range(1,12):
    for iint in range(1,6):
        c_states_exp_filt[ivol,iint] = np.nan_to_num(c_states_exp_filt[ivol,iint], nan=0, posinf=10, neginf=-10)

# Validate the trained model
intensidades = ['70','60','50','40','30']

for iint in range(1,6):
    print(f'Intensidade: {intensidades[iint-1]} dB')
    for ivol in range(1,12):
        
        data = c_states_exp_filt[ivol,iint] 
        env = EVAL_FrequencyDetectionEnv(data)
        obs, info = env.reset()
        model = A2C.load('mini_a2c_snr5.zip', env=env)

        for test_episode in range((data.shape[2]-1)*(data.shape[0]-1)):
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = env.step(action)

            if terminated:
                obs, info = env.reset()
                print(info)
            env.close()

Intensidade: 70 dB
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'tp_rate': 98.79, 'fp_rate': 23.03}
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'tp_rate': 99.19, 'fp_rate': 36.36}
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'tp_rate': 98.99, 'fp_rate': 43.03}
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'tp_rate': 98.79, 'fp_rate': 25.05}
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'tp_rate': 98.79, 'fp_rate': 23.43}
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'tp_rate': 98.18, 'fp_rate': 46.26}
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'tp_rate': 97.58, 'fp_rate': 48.08}
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'tp_rate': 99.39, 'fp_rate': 23.03}
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a Dummy

KeyboardInterrupt: 