In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
from collections import namedtuple
import math

# Define transition tuple
Transition = namedtuple('Transition', 
                        ('state', 'action', 'reward', 'next_state', 'done'))

class NoisyLinear(nn.Module):
    """Noisy Linear layer for exploration"""
    def __init__(self, in_features, out_features, std_init=0.4):
        super(NoisyLinear, self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.std_init = std_init
        
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.register_buffer('weight_epsilon', torch.empty(out_features, in_features))
        
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))
        self.register_buffer('bias_epsilon', torch.empty(out_features))
        
        self.reset_parameters()
        self.reset_noise()
    
    def reset_parameters(self):
        mu_range = 1 / math.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features))
    
    def _scale_noise(self, size):
        x = torch.randn(size)
        return x.sign().mul_(x.abs().sqrt_())
    
    def reset_noise(self):
        epsilon_in = self._scale_noise(self.in_features)
        epsilon_out = self._scale_noise(self.out_features)
        self.weight_epsilon.copy_(epsilon_out.outer(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)
    
    def forward(self, x):
        if self.training:
            weight = self.weight_mu + self.weight_sigma * self.weight_epsilon
            bias = self.bias_mu + self.bias_sigma * self.bias_epsilon
        else:
            weight = self.weight_mu
            bias = self.bias_mu
        
        return F.linear(x, weight, bias)

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, n_step=3, gamma=0.99):
        self.capacity = capacity
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.buffer = []
        self.position = 0
        self.eps = 1e-5
        
        # N-step learning
        self.n_step = n_step
        self.gamma = gamma
        self.n_step_buffer = deque(maxlen=n_step)
    
    def _get_n_step_info(self):
        reward, next_state, done = self.n_step_buffer[-1][-3:]
        
        for transition in reversed(list(self.n_step_buffer)[:-1]):
            r, n_s, d = transition[-3:]
            
            reward = r + self.gamma * reward * (1 - d)
            next_state = n_s if d else next_state
            done = d if d else done
        
        return reward, next_state, done
    
    def push(self, *args):
        # Store transition in n-step buffer
        self.n_step_buffer.append(Transition(*args))
        
        # If we don't have enough transitions yet, return
        if len(self.n_step_buffer) < self.n_step:
            return
        
        # Get n-step transition
        reward, next_state, done = self._get_n_step_info()
        state, action = self.n_step_buffer[0][:2]
        
        # Store n-step transition in main buffer
        max_priority = self.priorities.max() if self.buffer else 1.0
        
        if len(self.buffer) < self.capacity:
            self.buffer.append(Transition(state, action, reward, next_state, done))
        else:
            self.buffer[self.position] = Transition(state, action, reward, next_state, done)
        
        self.priorities[self.position] = max_priority
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        if len(self.buffer) < batch_size:
            indices = range(len(self.buffer))
        else:
            priorities = self.priorities[:len(self.buffer)]
            probabilities = priorities ** self.alpha
            probabilities /= probabilities.sum()
            
            indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
        
        weights = (len(self.buffer) * probabilities[indices]) ** (-self.beta)
        weights /= weights.max()
        
        self.beta = min(1.0, self.beta + self.beta_increment)
        
        samples = [self.buffer[idx] for idx in indices]
        
        return samples, indices, torch.FloatTensor(weights)
    
    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority + self.eps
    
    def __len__(self):
        return len(self.buffer)

