In [None]:
#IMPORTING LIBRARIES
import numpy as np
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
import os
import pickle
from datetime import datetime




# 0) DEVICE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 1) ENVIRONMENT
class DroneDispatchEnv:
    def __init__(self, H, M, N, L,
                 zeta, epsilon_min,
                 breakpoints_slow, breakpoints_fast,
                 alpha, beta, cap_i,
                 lambda_h_t, sigma_h,
                 distances, speeds,
                 eta, psi, sell_price):
        self.H, self.M, self.N, self.L = H, M, N, L
        self.zeta, self.epsilon_min = zeta, epsilon_min
        self.alpha, self.beta = alpha, beta
        self.cap_i = cap_i.copy()
        self.lambda_h_t = lambda_h_t.copy()
        self.sigma_h = sigma_h.copy()
        self.distances = distances.copy()
        self.speeds = speeds.copy()
        self.eta, self.psi = eta, psi
        self.sell_price = sell_price

        # normalization constants
        self.max_demand = self.lambda_h_t.max()
        self.reward_scale = float(self.M * self.sigma_h.max() * self.cap_i.max())

        self.bp_slow = sorted(breakpoints_slow, key=lambda x: x[0])
        self.bp_fast = sorted(breakpoints_fast, key=lambda x: x[0])
        self._build_piecewise()
        self.action_dim = 3 + H
        self.reset()

    def _build_piecewise(self):
        ts, ys = zip(*self.bp_slow)
        self.t_slow_bp = np.array(ts); self.s_slow_bp = np.array(ys)
        self.sl_slow = (self.s_slow_bp[1:]-self.s_slow_bp[:-1])/(self.t_slow_bp[1:]-self.t_slow_bp[:-1])
        tf, yf = zip(*self.bp_fast)
        self.t_fast_bp = np.array(tf); self.s_fast_bp = np.array(yf)
        self.sl_fast = (self.s_fast_bp[1:]-self.s_fast_bp[:-1])/(self.t_fast_bp[1:]-self.t_fast_bp[:-1])

    def t_slow_from_soc(self, soc):
        for i in range(len(self.sl_slow)):
            if self.s_slow_bp[i] <= soc <= self.s_slow_bp[i+1]:
                return self.t_slow_bp[i] + (soc-self.s_slow_bp[i])/self.sl_slow[i]
        return self.t_slow_bp[-1]

    def t_fast_from_soc(self, soc):
        for i in range(len(self.sl_fast)):
            if self.s_fast_bp[i] <= soc <= self.s_fast_bp[i+1]:
                return self.t_fast_bp[i] + (soc-self.s_fast_bp[i])/self.sl_fast[i]
        return self.t_fast_bp[-1]

    def soc_slow(self, t):
        for i in range(len(self.sl_slow)):
            if self.t_slow_bp[i] <= t <= self.t_slow_bp[i+1]:
                return self.s_slow_bp[i] + self.sl_slow[i]*(t-self.t_slow_bp[i])
        return self.s_slow_bp[-1]

    def soc_fast(self, t):
        for i in range(len(self.sl_fast)):
            if self.t_fast_bp[i] <= t <= self.t_fast_bp[i+1]:
                return self.s_fast_bp[i] + self.sl_fast[i]*(t-self.t_fast_bp[i])
        return self.s_fast_bp[-1]

    def reset(self):
        self.t = 0
        self.soc = np.full(self.M, 100.0)
        self.mode = np.zeros(self.M, dtype=int)
        self.remain = np.zeros(self.M, dtype=int)
        self.target = -np.ones(self.M, dtype=int)
        self.demand = np.random.poisson(self.lambda_h_t[:,0])
        self.dispatch_demand = np.zeros(self.M)
        return self._get_state()

    def _get_state(self):
        # time ∈ [0,1], soc ∈ [0,1], mode ∈ [0,1], remain ∈ [0,1], demand ∈ [0,1]
        t_norm    = self.t / self.N
        soc_norm  = (np.floor(self.soc/self.zeta)*self.zeta)/100.0
        mode_norm = self.mode.astype(float)/3.0
        rem_norm  = self.remain.astype(float)/self.N
        dem_norm  = self.demand.astype(float)/4
        return np.concatenate([[t_norm], soc_norm, mode_norm, rem_norm, dem_norm])

    def _get_raw_state(self):
        """Get unnormalized state for logging purposes"""
        t_raw = self.t
        soc_raw = np.floor(self.soc/self.zeta)*self.zeta  # Discretized but not normalized
        mode_raw = self.mode.copy()
        rem_raw = self.remain.copy()
        dem_raw = self.demand.copy()
        return {
            'timestep': t_raw,
            'soc': soc_raw.tolist(),
            'mode': mode_raw.tolist(), 
            'remain': rem_raw.tolist(),
            'demand': dem_raw.tolist()
        }

    def _calculate_terminal_demand_potential_with_details(self, socs):
        """
        Returns both terminal reward and complete calculation details for logging
        Sequential assignment terminal reward with demand depletion
        """
        # Sample terminal demands from Poisson distribution
        terminal_demands = {}
        for hub in range(self.H):
            lambda_h_avg = np.mean(self.lambda_h_t[hub, :])
            terminal_demands[hub] = np.random.poisson(lambda_h_avg)
        
        remaining_demand = terminal_demands.copy()
        total_potential = 0
        assignment_details = []  # Track assignment process for logging
        
        # Sequential assignment from drone 0 to M-1 (matching mathematical formulation)
        for agent_id in range(self.M):
            agent_max_potential = 0
            best_hub = -1
            best_serviceable = 0
            
            # Find the best hub for this agent given remaining demands
            for hub in range(self.H):
                if remaining_demand[hub] > 0:
                    # Calculate actual payload for energy calculation
                    actual_payload = min(remaining_demand[hub], self.cap_i[agent_id])
                    
                    # Check if dispatch is possible with actual payload
                    can_dispatch = self._can_dispatch_terminal(agent_id, hub, socs[agent_id], remaining_demand[hub])
                    
                    if can_dispatch:
                        serviceable = actual_payload
                        potential_reward = serviceable * self.sigma_h[hub]
                        
                        if potential_reward > agent_max_potential:
                            agent_max_potential = potential_reward
                            best_hub = hub
                            best_serviceable = serviceable
            
            # Track assignment details for logging
            assignment_details.append({
                'agent_id': agent_id,
                'soc': float(socs[agent_id]),
                'best_hub': int(best_hub) if best_hub >= 0 else -1,
                'reward': float(agent_max_potential),
                'serviceable': int(best_serviceable) if best_hub >= 0 else 0
            })
            
            # Update remaining demand based on agent's best choice
            if best_hub >= 0:
                remaining_demand[best_hub] = max(0, remaining_demand[best_hub] - best_serviceable)
            
            total_potential += agent_max_potential
        
        # Return complete information for logging
        terminal_info = {
            'sampled_demands': terminal_demands,
            'assignment_details': assignment_details,
            'final_remaining_demands': remaining_demand,
            'lambda_averages': [np.mean(self.lambda_h_t[hub, :]) for hub in range(self.H)]
        }
        
        return total_potential, terminal_info

    def _can_dispatch_terminal(self, agent_id, hub, current_soc, available_demand):
        """Check if agent can dispatch to hub with given SOC and actual payload (CORRECTED)"""
        actual_payload = min(available_demand, self.cap_i[agent_id])  # Consider actual payload
        flight_time = self.distances[hub] / self.speeds[agent_id]  # FLIGHT TIME, not distance!
        energy_needed = (self.alpha * actual_payload + 2 * self.beta) * flight_time  # Multiply by TIME
        threshold = self.zeta * math.ceil((energy_needed + self.epsilon_min) / self.zeta)
        return current_soc >= threshold

    def step(self, actions):
        reward = 0.0
        rem_dem = self.demand.copy()
        
        # reset dispatch_demand for drones NOT in flight
        for i in range(self.M):
            if self.mode[i] != 3:  # Not in flight
                self.dispatch_demand[i] = 0

        # fleet-wide action effects
        for i,a in enumerate(actions):
            if self.mode[i] in (0,1,2):
                if a==0:
                    self.mode[i],self.remain[i]=0,0
                elif a==1:
                    t0=self.t_slow_from_soc(self.soc[i])
                    t1=min(t0+self.L,self.t_slow_bp[-1])
                    self.soc[i]=min(100,self.soc_slow(t1))
                    self.mode[i],self.remain[i]=1,0
                    reward-=self.eta
                elif a==2:
                    t0=self.t_fast_from_soc(self.soc[i])
                    t1=min(t0+self.L,self.t_fast_bp[-1])
                    self.soc[i]=min(100,self.soc_fast(t1))
                    self.mode[i],self.remain[i]=2,0
                    reward-=self.psi
                else:
                    h=a-3
                    if self._can_dispatch(i,h):
                        pay=min(rem_dem[h],self.cap_i[i])
                        rem_dem[h]-=pay
                        reward+=self.sigma_h[h]*pay
                        self.dispatch_demand[i]=pay  # Now preserved for in-flight drones
                        T1=self.distances[h]/self.speeds[i]
                        k1=math.ceil(T1/self.L)
                        self.mode[i],self.remain[i],self.target[i]=3,2*k1,h

        # in-flight discharge
        for i in range(self.M):
            if self.mode[i]==3:
                h=int(self.target[i])
                T1=self.distances[h]/self.speeds[i]
                k1=math.ceil(T1/self.L); k=2*k1
                q=k-self.remain[i]+1
                # uses correct payload from dispatch_demand[i]
                c=(self.alpha*self.dispatch_demand[i]+self.beta)*self.L if q<=k1 else self.beta*self.L
                new_soc=max(0,self.soc[i]-c)
                self.soc[i]=self.zeta*math.floor(new_soc/self.zeta)
                self.remain[i]-=1
                if self.remain[i]<=0:
                    self.mode[i],self.remain[i],self.target[i]=0,0,-1
                    self.dispatch_demand[i]=0  # Reset when mission complete

        self.t+=1
        done=(self.t>=self.N)  # When t reaches N, episode is done
        
        if not done:
            self.demand=np.random.poisson(self.lambda_h_t[:,self.t]) if self.t<self.lambda_h_t.shape[1] else np.zeros(self.H,dtype=int)

        return self._get_state(), reward, done

    def _can_dispatch(self,i,h):
        pay=min((self.dispatch_demand[i] if self.mode[i]==3 else self.demand[h]),self.cap_i[i])
        flight=self.distances[h]/self.speeds[i]
        need=(self.alpha*pay+2*self.beta)*flight
        thresh=self.zeta*math.ceil((need+self.epsilon_min)/self.zeta)
        return self.soc[i]>=thresh