class RainbowDQN(nn.Module):
    def __init__(self, state_dim, action_dim, atom_size=51, v_min=-10, v_max=10):
        super(RainbowDQN, self).__init__()
        
        # Distributional RL parameters
        self.atom_size = atom_size
        self.v_min = v_min
        self.v_max = v_max
        self.support = torch.linspace(v_min, v_max, atom_size)
        
        # Common layers
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        
        # Dueling architecture
        self.value_stream = nn.Sequential(
            NoisyLinear(128, 64),
            nn.ReLU(),
            NoisyLinear(64, atom_size)
        )
        
        self.advantage_stream = nn.Sequential(
            NoisyLinear(128, 64),
            nn.ReLU(),
            NoisyLinear(64, action_dim * atom_size)
        )
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        value = self.value_stream(x).view(-1, 1, self.atom_size)
        advantage = self.advantage_stream(x).view(-1, self.action_dim, self.atom_size)
        
        # Combine value and advantage using dueling formula
        q_atoms = value + advantage - advantage.mean(dim=1, keepdim=True)
        
        # Get probabilities with softmax
        q_dist = F.softmax(q_atoms, dim=2)
        
        return q_dist
    
    def reset_noise(self):
        """Reset noise for exploration"""
        for module in self.modules():
            if isinstance(module, NoisyLinear):
                module.reset_noise()
    
    def act(self, state):
        """Get action with highest expected value"""
        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0)
            q_dist = self.forward(state)
            q_values = (q_dist * self.support).sum(dim=2)
            return q_values.max(1)[1].item()

class RainbowAgent:
    def __init__(self, state_dim, action_dim, buffer_size=100000, batch_size=64, 
                 gamma=0.99, lr=0.0001, target_update=1000, num_atoms=51, 
                 v_min=-10, v_max=10, n_step=3, alpha=0.5, beta=0.4):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.batch_size = batch_size
        self.gamma = gamma
        self.target_update = target_update
        
        # Distributional RL parameters
        self.num_atoms = num_atoms
        self.v_min = v_min
        self.v_max = v_max
        self.support = torch.linspace(v_min, v_max, num_atoms)
        self.delta_z = (v_max - v_min) / (num_atoms - 1)
        
        # Networks
        self.online_net = RainbowDQN(state_dim, action_dim, num_atoms, v_min, v_max)
        self.target_net = RainbowDQN(state_dim, action_dim, num_atoms, v_min, v_max)
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.target_net.eval()
        
        # Optimizer
        self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)
        
        # Replay buffer
        self.memory = PrioritizedReplayBuffer(
            buffer_size, alpha=alpha, beta=beta, n_step=n_step, gamma=gamma
        )
        
        # Update counter
        self.update_count = 0
    
    def select_action(self, state, evaluate=False):
        """Select action based on current policy"""
        if evaluate:
            return self.online_net.act(state)
        else:
            # Noisy network provides exploration, no epsilon needed
            return self.online_net.act(state)
    
    def store_transition(self, state, action, reward, next_state, done):
        """Store transition in replay buffer"""
        self.memory.push(state, action, reward, next_state, done)
    
    def update(self):
        """Update network parameters"""
        if len(self.memory) < self.batch_size:
            return 0.0
        
        # Sample from replay buffer
        transitions, indices, weights = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))
        
        # Convert to tensors
        state_batch = torch.FloatTensor(batch.state)
        action_batch = torch.LongTensor(batch.action).unsqueeze(1)
        reward_batch = torch.FloatTensor(batch.reward).unsqueeze(1)
        next_state_batch = torch.FloatTensor(batch.next_state)
        done_batch = torch.FloatTensor(batch.done).unsqueeze(1)
        
        # Get current Q distributions
        current_q_dist = self.online_net(state_batch)
        current_q_dist = current_q_dist.gather(1, action_batch.unsqueeze(-1).expand(-1, -1, self.num_atoms)).squeeze(1)
        
        # Get next Q distributions (for double DQN)
        with torch.no_grad():
            # Get argmax actions from online network
            next_q_dist = self.online_net(next_state_batch)
            next_q = (next_q_dist * self.support).sum(dim=2)
            next_actions = next_q.max(1)[1].unsqueeze(1)
            
            # Get Q distributions from target network using argmax actions
            next_q_dist = self.target_net(next_state_batch)
            next_q_dist = next_q_dist.gather(1, next_actions.unsqueeze(-1).expand(-1, -1, self.num_atoms)).squeeze(1)
            
            # Calculate target distribution
            Tz = reward_batch + (1 - done_batch) * (self.gamma ** self.memory.n_step) * self.support
            Tz = Tz.clamp(self.v_min, self.v_max)
            
            # Project onto fixed support
            b = (Tz - self.v_min) / self.delta_z
            l = b.floor().long()
            u = b.ceil().long()
            
            # Distribute probability mass
            target_q_dist = torch.zeros_like(current_q_dist)
            offset = torch.linspace(0, (self.batch_size - 1) * self.num_atoms, self.batch_size).long().unsqueeze(1).expand(self.batch_size, self.num_atoms)
            
            target_q_dist.view(-1).index_add_(
                0, (l + offset).view(-1), (next_q_dist * (u.float() - b)).view(-1)
            )
            target_q_dist.view(-1).index_add_(
                0, (u + offset).view(-1), (next_q_dist * (b - l.float())).view(-1)
            )
        
        # Calculate loss
        KL_div = -(target_q_dist * torch.log(current_q_dist + 1e-8)).sum(dim=1)
        weighted_loss = (KL_div * weights).mean()
        
        # Calculate priorities for PER
        priorities = KL_div.detach().cpu().numpy()
        
        # Optimize
        self.optimizer.zero_grad()
        weighted_loss.backward()
        # Optional: clip gradients
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), 10.0)
        self.optimizer.step()
        
        # Update priorities
        self.memory.update_priorities(indices, priorities)
        
        # Update target network
        self.update_count += 1
        if self.update_count % self.target_update == 0:
            self.target_net.load_state_dict(self.online_net.state_dict())
        
        # Reset noise for exploration
        self.online_net.reset_noise()
        
        return weighted_loss.item()

def train_rainbow(env, episodes=1000, state_dim=11, gamma=0.99, batch_size=64, update_frequency=4):
    """
    Train using Rainbow DQN
    
    Args:
        env: The SpectrumEnvironment instance
        episodes: Number of episodes to train
        state_dim: Dimension of state (num_bands + 1 for time step)
        gamma: Discount factor
        batch_size: Batch size for updates
        update_frequency: Steps between network updates
    
    Returns:
        List of episode rewards
    """
    # Number of actions: detect or skip for each band
    action_dim = env.num_bands * 2
    
    # Initialize agent
    agent = RainbowAgent(
        state_dim=state_dim,
        action_dim=action_dim,
        batch_size=batch_size,
        gamma=gamma,
        # Customize for spectral sensing
        v_min=-100,  # Min expected reward
        v_max=100,   # Max expected reward
        n_step=3     # Multi-step learning
    )
    
    episode_rewards = []
    step_count = 0
    
    for episode in range(episodes):
        # Reset environment
        env.soft_reset()
        env._generate_spectrum_state()
        
        # Initialize state (time_step=0)
        time_step = 0
        state = np.zeros(state_dim)
        state[0] = time_step / env.steps  # Normalize time step
        
        episode_reward = 0
        done = False
        
        while not done:
            # Select and take action
            action = agent.select_action(state)
            next_state_tuple, reward, done = env.step(time_step, action)
            
            # Update state
            next_time_step = next_state_tuple[0]
            next_state = np.zeros(state_dim)
            next_state[0] = next_time_step / env.steps  # Normalize time step
            
            # For each band, add its energy level (if available)
            for band in range(env.num_bands):
                if env.energy_costs[band]:
                    next_state[band + 1] = env.energy_costs[band][-1]
            
            # Store transition
            agent.store_transition(state, action, reward, next_state, done)
            
            # Update network
            step_count += 1
            if step_count % update_frequency == 0:
                loss = agent.update()
            
            episode_reward += reward
            
            # Move to next state
            state = next_state
            time_step = next_time_step
        
        episode_rewards.append(episode_reward)
        
        if episode % 10 == 0:
            print(f"Episode {episode}, Reward: {episode_reward}, Energy Used: {env.total_energy}")
    
    return episode_rewards