# 2) FEASIBLE ACTIONS
def get_feasible_actions(env,i,rem):
    feasible=[0]
    if env.mode[i]!=3 and env.soc[i]<100: feasible+=[1,2]
    if env.mode[i]!=3:
        for h in range(env.H):
            if rem[h]>0 and env._can_dispatch(i,h):
                feasible.append(3+h)
    return feasible


# 3) NETWORKS FOR COMA
class Actor(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, x):
        return self.net(x)


class COMACritic(nn.Module):
    """
    COMA Critic: Q(s, u^-i, i) -> Q-values for ALL actions of agent i
    """
    def __init__(self, state_dim, M, action_dim, hidden_dim=256):
        super().__init__()
        self.M = M
        self.action_dim = action_dim
        self.state_dim = state_dim
        
        # Input: state + other agents' actions (one-hot) + agent_id
        input_dim = state_dim + (M-1) * action_dim + 1
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)  # Output Q-values for ALL actions
        )
    
    def forward(self, state, other_actions, agent_id):
        """
        Args:
            state: (batch_size, state_dim)
            other_actions: (batch_size, (M-1)*action_dim) - other agents' actions (one-hot)
            agent_id: (batch_size, 1) - which agent this is for
        Returns:
            Q-values: (batch_size, action_dim) - Q(s, u^-i, i) for ALL u^i
        """
        agent_id_normalized = agent_id.float() / self.M
        x = torch.cat([state, other_actions, agent_id_normalized], dim=1)
        return self.net(x)  # Shape: (batch_size, action_dim)


# 4) RUN EPISODE WITH DETAILED LOGGING
def run_episode(env, actors):
    state = env.reset()
    traj = {
        'states': [], 'local_obs': [], 'actions': [], 'rewards': [], 
        'feasible_masks': [], 'log': []
    }
    
    # Detailed episode log
    episode_log = []
    
    while True:
        # Check if episode is done BEFORE taking actions
        if env.t >= env.N - 1:
            # Calculate terminal reward
            soc_disc = np.floor(env.soc/env.zeta)*env.zeta
            terminal_reward, terminal_info = env._calculate_terminal_demand_potential_with_details(soc_disc)
            
            # Log terminal step (no action, no next_state)
            episode_log.append({
                'step': env.t,
                'state': env._get_raw_state(),
                'action': None,  # No action in terminal state
                'next_state': None,  # No next state
                'reward': float(terminal_reward),
                'terminal': True,
                'terminal_info': terminal_info
            })
            
            traj['rewards'].append(terminal_reward)
            break
        
        traj['states'].append(state.copy())
        
        # Store current state for logging
        current_raw_state = env._get_raw_state()
        
        # Local observations for each agent
        local_feats = [
            np.concatenate([
                [state[0]], [state[1+i]], [state[1+env.M+i]],
                [state[1+2*env.M+i]], state[-env.H:]
            ]) for i in range(env.M)
        ]
        traj['local_obs'].append(local_feats)
        
        rem = env.demand.copy()
        acts = []
        feasible_masks = []
        
        for i, actor in enumerate(actors):
            feas = get_feasible_actions(env, i, rem)
            
            # Store feasible mask
            mask = torch.zeros(env.action_dim, device=device)
            mask[feas] = 1.0
            feasible_masks.append(mask)
            
            # Get action from policy
            obs = torch.FloatTensor(local_feats[i]).unsqueeze(0).to(device)
            with torch.no_grad():
                logits = actor(obs).squeeze(0)
            
            # Apply feasible mask
            masked_logits = torch.full_like(logits, -1e9)
            masked_logits[feas] = logits[feas]
            probs = torch.softmax(masked_logits, dim=-1)
            
            # Sample from policy
            a = int(torch.multinomial(probs, 1).item())
            
            acts.append(a)
            if a >= 3:
                rem[a-3] = max(0, rem[a-3] - env.cap_i[i])

        traj['feasible_masks'].append(feasible_masks)
        traj['actions'].append(acts.copy())
        
        # Step environment
        next_state, reward, done = env.step(acts)
        traj['rewards'].append(reward)
        
        # Log this step
        episode_log.append({
            'step': env.t - 1,  # env.t was incremented in step()
            'state': current_raw_state,
            'action': acts.copy(),
            'next_state': env._get_raw_state(),
            'reward': float(reward),
            'terminal': False
        })
        
        traj['log'].append({
            'state': state.copy(), 'action': acts.copy(), 
            'reward': reward, 'next_state': next_state.copy(), 'done': done
        })
        
        if done:
            break
        
        state = next_state

    return traj, episode_log


# 5) ON-POLICY COMA TRAINING WITH DETAILED LOGGING
def train_coma_standard(env, episodes=8000, gamma=0.95, lr_actor=3e-4, lr_critic=1e-3, entropy_coef=0.01):
    
    state_dim = len(env.reset())
    M, H = env.M, env.H
    action_dim = 3 + H
    obs_dim = 4 + H
    
    # Initialize networks
    actors = [Actor(obs_dim, action_dim, hidden_dim=128).to(device) for _ in range(M)]
    critic = COMACritic(state_dim, M, action_dim, hidden_dim=256).to(device)
    
    # Optimizers
    actor_optimizers = [optim.Adam(actor.parameters(), lr=lr_actor, weight_decay=1e-5) 
                       for actor in actors]
    critic_optimizer = optim.Adam(critic.parameters(), lr=lr_critic, weight_decay=1e-4)
    
    # Training logs
    all_returns, critic_losses, actor_losses = [], [], []
    
    # Detailed episode logs
    detailed_episode_logs = {}
    
    for episode in range(episodes):
        # Run episode with detailed logging
        traj, episode_log = run_episode(env, actors)
        
        # Store detailed episode log
        detailed_episode_logs[episode] = episode_log
        
        T = len(traj['actions'])  # Number of action timesteps
        
        # Calculate returns for this episode only (Monte Carlo)
        returns = []
        G = 0
        for t in reversed(range(T + 1)):  # Include terminal reward
            G = traj['rewards'][t] + gamma * G
            returns.insert(0, G)
        
        total_reward = returns[0] if returns else 0
        all_returns.append(total_reward)
        
        # ON-POLICY NETWORK UPDATES
        episode_critic_loss = 0.0
        episode_actor_loss = 0.0
        
        if T > 0:  # Make sure we have some transitions
            # Convert episode data to tensors
            states = torch.from_numpy(np.stack(traj['states'], 0)).float().to(device)  # (T, state_dim)
            actions = torch.LongTensor(traj['actions']).to(device)  # (T, M)
            returns_tensor = torch.FloatTensor(returns[:-1]).to(device)  # (T,) - exclude terminal
            
            # CRITIC UPDATE
            critic_loss = 0.0
            
            for i in range(M):
                # Prepare other agents' actions (one-hot)
                other_actions_onehot = torch.zeros(T, (M-1) * action_dim, device=device)
                for t in range(T):
                    other_idx = 0
                    for j in range(M):
                        if j != i:
                            other_actions_onehot[t, other_idx * action_dim + actions[t, j]] = 1.0
                            other_idx += 1
                
                agent_ids = torch.full((T, 1), i, device=device)
                
                # Get Q-values for ALL actions of agent i
                q_values_all = critic(states, other_actions_onehot, agent_ids)  # (T, action_dim)
                
                # Extract Q-value for the action that was actually taken
                q_values_taken = q_values_all.gather(1, actions[:, i].unsqueeze(1)).squeeze(1)
                
                # Critic loss for agent i
                critic_loss_i = F.mse_loss(q_values_taken, returns_tensor)
                critic_loss += critic_loss_i
            
            critic_loss = critic_loss / M
            episode_critic_loss = critic_loss.item()
            
            # Update critic
            critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=10.0)
            critic_optimizer.step()
            
            # ACTOR UPDATES
            total_actor_loss = 0.0
            
            for i in range(M):
                agent_actor_loss = 0.0
                valid_samples = 0
                
                for t in range(T):
                    feasible_mask = traj['feasible_masks'][t][i]
                    feasible_actions = torch.where(feasible_mask > 0)[0].cpu().numpy()
                    
                    if len(feasible_actions) <= 1:
                        continue
                    
                    # Prepare other agents' actions
                    other_actions_onehot = torch.zeros((M-1) * action_dim, device=device)
                    other_idx = 0
                    for j in range(M):
                        if j != i:
                            other_actions_onehot[other_idx * action_dim + actions[t, j]] = 1.0
                            other_idx += 1
                    
                    state_t = states[t:t+1]
                    other_actions_t = other_actions_onehot.unsqueeze(0)
                    agent_id_t = torch.tensor([[i]], device=device)
                    
                    # COMA ADVANTAGE COMPUTATION
                    with torch.no_grad():
                        # Get Q-values for ALL actions of agent i
                        q_values_all = critic(state_t, other_actions_t, agent_id_t).squeeze(0)  # (action_dim,)
                        
                        # Get current policy probabilities
                        obs_i = torch.FloatTensor(traj['local_obs'][t][i]).unsqueeze(0).to(device)
                        logits_i = actors[i](obs_i).squeeze(0)
                        
                        # Apply feasible mask
                        masked_logits = torch.full_like(logits_i, -1e9)
                        masked_logits[feasible_actions] = logits_i[feasible_actions]
                        policy_probs = F.softmax(masked_logits, dim=-1)
                        
                        # Compute counterfactual baseline
                        baseline = torch.sum(policy_probs * q_values_all).item()
                        
                        # Q-value for the action that was actually taken
                        action_taken = actions[t, i]
                        q_taken = q_values_all[action_taken].item()
                    
                    # Compute advantage
                    advantage = q_taken - baseline
                    
                    # Policy gradient with COMA advantage + ENTROPY
                    obs_i = torch.FloatTensor(traj['local_obs'][t][i]).unsqueeze(0).to(device)
                    logits_i = actors[i](obs_i).squeeze(0)
                    masked_logits = torch.full_like(logits_i, -1e9)
                    masked_logits[feasible_actions] = logits_i[feasible_actions]
                    policy_probs = F.softmax(masked_logits, dim=-1)
                    
                    log_prob = torch.log(policy_probs[action_taken] + 1e-8)
                    
                    # ENTROPY REGULARIZATION
                    entropy = -torch.sum(policy_probs * torch.log(policy_probs + 1e-8))
                    
                    # Combined loss: Policy gradient + Entropy regularization
                    actor_loss_sample = -advantage * log_prob - entropy_coef * entropy
                    
                    agent_actor_loss += actor_loss_sample
                    valid_samples += 1
                
                if valid_samples > 0:
                    agent_actor_loss = agent_actor_loss / valid_samples
                    total_actor_loss += agent_actor_loss
                    
                    # Update actor i
                    actor_optimizers[i].zero_grad()
                    agent_actor_loss.backward()
                    torch.nn.utils.clip_grad_norm_(actors[i].parameters(), max_norm=10.0)
                    actor_optimizers[i].step()
            
            episode_actor_loss = (total_actor_loss / M).item() if total_actor_loss != 0 else 0.0
        
        # Store losses
        critic_losses.append(episode_critic_loss)
        actor_losses.append(episode_actor_loss)
        
        print(f"Episode {episode+1:4d}/{episodes} | Total Reward: {total_reward:8.3f} | "
              f"Critic Loss: {episode_critic_loss:8.4f} | Actor Loss: {episode_actor_loss:8.4f}")

    return actors, critic, all_returns, critic_losses, actor_losses, detailed_episode_logs



def save_trained_models(actors, critic, returns=None, save_dir="saved_models"):
    """
    Save your trained COMA models
    """
    # Create timestamped directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_path = os.path.join(save_dir, f"coma_model_{timestamp}")
    os.makedirs(model_path, exist_ok=True)
    
    # Save each actor
    for i, actor in enumerate(actors):
        torch.save(actor.state_dict(), os.path.join(model_path, f"actor_{i}.pth"))
    
    # Save critic
    torch.save(critic.state_dict(), os.path.join(model_path, "critic.pth"))
    
    # Save training results if provided
    if returns is not None:
        with open(os.path.join(model_path, "training_results.pkl"), 'wb') as f:
            pickle.dump({'returns': returns}, f)
    
    print(f" Models saved to: {model_path}")
    return model_path

def load_trained_models(model_path, env):
    """
    Load your saved models
    """
    # Recreate network architectures
    state_dim = len(env.reset())
    M, H = env.M, env.H
    action_dim = 3 + H
    obs_dim = 4 + H
    
    # Initialize networks
    actors = [Actor(obs_dim, action_dim, hidden_dim=128).to(device) for _ in range(M)]
    critic = COMACritic(state_dim, M, action_dim, hidden_dim=256).to(device)
    
    # Load weights
    for i, actor in enumerate(actors):
        actor.load_state_dict(torch.load(os.path.join(model_path, f"actor_{i}.pth"), map_location=device))
        actor.eval()
    
    critic.load_state_dict(torch.load(os.path.join(model_path, "critic.pth"), map_location=device))
    critic.eval()
    
    print(f" Models loaded from: {model_path}")
    return actors, critic