In [None]:
# Chạy ô này đầu tiên
!pip install "protobuf<=3.20.1" --force-reinstall

In [None]:
import json
import os

# 1. Định nghĩa dữ liệu cho Drone Config
drone_config = {"1":{
    "takeoffSpeed [m/s]":15.6464,
    "cruiseSpeed [m/s]":31.2928,
    "landingSpeed [m/s]":7.8232,
    "cruiseAlt [m]":50,
    "capacity [kg]":2.27,
    "batteryPower [Joule]":4575030,
    "speed_type":"high", 
    "range":"low",
    "beta(w/kg)":24.2,
    "gama(w)":1392
    },
"2":{
    "takeoffSpeed [m/s]":15.6464,
    "cruiseSpeed [m/s]":31.2928,
    "landingSpeed [m/s]":7.8232,
    "cruiseAlt [m]":50,
    "capacity [kg]":2.27,
    "batteryPower [Joule]":904033,
    "speed_type":"high",
    "range":"high",
    "beta(w/kg)":24.2,
    "gama(w)":1392
    },
"3":{
    "takeoffSpeed [m/s]":7.8232,
    "cruiseSpeed [m/s]":15.6464,
    "landingSpeed [m/s]":3.9116,
    "cruiseAlt [m]":50,
    "capacity [kg]":2.27,
    "batteryPower [Joule]":291094,
    "speed_type":"low",
    "range":"low",
    "beta(w/kg)":210.8,
    "gama(w)":181.2
    },
"4":{
    "takeoffSpeed [m/s]":7.8232,
    "cruiseSpeed [m/s]":15.6464,
    "landingSpeed [m/s]":3.9116,
    "cruiseAlt [m]":50,
    "capacity [kg]":2.27,
    "batteryPower [Joule]":562990,
    "speed_type":"low",
    "range":"high",
    "beta(w/kg)":210.8,
    "gama(w)":181.2
}
}



# 2. Định nghĩa dữ liệu cho Truck Config
truck_config = {
    "V_max (m/s)": 15.557,
    "T (hour)": {"0-1": 0.7, "1-2":0.4, "2-3":0.6, "3-4":0.7,
            "4-5":0.8, "5-6": 0.9, "6-7": 1.0, "7-8":0.7,
            "8-9": 0.6, "9-10": 0.5, "10-11":0.7, "11-12":0.8 }
        
}
            

# 3. Lưu file drone_linear_config.json
with open('drone_linear_config.json', 'w') as f:
    json.dump(drone_config, f, indent=4)
    print("Đã tạo file: drone_linear_config.json")

# 4. Lưu file Truck_config.json
with open('Truck_config.json', 'w') as f:
    json.dump(truck_config, f, indent=4)
    print("Đã tạo file: Truck_config.json")

# Kiểm tra lại thư mục hiện tại
print("\nDanh sách file trong thư mục làm việc:")
print(os.listdir("."))

In [None]:
import torch
import numpy as np
import random
from torch.utils.data import IterableDataset, DataLoader

class MOPVRPGenerator(IterableDataset):
    def __init__(self, batch_size=32, device='cpu'):
        super(MOPVRPGenerator, self).__init__()
        self.batch_size = batch_size
        self.device = device
        
        # Cấu hình các kịch bản (Profiles) dựa trên file mẫu của bạn
        # (Số khách, Số Staff, Số Drone, Phạm vi tọa độ)
        self.configs = [
            {'n': 6,   'staff': 1,  'drone': 1, 'scale': 5000.0},
            {'n': 10,  'staff': 2,  'drone': 1, 'scale': 8000.0},
            {'n': 20,  'staff': 2,  'drone': 2, 'scale': 10000.0},
            {'n': 50,  'staff': 4,  'drone': 2, 'scale': 20000.0},
            {'n': 100, 'staff': 4,  'drone': 4, 'scale': 35000.0},
            {'n': 200, 'staff': 10, 'drone': 4, 'scale': 40000.0}
        ]

        self.scenario = 5  # Default

    def set_scenario(self, scenario: int):
        self.scenario = scenario
        return 

    def _generate_instance(self, cfg):
        """Sinh 1 batch dữ liệu theo cấu hình cfg"""
        batch_size = self.batch_size
        num_customers = cfg['n']
        num_nodes = num_customers + 1
        num_trucks = cfg['staff']
        num_drones = cfg['drone']
        map_scale = cfg['scale']
        
        # --- 1. Static Data (Bản đồ) ---
        # Shape: (Batch, 4, Num_Nodes)
        # Feature 0, 1: X, Y (Normalized về 0-1 để Model dễ học)
        # Feature 2: Demand
        # Feature 3: Truck Only Flag
        static = torch.zeros(batch_size, 4, num_nodes, device=self.device)
        
        # Tọa độ: Random [0, 1]
        static[:, 0:2, :] = torch.rand(batch_size, 2, num_nodes, device=self.device)
        
        # Depot: Luôn ở trung tâm (0.5, 0.5) hoặc random
        # Để giống file mẫu (depot có thể âm dương), ta cứ để random [0,1] rồi scale sau
        # Nhưng trong logic xe, depot thường là node 0
        
        # Demand: Random nhỏ [0.01, 0.1] như file mẫu
        static[:, 2, 1:] = torch.rand(batch_size, num_customers, device=self.device) * 0.09 + 0.01
        
        # Truck Only: Xác suất 20-30%
        truck_prob = 0.3
        static[:, 3, 1:] = (torch.rand(batch_size, num_customers, device=self.device) < truck_prob).float()
        
        # --- 2. Dynamic Data ---
        # Trucks: [Loc, Time]
        dynamic_trucks = torch.zeros(batch_size, 2, num_trucks, device=self.device)
        
        # Drones: [Loc, Time, Energy, Payload]
        dynamic_drones = torch.zeros(batch_size, 4, num_drones, device=self.device)
        dynamic_drones[:, 2, :] = 1.0 # Full pin (Normalized)
        
        # --- 3. Masks ---
        mask_customers = torch.ones(batch_size, num_nodes, device=self.device)
        mask_customers[:, 0] = 0 # Depot không cần phục vụ
        
        mask_vehicles = torch.ones(batch_size, num_trucks + num_drones, device=self.device)
        
        # Trả về thêm tham số 'scale' để môi trường tính khoảng cách thực tế (km)
        scale_tensor = torch.full((batch_size, 1), map_scale, device=self.device)
        weights = torch.tensor([[0.5, 0.5]] * batch_size, device=self.device)
        
        return static, dynamic_trucks, dynamic_drones, mask_customers, mask_vehicles, scale_tensor, weights

    def __iter__(self):
        """Vòng lặp vô tận sinh dữ liệu cho RL"""
        while True:
            # Bước 1: Chọn ngẫu nhiên 1 kịch bản (Curriculum Learning)
            cfg = self.configs[self.scenario]
            
            yield self._generate_instance(cfg)

# Hàm tiện ích để tạo DataLoader chuẩn của PyTorch
def get_rl_dataloader(batch_size=32, device='cpu', scenario=5):
    dataset = MOPVRPGenerator(batch_size=batch_size, device=device)
    dataset.set_scenario(scenario)
    # Batch size để None vì dataset tự sinh batch
    return DataLoader(dataset, batch_size=None, batch_sampler=None)


if __name__ == "__main__":

    # 1. Cấu hình
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    BATCH_SIZE = 4  # Để nhỏ cho dễ nhìn log    
    print(f"🚀 Bắt đầu kiểm tra DataLoader trên thiết bị: {DEVICE}")
    
    # 2. Khởi tạo Loader
    dataloader = get_rl_dataloader(batch_size=BATCH_SIZE, device=DEVICE)
    data_iter = iter(dataloader)
    
    
    # 3. Chạy thử 5 vòng lặp để xem kích thước thay đổi
    for i in range(1, 6):
        print(f"\n{'='*40}")
        print(f"📡 LẤY BATCH THỨ {i}")
        scenario_number = random.randint(0,5)
        dataloader.dataset.set_scenario(scenario_number)
        
        # Lấy dữ liệu từ dataloader
        static, dyn_trucks, dyn_drones, mask_cust, mask_veh, scale, weights = next(data_iter)
        
        # Trích xuất thông tin kích thước
        b_size, _, num_nodes = static.shape
        num_customers = num_nodes - 1
        num_trucks = dyn_trucks.shape[2]
        num_drones = dyn_drones.shape[2]
        
        # In thông tin kiểm tra
        print(f"🔹 Kịch bản (Scenario): {num_customers} Khách hàng")
        print(f"🔹 Đội xe: {num_trucks} Trucks + {num_drones} Drones")
        print(f"🔹 Phạm vi bản đồ thực (Scale): {scale[0].item():.0f} mét")
        
        print(f"\n🔍 Kiểm tra Shape Tensor:")
        print(f"   - Static Input:       {static.shape}  (Mong đợi: [{BATCH_SIZE}, 4, {num_nodes}])")
        print(f"   - Dynamic Trucks:     {dyn_trucks.shape}  (Mong đợi: [{BATCH_SIZE}, 2, {num_trucks}])")
        print(f"   - Dynamic Drones:     {dyn_drones.shape}  (Mong đợi: [{BATCH_SIZE}, 4, {num_drones}])")
        print(f"   - Mask Customers:     {mask_cust.shape}  (Mong đợi: [{BATCH_SIZE}, {num_nodes}])")
        print(f"   - Mask Vehicles:      {mask_veh.shape}  (Mong đợi: [{BATCH_SIZE}, {num_trucks + num_drones}])")
        
        # Kiểm tra tính chuẩn hóa dữ liệu
        max_coord = static[:, 0:2, :].max().item()
        min_coord = static[:, 0:2, :].min().item()
        
        print(f"\n📊 Kiểm tra giá trị:")
        print(f"   - Tọa độ Max: {max_coord:.4f} (Phải <= 1.0)")
        print(f"   - Tọa độ Min: {min_coord:.4f} (Phải >= 0.0)")
        
        if max_coord <= 1.0 and min_coord >= 0.0:
            print("   ✅ Dữ liệu đã được Normalize tốt.")
        else:
            print("   ❌ Cảnh báo: Dữ liệu chưa được Normalize!")

    print(f"\n{'='*40}")
    print("✅ Kiểm tra hoàn tất. DataLoader hoạt động đúng thiết kế RL.")

In [None]:
import json
import numpy as np
import torch

class SystemConfig:
    def __init__(self, truck_config_path, drone_config_path, drone_type="1"):
        self.truck_config = self._load_json(truck_config_path)
        self.drone_config_full = self._load_json(drone_config_path)
        
        self.drone_type = str(drone_type)
        if self.drone_type not in self.drone_config_full:
            raise ValueError(f"Drone type {drone_type} not found")
        
        # Load Drone Params
        p = self.drone_config_full[self.drone_type]
        self.drone_params = p
        self.drone_max_energy = p['batteryPower [Joule]']
        self.drone_speed = p['cruiseSpeed [m/s]']
        self.drone_capacity_kg = p['capacity [kg]']
        
        # Pre-calc time for takeoff/landing to save compute
        self.t_takeoff = p['cruiseAlt [m]'] / p['takeoffSpeed [m/s]']
        self.t_landing = p['cruiseAlt [m]'] / p['landingSpeed [m/s]']

        # Truck Time Windows
        self.truck_time_factors = []
        for key, factor in self.truck_config['T (hour)'].items():
            start, end = map(int, key.split('-'))
            self.truck_time_factors.append((start, end, factor))
        self.truck_v_max = self.truck_config['V_max (m/s)']

    def _load_json(self, path):
        with open(path, 'r') as f: return json.load(f)

    def get_truck_speed_batch(self, current_time_seconds):
        """
        Tính vận tốc Truck theo Eq (11).
        Input: Tensor (Batch_Size,)
        Output: Tensor (Batch_Size,)
        """
        if isinstance(current_time_seconds, torch.Tensor):
            hours = (current_time_seconds / 3600) % 24
            factors = torch.ones_like(hours)
            for start, end, f in self.truck_time_factors:
                mask = (hours >= start) & (hours < end)
                factors[mask] = f
            return self.truck_v_max * factors
        else:
            # Fallback scalar
            return self.truck_v_max # Simplification for scalar


In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np

# # from dataloader import MOPVRPGenerator, get_rl_dataloader

# class Encoder(nn.Module):
#     """Encodes static & dynamic features using 1D Convolution."""
#     def __init__(self, input_size, hidden_size):
#         super(Encoder, self).__init__()
#         self.conv = nn.Conv1d(input_size, hidden_size, kernel_size=1)
    
#     def forward(self, x):
#         return self.conv(x)

# class MultiAgentDecoder(nn.Module):
#     """Decoder for multi-agent vehicle routing."""
#     def __init__(self, hidden_size, num_layers=1, dropout=0.2):
#         super(MultiAgentDecoder, self).__init__()
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
        
#         self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers,
#                            batch_first=True, dropout=dropout if num_layers > 1 else 0)
        
#         self.v_veh = nn.Parameter(torch.zeros((1, 1, hidden_size)))
#         self.W_veh = nn.Parameter(torch.zeros((1, hidden_size, hidden_size * 3)))
        
#         self.v_node = nn.Parameter(torch.zeros((1, 1, hidden_size)))
#         self.W_node = nn.Parameter(torch.zeros((1, hidden_size, hidden_size * 3)))
        
#         # Init weights
#         nn.init.xavier_uniform_(self.v_veh)
#         nn.init.xavier_uniform_(self.W_veh)
#         nn.init.xavier_uniform_(self.v_node)
#         nn.init.xavier_uniform_(self.W_node)
        
#         self.drop_rnn = nn.Dropout(p=dropout)
#         if num_layers == 1:
#             self.drop_hh = nn.Dropout(p=dropout)
    
#     def forward(self, customer_embeds, vehicle_embeds, decoder_hidden, last_hh):
#         batch_size = customer_embeds.size(0)
        
#         # Update LSTM
#         rnn_out, last_hh = self.lstm(decoder_hidden.transpose(2, 1), last_hh)
#         rnn_out = rnn_out.squeeze(1)
#         rnn_out = self.drop_rnn(rnn_out)
        
#         if self.num_layers == 1:
#             h_n, c_n = last_hh
#             h_n = self.drop_hh(h_n)
#             last_hh = (h_n, c_n)
        
#         # --- Attention Mechanism ---
        
#         # 1. Global Context
#         C_node = customer_embeds.mean(dim=2, keepdim=True) 
#         C_veh = vehicle_embeds.mean(dim=2, keepdim=True)   
        
#         # 2. Vehicle Selection Attention
#         h_expanded = rnn_out.unsqueeze(2).expand_as(vehicle_embeds)
#         C_node_expanded = C_node.expand_as(vehicle_embeds)
        
#         veh_input = torch.cat([C_node_expanded, h_expanded, vehicle_embeds], dim=1) 
        

#         v_veh = self.v_veh.expand(batch_size, -1, -1)
#         W_veh = self.W_veh.expand(batch_size, -1, -1)
        
#         veh_energy = torch.bmm(v_veh, torch.tanh(torch.bmm(W_veh, veh_input)))
#         veh_probs = veh_energy.squeeze(1)
        
#         # 3. Customer Selection Attention
#         h_expanded_node = rnn_out.unsqueeze(2).expand_as(customer_embeds)
#         C_veh_expanded = C_veh.expand_as(customer_embeds)
        
#         node_input = torch.cat([C_veh_expanded, h_expanded_node, customer_embeds], dim=1)
        
#         v_node = self.v_node.expand(batch_size, -1, -1)
#         W_node = self.W_node.expand(batch_size, -1, -1)
        
#         node_energy = torch.bmm(v_node, torch.tanh(torch.bmm(W_node, node_input)))
#         node_probs = node_energy.squeeze(1)
        
#         return veh_probs, node_probs, last_hh

# class MOPVRP_Actor(nn.Module):
#     def __init__(self, static_size, dynamic_size_truck, dynamic_size_drone, 
#                  hidden_size, num_layers=1, dropout=0.2):
#         super(MOPVRP_Actor, self).__init__()
        
#         # Encoders
#         self.static_encoder = Encoder(static_size, hidden_size)
#         self.truck_encoder = Encoder(dynamic_size_truck, hidden_size)
#         self.drone_encoder = Encoder(dynamic_size_drone, hidden_size)
        
#         # Decoder input is 2D (x, y) coordinates of the last visited node
#         self.decoder = Encoder(2, hidden_size) 
#         self.pointer = MultiAgentDecoder(hidden_size, num_layers, dropout)
        
#         # Learnable initial placeholder for decoder input
#         self.x0 = nn.Parameter(torch.zeros(1, 2, 1)) 
    
#     def forward(self, static, dynamic_trucks, dynamic_drones, 
#                 decoder_input=None, last_hh=None, mask_customers=None, mask_vehicles=None):
        
#         batch_size = static.size(0)
        
#         # Prepare Decoder Input (First step uses x0)
#         if decoder_input is None:
#             decoder_input = self.x0.expand(batch_size, -1, -1)
        
#         # Prepare Masks
#         if mask_customers is None:
#             mask_customers = torch.ones(batch_size, static.size(2), device=static.device)
#         if mask_vehicles is None:
#             num_veh = dynamic_trucks.size(2) + dynamic_drones.size(2)
#             mask_vehicles = torch.ones(batch_size, num_veh, device=static.device)
        
#         # --- 1. Encoding ---
#         customer_hidden = self.static_encoder(static)      # (B, 128, N)
#         truck_hidden = self.truck_encoder(dynamic_trucks)  # (B, 128, T)
#         drone_hidden = self.drone_encoder(dynamic_drones)  # (B, 128, D)
        
#         # Combine vehicles
#         vehicle_hidden = torch.cat([truck_hidden, drone_hidden], dim=2) # (B, 128, T+D)
        
#         # --- 2. Decoding Step ---
#         decoder_hidden = self.decoder(decoder_input)
        
#         veh_logits, node_logits, last_hh = self.pointer(
#             customer_hidden, vehicle_hidden, decoder_hidden, last_hh
#         )
        
#         # --- 3. Masking & Softmax ---
#         # Masking: Set logits of invalid actions to -inf
#         # mask = 1 (valid), 0 (invalid)
#         veh_logits = veh_logits.masked_fill(mask_vehicles == 0, float('-inf'))
#         node_logits = node_logits.masked_fill(mask_customers == 0, float('-inf'))
        
#         veh_probs = F.softmax(veh_logits, dim=1)
#         node_probs = F.softmax(node_logits, dim=1)
        
#         return veh_probs, node_probs, last_hh

import torch
import torch.nn as nn
import torch.nn.functional as F

class PairwiseEmbedding(nn.Module):
    """
    Tạo vector embedding cho từng cặp (Vehicle, Customer).
    Input: Static (Customer) & Dynamic (Vehicle)
    Output: Tensor [Batch, Hidden, Num_Vehicles, Num_Customers]
    """
    def __init__(self, static_size, dynamic_size, hidden_size):
        super(PairwiseEmbedding, self).__init__()
        # Input size = feature tĩnh + feature động
        self.conv2d = nn.Conv2d(static_size + dynamic_size, hidden_size, kernel_size=1)
        
    def forward(self, static, dynamic):
        """
        static: [Batch, Static_Feat, Num_Customers]
        dynamic: [Batch, Dyn_Feat, Num_Vehicles]
        """
        B, S_Feat, N_Cust = static.size()
        _, D_Feat, N_Veh = dynamic.size()
        
        # 1. Broadcasting để khớp kích thước
        # Static: [B, S_Feat, 1, N_Cust] -> Lặp lại cho mọi Vehicle
        static_expanded = static.unsqueeze(2).expand(-1, -1, N_Veh, -1)
        
        # Dynamic: [B, D_Feat, N_Veh, 1] -> Lặp lại cho mọi Customer
        dynamic_expanded = dynamic.unsqueeze(3).expand(-1, -1, -1, N_Cust)
        
        # 2. Concatenate: [B, S+D, N_Veh, N_Cust]
        combined = torch.cat([static_expanded, dynamic_expanded], dim=1)
        
        # 3. Embedding (Conv2d kernel 1 tương đương Linear cho từng cặp)
        # Output: [B, Hidden, N_Veh, N_Cust]
        pairwise_embeds = self.conv2d(combined)
        return pairwise_embeds

class HierarchicalDecoder(nn.Module):
    def __init__(self, hidden_size, dropout=0.1):
        super(HierarchicalDecoder, self).__init__()
        self.hidden_size = hidden_size
        
        # LSTM để nhớ ngữ cảnh quá khứ (History)
        self.lstm = nn.LSTMCell(hidden_size, hidden_size)
        
        # --- Attention cho bước 1: Chọn Vehicle ---
        # Query: LSTM State + Global Context
        # Key: Vehicle Representation (Aggregated from customers)
        self.W_veh = nn.Linear(hidden_size * 2, hidden_size) # Project Context
        self.v_veh = nn.Parameter(torch.rand(hidden_size))
        
        # --- Attention cho bước 2: Chọn Customer ---
        # Query: LSTM State + Selected Vehicle Info
        # Key: Pairwise Embedding của (Selected Vehicle, Customers)
        self.W_cust = nn.Linear(hidden_size * 2, hidden_size)
        self.v_cust = nn.Parameter(torch.rand(hidden_size))
        
    def forward(self, pairwise_embeds, decoder_input, last_hh, mask_veh=None, mask_cust=None, deterministic=False):
        """
        pairwise_embeds: [B, H, N_Veh, N_Cust]
        decoder_input: [B, H] (Embedding của node vừa ghé thăm)
        """
        h_t, c_t = last_hh
        h_t, c_t = self.lstm(decoder_input, (h_t, c_t)) # Update LSTM
        
        B, H, N_Veh, N_Cust = pairwise_embeds.size()
        
        # =========================================================
        # BƯỚC 1: CHỌN VEHICLE (Vehicle Selection)
        # =========================================================
        
        # 1. Tạo Vector đại diện cho từng Vehicle
        # Bằng cách: Gộp (Mean Pooling) tất cả Customer tương ứng với Vehicle đó
        # Shape: [B, H, N_Veh, N_Cust] -> [B, H, N_Veh]
        veh_repr = pairwise_embeds.mean(dim=3) 
        
        # 2. Tính điểm (Attention Score) cho từng Vehicle
        # Context gồm: LSTM output (h_t) mở rộng
        # Score = v^T * tanh(W_veh * [veh_repr; h_t])
        
        h_t_expanded_v = h_t.unsqueeze(2).expand(-1, -1, N_Veh) # [B, H, V]
        
        # Gộp Vehicle Rep và LSTM Context (theo chiều feature dim 1)
        # Input cho attention: [B, 2*H, V] -> transpose -> [B, V, 2*H]
        veh_att_input = torch.cat([veh_repr, h_t_expanded_v], dim=1).transpose(1, 2)
        
        # Tính Energy: [B, V, H] -> [B, V]
        veh_energy = torch.matmul(torch.tanh(self.W_veh(veh_att_input)), self.v_veh)
        
        # Masking & Softmax
        if mask_veh is not None:
            veh_energy = veh_energy.masked_fill(mask_veh == 0, float('-inf'))
        veh_probs = F.softmax(veh_energy, dim=1)
        
        # 3. Chọn Vehicle (Sampling hoặc Greedy)
        if deterministic:
            selected_veh_idx = torch.argmax(veh_probs, dim=1) # [B]
        else:
            dist = torch.distributions.Categorical(veh_probs)
            selected_veh_idx = dist.sample() # [B]

        # =========================================================
        # BƯỚC 2: CHỌN CUSTOMER (Customer Selection)
        # =========================================================
        
        # 1. Lấy vector cặp của (Vehicle ĐƯỢC CHỌN, Tất cả Customers)
        # Chúng ta cần lấy lát cắt (slice) tương ứng với selected_veh_idx
        
        # Tạo index để gather: [B, H, 1, N_Cust]
        idx_view = selected_veh_idx.view(B, 1, 1, 1).expand(-1, H, 1, N_Cust)
        
        # Gather: Lấy ra [B, H, 1, N_Cust] -> squeeze -> [B, H, N_Cust]
        # Đây là vector đặc trưng của việc "Vehicle X đi đến từng Customer"
        selected_veh_cust_embeds = pairwise_embeds.gather(2, idx_view).squeeze(2)
        
        # 2. Tính điểm cho từng Customer
        # Context: LSTM output (h_t)
        h_t_expanded_c = h_t.unsqueeze(2).expand(-1, -1, N_Cust) # [B, H, N]
        
        # Input: [Pairwise(V_selected, C); h_t]
        cust_att_input = torch.cat([selected_veh_cust_embeds, h_t_expanded_c], dim=1).transpose(1, 2)
        
        # Tính Energy: [B, N]
        cust_energy = torch.matmul(torch.tanh(self.W_cust(cust_att_input)), self.v_cust)
        
        # Masking & Softmax
        if mask_cust is not None:
            cust_energy = cust_energy.masked_fill(mask_cust == 0, float('-inf'))
        cust_probs = F.softmax(cust_energy, dim=1)
        
        # Trả về cả index xe đã chọn để bên ngoài biết
        return veh_probs, cust_probs, selected_veh_idx, (h_t, c_t)


class MOPVRP_Actor(nn.Module):
    def __init__(self, static_size, dynamic_size_truck, dynamic_size_drone, 
                 hidden_size, dropout=0.1):
        super(MOPVRP_Actor, self).__init__()
        
        # Tự động tính kích thước Dynamic lớn nhất để Padding
        self.max_dyn_size = max(dynamic_size_truck, dynamic_size_drone)
        
        # Encoder tạo ma trận cặp
        self.pairwise_encoder = PairwiseEmbedding(static_size, self.max_dyn_size, hidden_size)
        
        # Embed tọa độ (x,y) của node trước đó làm input cho LSTM
        self.coords_embedding = nn.Linear(2, hidden_size)
        
        # Decoder chính
        self.decoder = HierarchicalDecoder(hidden_size)
        
        # Learnable initial state
        self.x0 = nn.Parameter(torch.zeros(1, 2))
        self.h0 = nn.Parameter(torch.zeros(1, hidden_size))
        self.c0 = nn.Parameter(torch.zeros(1, hidden_size))
        
        # Khởi tạo trọng số
        self._init_weights()

    def _init_weights(self):
        """Khởi tạo cơ bản"""
        for name, param in self.named_parameters():
            if 'weight' in name and param.dim() > 1:
                nn.init.xavier_uniform_(param)
            if 'bias' in name:
                nn.init.constant_(param, 0)
    
    def _pad_and_combine_vehicles(self, trucks, drones):
        """
        Hàm helper: Padding feature và gộp Truck + Drone thành 1 tensor
        Trucks: [B, F_T, N_T]
        Drones: [B, F_D, N_D]
        Output: [B, Max_F, N_T + N_D]
        """
        # Pad Truck
        diff_t = self.max_dyn_size - trucks.size(1)
        if diff_t > 0:
            pad_t = torch.zeros(trucks.size(0), diff_t, trucks.size(2), device=trucks.device)
            trucks = torch.cat([trucks, pad_t], dim=1)
            
        # Pad Drone
        diff_d = self.max_dyn_size - drones.size(1)
        if diff_d > 0:
            pad_d = torch.zeros(drones.size(0), diff_d, drones.size(2), device=drones.device)
            drones = torch.cat([drones, pad_d], dim=1)
            
        # Gộp lại
        return torch.cat([trucks, drones], dim=2)

    def forward(self, static, dynamic_trucks, dynamic_drones, 
                decoder_input=None, last_hh=None, mask_customers=None, mask_vehicles=None, deterministic=False):
        
        batch_size = static.size(0)
        
        # 1. Xử lý Input: Padding & Combine
        dynamic_vehicles = self._pad_and_combine_vehicles(dynamic_trucks, dynamic_drones)
        
        # 2. Tạo Pairwise Embedding [B, H, V, N]
        pairwise_embeds = self.pairwise_encoder(static, dynamic_vehicles)
        
        # 3. Chuẩn bị LSTM Input
        if decoder_input is None:
            decoder_input = self.x0.expand(batch_size, -1)
        
        decoder_input_embed = self.coords_embedding(decoder_input)
        
        if last_hh is None:
            last_hh = (self.h0.expand(batch_size, -1), self.c0.expand(batch_size, -1))
            
        # 4. Giải mã Hierarchical
        # Lưu ý: Hàm này trả thêm selected_veh_idx vì nó được chọn nội bộ
        veh_probs, node_probs, selected_veh_idx, last_hh = self.decoder(
            pairwise_embeds, 
            decoder_input_embed, 
            last_hh, 
            mask_vehicles, 
            mask_customers,
            deterministic
        )
        
        # Trả về dạng (Veh_Probs, Node_Probs, Last_HH) như cũ
        # Nhưng LƯU Ý: Trong vòng lặp training PPO, bạn nên sử dụng `selected_veh_idx` 
        # được trả về từ model này thay vì sample lại bên ngoài (để đồng bộ).
        # Tuy nhiên, để khớp API cũ, ta trả về các biến chính.
        
        # Hack nhẹ: Gắn selected_veh_idx vào tuple trả về hoặc xử lý ở PPOTrainer
        # Ở đây tôi trả về thêm 1 biến thứ 4, bạn chỉ cần sửa dòng gọi hàm trong PPOTrainer là:
        # veh_probs, node_probs, last_hh, internal_veh_idx = model(...)
        
        return veh_probs, node_probs, selected_veh_idx, last_hh

    # ======================================================================
    # NEW METHOD ADDED: PERTURB WEIGHTS (Chỉ dùng khi Test/Debug)
    # ======================================================================
    def perturb_weights(self, noise_scale=1.0):
        """
        Thêm nhiễu mạnh để phá vỡ thế kẹt (Local Optima) cho kiến trúc Hierarchical.
        """
        print(f"⚡ [Hierarchical_Actor] Adding STRONG noise (scale={noise_scale})...")
        with torch.no_grad():
            # 1. Nhiễu Encoder (Pairwise Conv2d)
            # Thay thế hoàn toàn trọng số bằng phân phối Uniform rộng (Reset mạnh)
            if hasattr(self.pairwise_encoder, 'conv2d'):
                self.pairwise_encoder.conv2d.weight.data.uniform_(-noise_scale, noise_scale)
                if self.pairwise_encoder.conv2d.bias is not None:
                     self.pairwise_encoder.conv2d.bias.data.uniform_(-noise_scale, noise_scale)

            # 2. Nhiễu Decoder (Hierarchical Steps)
            # Với các lớp Linear (W), ta cộng thêm nhiễu (Additive Noise) thay vì thay thế
            # để giữ lại một phần kiến thức đã học nhưng làm rung chuyển nó.
            
            # --- Nhánh chọn Vehicle ---
            self.decoder.W_veh.weight.data += torch.randn_like(self.decoder.W_veh.weight.data) * noise_scale
            self.decoder.v_veh.data.normal_(0, noise_scale * 2) # Vector v reset mạnh
            
            # --- Nhánh chọn Customer ---
            self.decoder.W_cust.weight.data += torch.randn_like(self.decoder.W_cust.weight.data) * noise_scale
            self.decoder.v_cust.data.normal_(0, noise_scale * 2) # Vector v reset mạnh
            
        print("✓ Hierarchical Weights perturbed successfully.")

    def _init_weights_high_variance(self):
        """
        Khởi tạo trọng số với phương sai lớn để phá vỡ tính đối xứng ban đầu.
        Giúp model không bị tình trạng chọn tất cả các xe/khách với xác suất ngang nhau (50/50).
        """
        scale_factor = 2.0  # Std lớn để tạo logit lớn -> Softmax nhọn (Sharp)
        
        with torch.no_grad():
            # 1. Các ma trận chiếu (Linear Projection - W)
            # Giữ Xavier để đảm bảo luồng gradient ổn định qua tanh()
            nn.init.xavier_uniform_(self.decoder.W_veh.weight)
            nn.init.xavier_uniform_(self.decoder.W_cust.weight)
            
            # Nếu có bias thì đưa về 0
            if self.decoder.W_veh.bias is not None: nn.init.zeros_(self.decoder.W_veh.bias)
            if self.decoder.W_cust.bias is not None: nn.init.zeros_(self.decoder.W_cust.bias)

            # 2. Các vector năng lượng (Scoring Vectors - v)
            # Dùng Normal distribution với độ lệch chuẩn LỚN
            # Điều này khiến điểm Energy ban đầu dao động mạnh, giúp mô hình
            # "dám" đưa ra quyết định dứt khoát ngay từ đầu thay vì ngập ngừng.
            nn.init.normal_(self.decoder.v_veh, mean=0.0, std=scale_factor)
            nn.init.normal_(self.decoder.v_cust, mean=0.0, std=scale_factor)
            
            # 3. Pairwise Encoder
            # Khởi tạo Kaiming cho Conv2d (tốt cho ReLU/Non-linearity sau đó)
            nn.init.kaiming_normal_(self.pairwise_encoder.conv2d.weight, mode='fan_out', nonlinearity='relu')
        
        print(f"⚡ Weights initialized with High Variance (std={scale_factor}) to force random bias.")

class Critic(nn.Module):
    def __init__(self, static_size, dynamic_size_truck, dynamic_size_drone, hidden_size):
        super(Critic, self).__init__()
        self.static_conv = nn.Conv1d(static_size, hidden_size, kernel_size=1)
        self.truck_conv = nn.Conv1d(dynamic_size_truck, hidden_size, kernel_size=1)
        self.drone_conv = nn.Conv1d(dynamic_size_drone, hidden_size, kernel_size=1)
        self.fc1 = nn.Linear(hidden_size * 3, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.fc3 = nn.Linear(hidden_size // 2, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, static, dynamic_trucks, dynamic_drones):
        static_embed = self.static_conv(static)
        truck_embed = self.truck_conv(dynamic_trucks)
        drone_embed = self.drone_conv(dynamic_drones)
        combined = torch.cat([static_embed.mean(2), truck_embed.mean(2), drone_embed.mean(2)], dim=1)
        x = self.relu(self.fc1(combined))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        return self.fc3(x).squeeze(-1)



# def check_model_compatibility():
#     print("\n🚀 STARTING COMPATIBILITY CHECK...")
    
#     # 1. Setup
#     DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     BATCH_SIZE = 4
#     HIDDEN_SIZE = 128
    
#     # Kích thước đặc trưng theo DataLoader
#     STATIC_SIZE = 4       # x, y, demand, type
#     DYN_TRUCK_SIZE = 2    # loc, time
#     DYN_DRONE_SIZE = 4    # loc, time, energy, payload
    
#     # 2. Init Model
#     print(f"🔹 Initializing Model on {DEVICE}...")
#     model = MOPVRP_Actor(
#         static_size=STATIC_SIZE,
#         dynamic_size_truck=DYN_TRUCK_SIZE,
#         dynamic_size_drone=DYN_DRONE_SIZE,
#         hidden_size=HIDDEN_SIZE
#     ).to(DEVICE)



#     # =================================================================
#     # QUAN TRỌNG: GỌI HÀM LÀM NHIỄU Ở ĐÂY
#     # =================================================================
#     try:
#         checkpoint_path = "/Users/nguyentrithanh/Documents/20251/Project3/IT3940E-RL_for_MOVRP/ptr-net/checkpoints/checkpoint_epoch_497.pth"  # Đường dẫn checkpoint nếu có
    
#         checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
#         model.load_state_dict(checkpoint["actor_state_dict"])
#         # Gọi hàm perturb_weights với noise lớn để thấy rõ sự khác biệt
#         model.perturb_weights(noise_scale=1.0)
#         model._init_weights_high_variance()
#     except:
#         print("⚠️ Warning: Model chưa có hàm perturb_weights. Hãy cập nhật class MOPVRP_Actor trước.")
#     # =================================================================
    
#     # 3. Init DataLoader
#     print("🔹 Initializing DataLoader...")
#     # Giả sử bạn đã định nghĩa class MOPVRPGenerator ở trên
#     dataloader = get_rl_dataloader(batch_size=BATCH_SIZE, device=DEVICE)
#     data_iter = iter(dataloader)
    
#     # 4. Run Test
#     try:
#         # Lấy 1 batch
#         print("🔹 Fetching Batch data...")
#         static, dyn_trucks, dyn_drones, mask_cust, mask_veh, scale, weights = next(data_iter)
        
#         batch_size, _, num_nodes = static.shape
#         num_trucks = dyn_trucks.shape[2]
#         num_drones = dyn_drones.shape[2]
        
#         print(f"   Input Shapes:")
#         print(f"   - Static: {static.shape}")
#         print(f"   - Trucks: {dyn_trucks.shape}")
#         print(f"   - Drones: {dyn_drones.shape}")

#         # --- CAN THIỆP THỦ CÔNG ĐỂ TEST (Nuclear Option) ---
#         print("\n☢️  MANUALLY HACKING WEIGHTS TO FORCE SKEW...")
#         with torch.no_grad():
#             # 1. Ép xe đầu tiên (Index 0) có điểm số cực cao
#             # model.pointer.v_veh shape: (1, 1, hidden)
#             # Ta cộng một số rất lớn vào phần tử đầu tiên của vector v
#             model.pointer.v_veh.data.fill_(10.0) # Tăng độ lớn vector v lên
            
#             # Ép bias của xe đầu tiên trong lớp Linear W_veh (nếu có)
#             # Nhưng ở đây W_veh không có bias, ta hack vào input trucks
#             # Thay vào đó, ta hack trực tiếp vào decoder output của xe
            
#             # Cách hiệu quả nhất: Hack vào lớp Conv1d của Truck Encoder
#             # Làm cho đặc trưng của Truck 0 cực kỳ khác biệt so với các xe khác
#             # Truck Encoder weights: (hidden, input_size, 1)
#             model.truck_encoder.conv.weight.data.normal_(0, 5.0) 
#             model.drone_encoder.conv.weight.data.normal_(0, 2.0) # Drone nhỏ xíu
            
#             # Hack vào v_node để làm lệch Node Probs
#             model.pointer.v_node.data.normal_(0, 5.0)
            
#         print("✅ Weights hacked successfully.")
#         # ---------------------------------------------------

        
#         # Forward Pass
#         print("🔹 Running Forward Pass...")
#         veh_probs, node_probs, last_hh = model(
#             static, dyn_trucks, dyn_drones, 
#             decoder_input=None, 
#             last_hh=None, 
#             mask_customers=mask_cust, 
#             mask_vehicles=mask_veh
#         )
        
#         print("✅ Forward Pass Successful!")
#         print(f"   Output Shapes:")
#         print(f"   - Vehicle Probs: {veh_probs.shape} (Expected: [{batch_size}, {num_trucks + num_drones}])")
#         print(f"   - Node Probs:    {node_probs.shape} (Expected: [{batch_size}, {num_nodes}])")

#         print(f"   Output Explicit Probability:")
#         print(f"   - Vehicle Probs: {veh_probs.detach().cpu().numpy()}")
#         print(f"   - Node Probs:    {node_probs.detach().cpu().numpy()}")
        
#         # Kiểm tra tổng xác suất = 1
#         print(f"   - Sum Vehicle Probs: {veh_probs.sum(dim=1).detach().cpu().numpy()}")
#         print(f"   - Sum Node Probs:    {node_probs.sum(dim=1).detach().cpu().numpy()}")
        
#     except Exception as e:
#         print(f"\n❌ FAILED! Error: {e}")
#         import traceback
#         traceback.print_exc()

# if __name__ == "__main__":
#     check_model_compatibility()

import torch
import torch.nn.functional as F

def check_model_compatibility():
    print("\n🚀 STARTING HIERARCHICAL MODEL COMPATIBILITY CHECK...")
    
    # 1. Setup
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    BATCH_SIZE = 4
    HIDDEN_SIZE = 128
    
    # Kích thước đặc trưng giả định (Mô phỏng dữ liệu thật của bạn)
    STATIC_SIZE = 4       # x, y, demand, type
    DYN_TRUCK_SIZE = 2    # loc, load (Ít feature hơn)
    DYN_DRONE_SIZE = 4    # loc, energy, payload, time (Nhiều feature hơn)
    
    # Kích thước Dynamic đầu vào cho Model phải là MAX của 2 loại xe
    # Vì chúng ta sẽ padding thằng nhỏ lên bằng thằng lớn
    MAX_DYN_SIZE = max(DYN_TRUCK_SIZE, DYN_DRONE_SIZE)

    # 2. Init Hierarchical Model
    print(f"🔹 Initializing MOPVRP_HierarchicalActor on {DEVICE}...")
    # Lưu ý: Class mới chỉ cần 3 tham số này
    model = MOPVRP_Actor(
        static_size=STATIC_SIZE,
        dynamic_size_truck=DYN_TRUCK_SIZE, 
        dynamic_size_drone=DYN_DRONE_SIZE,
        hidden_size=HIDDEN_SIZE
    ).to(DEVICE)

    # =================================================================
    # PHẦN 3: KIỂM TRA TÍNH NĂNG NHIỄU (PERTURBATION CHECK)
    # =================================================================
    try:
        # Đường dẫn checkpoint (Giữ nguyên của bạn)
        checkpoint_path = "/Users/nguyentrithanh/Documents/20251/Project3/IT3940E-RL_for_MOVRP/ptr-net/checkpoints/checkpoint_epoch_497.pth" 
        
        # Thử load (nếu file tồn tại)
        import os
        if os.path.exists(checkpoint_path):
            print(f"🔹 Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
            # Lưu ý: Key state_dict có thể khác nếu bạn đổi tên class, cần check kỹ
            # model.load_state_dict(checkpoint["actor_state_dict"], strict=False) 
        else:
            print("⚠️ Checkpoint file not found. Using random weights.")

        # Gọi hàm perturb_weights mới
        print("⚡ Testing perturb_weights()...")
        model.perturb_weights(noise_scale=5.0)
        model._init_weights_high_variance()
        
    except Exception as e:
        print(f"⚠️ Warning during perturbation: {e}")
    # =================================================================
    
    # 4. Tạo Dữ liệu Giả lập (Dummy Data) 
    # (Tôi tạo trực tiếp để bạn chạy được ngay mà không cần Dataloader)
    print("🔹 Generating Dummy Data with different dimensions...")
    
    NUM_NODES = 20
    NUM_TRUCKS = 2
    NUM_DRONES = 3
    
    # Static: [B, 4, N]
    static = torch.rand(BATCH_SIZE, STATIC_SIZE, NUM_NODES).to(DEVICE)
    
    # Dynamic Truck: [B, 2, T] (Ít chiều)
    dyn_trucks_raw = torch.rand(BATCH_SIZE, DYN_TRUCK_SIZE, NUM_TRUCKS).to(DEVICE)
    
    # Dynamic Drone: [B, 4, D] (Nhiều chiều)
    dyn_drones_raw = torch.rand(BATCH_SIZE, DYN_DRONE_SIZE, NUM_DRONES).to(DEVICE)
    
    mask_cust = torch.ones(BATCH_SIZE, NUM_NODES).to(DEVICE)
    mask_veh = torch.ones(BATCH_SIZE, NUM_TRUCKS + NUM_DRONES).to(DEVICE)

    # 3. Init DataLoader
    print("🔹 Initializing DataLoader...")
    # Giả sử bạn đã định nghĩa class MOPVRPGenerator ở trên
    dataloader = get_rl_dataloader(batch_size=BATCH_SIZE, device=DEVICE)
    data_iter = iter(dataloader)

    print("🔹 Fetching Batch data...")
    static, dyn_trucks, dyn_drones, mask_cust, mask_veh, scale, weights = next(data_iter)

    batch_size, _, num_nodes = static.shape
    num_trucks = dyn_trucks.shape[2]
    num_drones = dyn_drones.shape[2]

    # print(f"   Input Shapes:")
    # print(f"   - Static: {static.shape}")
    # print(f"   - Trucks: {dyn_trucks.shape}")
    # print(f"   - Drones: {dyn_drones.shape}")
    

    # =================================================================
    # PHẦN 5: PADDING LOGIC (QUAN TRỌNG)
    # =================================================================
    print(f"🔹 Processing Dynamic Features (Padding)...")
    print(f"   Original Truck Shape: {dyn_trucks.shape}")
    print(f"   Original Drone Shape: {dyn_drones.shape}")
    
    def pad_feature_dim(tensor, target_dim):
        """Hàm padding feature dimension (dim 1) cho bằng target_dim"""
        b, f, n = tensor.size()
        diff = target_dim - f
        if diff > 0:
            # Tạo tensor 0 có kích thước [B, diff, N]
            padding = torch.zeros(b, diff, n, device=tensor.device)
            # Nối vào đuôi feature
            return torch.cat([tensor, padding], dim=1)
        return tensor

    # Pad cả 2 loại xe để đảm bảo cùng số feature = MAX_DYN_SIZE
    dyn_trucks_padded = pad_feature_dim(dyn_trucks, MAX_DYN_SIZE)
    dyn_drones_padded = pad_feature_dim(dyn_drones, MAX_DYN_SIZE)
    
    print(f"   -> Padded Truck Shape: {dyn_trucks_padded.shape}")
    print(f"   -> Padded Drone Shape: {dyn_drones_padded.shape}")

    # =================================================================
    # PHẦN 6: MANUAL HACKING (Update cho Hierarchical Model)
    # =================================================================
    # print("\n☢️  MANUALLY HACKING WEIGHTS (HIERARCHICAL VERSION)...")
    # with torch.no_grad():
    #     # 1. Hack vào Pairwise Embedding (Conv2d)
    #     # Làm cho đặc trưng của cặp (Xe 0, Khách hàng) cực mạnh
    #     # Pairwise Encoder: self.pairwise_encoder.conv2d
    #     print("   -> Hacking Pairwise Conv2d...")
    #     model.pairwise_encoder.conv2d.weight.data.normal_(0, 5.0) 
        
    #     # 2. Hack vào nhánh chọn Vehicle (W_veh, v_veh)
    #     # Ép model cực kỳ thiên vị khi chọn xe
    #     print("   -> Hacking Vehicle Selection Branch...")
    #     model.decoder.v_veh.data.fill_(10.0) # Tăng độ lớn vector chấm điểm xe
    #     model.decoder.W_veh.weight.data.normal_(0, 5.0)

    #     # 3. Hack vào nhánh chọn Customer (W_cust, v_cust)
    #     print("   -> Hacking Customer Selection Branch...")
    #     model.decoder.v_cust.data.fill_(10.0) # Tăng độ lớn vector chấm điểm khách
    #     model.decoder.W_cust.weight.data.normal_(0, 5.0)
        
    # print("✅ Weights hacked successfully.")

    # =================================================================
    # PHẦN 7: FORWARD PASS
    # =================================================================
    try:
        print("\n🔹 Running Forward Pass...")
        # Lấy 1 batch

        
        # Lưu ý: Truyền vào tensor ĐÃ ĐƯỢC PADDING
        veh_probs, node_probs, idx, last_hh = model(
            static, 
            dyn_trucks_padded, 
            dyn_drones_padded, 
            decoder_input=None, 
            last_hh=None, 
            mask_customers=mask_cust, 
            mask_vehicles=mask_veh
        )
        
        print("✅ Forward Pass Successful!")
        print(f"\n📊 OUTPUT ANALYSIS:")
        
        # Check Shape
        expected_veh = NUM_TRUCKS + NUM_DRONES
        print(f"   - Vehicle Probs Shape: {veh_probs.shape} (Expected: [{BATCH_SIZE}, {expected_veh}])")
        print(f"   - Node Probs Shape:    {node_probs.shape} (Expected: [{BATCH_SIZE}, {NUM_NODES}])")
        print(f"   - Selected Index Shape: {idx.shape}")

        # Check Values
        print(f"\n   Example Probs (Batch 0):")
        print(f"   - Vehicle Probs: {veh_probs[0].detach().cpu().numpy().round(3)}")
        print(f"   - Node Probs:    {node_probs[0].detach().cpu().numpy().round(3)}")
        
        # Check Sum = 1
        sum_veh = veh_probs.sum(dim=1).detach().cpu().numpy()
        sum_node = node_probs.sum(dim=1).detach().cpu().numpy()
        print(f"\n   Probability Integrity Check (Should be all 1.0):")
        print(f"   - Sum Veh:  {sum_veh}")
        print(f"   - Sum Node: {sum_node}")
        
    except Exception as e:
        print(f"\n❌ FAILED! Error during forward pass: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    check_model_compatibility()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_mopvrp(static_tensor, truck_routes, drone_routes, pretrained, title="MOPVRP Solution"):
    """
    Vẽ biểu đồ lộ trình cho bài toán MOPVRP.
    
    Args:
        static_tensor: Tensor (4, N) chứa [x, y, demand, truck_only]
        truck_routes: List of Lists [[0, 5, 2, 0], [0, 1, 0]] (Lộ trình từng xe tải)
        drone_routes: List of Lists [[0, 8, 0], [0, 4, 0]] (Lộ trình từng drone)
        title: Tiêu đề biểu đồ
    """
    # Chuyển Tensor sang Numpy
    if isinstance(static_tensor, torch.Tensor):
        static = static_tensor.cpu().numpy()
    else:
        static = static_tensor

    coords = static[:2, :] # (2, N)
    truck_only_mask = static[3, :] == 1
    
    plt.figure(figsize=(10, 8))
    
    # 1. Vẽ các Nodes (Khách hàng & Depot)
    # Depot (Node 0)
    plt.scatter(coords[0, 0], coords[1, 0], c='red', marker='s', s=200, label='Depot', zorder=10)
    
    # Khách thường (Flexible)
    flexible_indices = np.where(~truck_only_mask)[0]
    flexible_indices = flexible_indices[flexible_indices != 0] # Bỏ depot
    plt.scatter(coords[0, flexible_indices], coords[1, flexible_indices], 
                c='blue', marker='o', s=50, label='Customer (Flexible)', alpha=0.6)
    
    # Khách Truck-Only
    truck_only_indices = np.where(truck_only_mask)[0]
    plt.scatter(coords[0, truck_only_indices], coords[1, truck_only_indices], 
                c='black', marker='X', s=80, label='Customer (Truck-Only)')

    # Annotate ID cho node
    for i in range(coords.shape[1]):
        plt.text(coords[0, i] + 0.01, coords[1, i] + 0.01, str(i), fontsize=9)

    # 2. Vẽ lộ trình Xe tải (Nét liền)
    colors = plt.cm.get_cmap('tab10', len(truck_routes) + len(drone_routes))
    
    for i, route in enumerate(truck_routes):
        if len(route) < 2: continue # Xe không đi đâu
        
        route_coords = coords[:, route]
        color = colors(i)
        
        plt.plot(route_coords[0], route_coords[1], c=color, linewidth=2, 
                 label=f'Truck {i}', linestyle='-', marker='.')
        ""
        # Vẽ mũi tên hướng đi
        mid_idx = len(route) // 2
        if mid_idx < len(route) - 1:
            p1 = coords[:, route[mid_idx]]
            p2 = coords[:, route[mid_idx+1]]
            plt.arrow(p1[0], p1[1], (p2[0]-p1[0])*0.5, (p2[1]-p1[1])*0.5, 
                      head_width=0.015, color=color)

    # 3. Vẽ lộ trình Drone (Nét đứt)
    for i, route in enumerate(drone_routes):
        if len(route) < 2: continue
        
        # Drone trips thường là star-shaped: 0 -> Node -> 0
        # Vẽ từng đoạn nhỏ để không bị rối
        color = colors(len(truck_routes) + i)
        
        # Vẽ các chuyến đi (Trips)
        # Giả sử route là [0, 5, 0, 8, 0] -> Vẽ từng cặp (0,5), (5,0), ...
        route_coords = coords[:, route]
        plt.plot(route_coords[0], route_coords[1], c=color, linewidth=1.5, 
                 label=f'Drone {i}', linestyle='--')

    suffix = " (Pretrained Model)" if pretrained else " (Random Model)"
    title = title + suffix

    plt.title(title)
    plt.xlabel("X Coordinate (Normalized)")
    plt.ylabel("Y Coordinate (Normalized)")
    plt.legend(loc='upper right', bbox_to_anchor=(1.15, 1))
    plt.grid(True, linestyle=':', alpha=0.6)
    plt.tight_layout()
    plt.savefig(f"{suffix}.png", dpi=300)
    plt.show()

In [None]:
import torch
import numpy as np

# from config import SystemConfig
# from model import MOPVRP_Actor
# from dataloader import get_rl_dataloader
# from visualizer import visualize_mopvrp

class MOPVRPEnvironment:
    def __init__(self, config, dataloader, device='cpu'):
        self.config = config
        self.dataloader = dataloader
        self.device = device
        self.data_iter = iter(dataloader)
        
        # State placeholders
        self.static = None       # [B, 4, N]
        self.dynamic_truck = None # [B, 2, K]
        self.dynamic_drone = None # [B, 4, D]
        self.mask_cust = None     # [B, N]
        self.scale = None         # [B, 1]
        self.weights = None       # [B, 2]

        # Scenario Configuraion
        self.scenario = 5     # Default

    def set_scenario(self, scenario):
        self.scenario = scenario
        return

    def reset(self):
        self.dataloader.dataset.set_scenario(self.scenario)
        batch_data = next(self.data_iter)
        static_src, self.dynamic_truck, self.dynamic_drone, \
        self.mask_cust, mask_veh, self.scale, self.weights = batch_data
        
        self.static = static_src.clone()
        self.batch_size = self.static.size(0)
        self.num_nodes = self.static.size(2)
        self.num_trucks = self.dynamic_truck.size(2)
        self.num_drones = self.dynamic_drone.size(2)

        # print(f"Current Scenario: {self.scenario}, Num Trucks: {self.num_trucks}, Num Drones: {self.num_drones}")
        
        self.routes = [[{'trucks': [[0] for _ in range(self.num_trucks)], 
                         'drones': [[0] for _ in range(self.num_drones)]} 
                        for _ in range(self.batch_size)]]
        
        self.total_waiting_time = torch.zeros(self.batch_size, device=self.device)
        
        return (self.static, self.dynamic_truck, self.dynamic_drone, self.mask_cust, self._update_vehicle_mask())

    def step(self, selected_vehicle_idx, selected_node_idx):
        is_drone = selected_vehicle_idx >= self.num_trucks
        drone_idx = selected_vehicle_idx - self.num_trucks
        truck_idx = selected_vehicle_idx

        prev_truck_times = self.dynamic_truck[:, 1, :].clone()
        prev_drone_times = self.dynamic_drone[:, 1, :].clone()
        
        # 1. Update Physics
        self._update_truck_state(truck_idx, selected_node_idx, ~is_drone)
        self._update_drone_state(drone_idx, selected_node_idx, is_drone)

        # --- Cập nhật Demand trong Static Feature ---
        # Logic: Node được chọn -> Demand giảm về 0 (Đã phục vụ xong)
        
        batch_indices = torch.arange(self.batch_size, device=self.device)
        # Chỉ update nếu node được chọn không phải Depot (0)
        is_customer_node = (selected_node_idx != 0)
        
        if is_customer_node.any():
            active_batch = batch_indices[is_customer_node]
            active_nodes = selected_node_idx[is_customer_node]
            
            # Gán Demand = 0.0 trực tiếp vào Static tensor
            self.static[active_batch, 2, active_nodes] = 0.0
        # --------------------------------------------------------
        
        is_customer = (selected_node_idx != 0)
        
        if is_customer.any():
            current_times = torch.zeros(self.batch_size, device=self.device)

            if (~is_drone).any():
                b_tr = torch.where(~is_drone)[0]
                tr_ids = truck_idx[b_tr]
                current_times[b_tr] = self.dynamic_truck[b_tr, 1, tr_ids]
                
            # Drone times
            if is_drone.any():
                b_dr = torch.where(is_drone)[0]
                dr_ids = drone_idx[b_dr]
                current_times[b_dr] = self.dynamic_drone[b_dr, 1, dr_ids]
            
            self.total_waiting_time += current_times * is_customer.float()

        # 2. Update Customer Mask
        # Nếu chọn Node != 0 thì đánh dấu đã thăm
        node_mask_update = torch.nn.functional.one_hot(selected_node_idx.long(), num_classes=self.num_nodes)
        not_depot = (selected_node_idx != 0).float().unsqueeze(1)
        node_mask_update = node_mask_update * not_depot
        self.mask_cust = self.mask_cust * (1 - node_mask_update)
        
        # 3. CHECK DONE CONDITION
        unvisited_count = self.mask_cust[:, 1:].sum(dim=1)
        
        # Check Truck Location
        truck_locs = self.dynamic_truck[:, 0, :]
        all_trucks_home = (truck_locs == 0).all(dim=1)
        
        # Check Drone Location
        drone_locs = self.dynamic_drone[:, 0, :]
        all_drones_home = (drone_locs == 0).all(dim=1)
        
        # Điều kiện dừng đầy đủ
        done = (unvisited_count == 0) & all_trucks_home & all_drones_home
        
        # Check Force Stop (Deadlock): Nếu còn khách hoặc còn xe ngoài đường, 
        next_mask_veh = self._update_vehicle_mask()
        no_valid_vehicle = (next_mask_veh.sum(dim=1) == 0)
        
        # Nếu bị deadlock (chưa xong việc mà hết xe), ta vẫn trả về done=True để reset môi trường, nhưng sẽ phạt nặng ở reward.
        done = done | no_valid_vehicle

        delta_truck = (self.dynamic_truck[:, 1, :] - prev_truck_times).sum(dim=1)
        delta_drone = (self.dynamic_drone[:, 1, :] - prev_drone_times).sum(dim=1)
        
        # Tổng thời gian tiêu tốn của cả đội trong bước này
        step_cost = delta_truck + delta_drone
        
        # 4. Reward
        step_reward = torch.full((self.batch_size,), -0.1, device=self.device)
        step_reward = -(step_cost / 10000.0)
        
        if done.any():
            final_rewards = self._calculate_terminal_reward()
            # Nếu done do deadlock (vẫn còn khách hoặc xe chưa về): Phạt nặng
            is_failure = (unvisited_count > 0) | (~all_trucks_home) | (~all_drones_home)
            penalty = torch.where(is_failure, torch.tensor(-10.0, device=self.device), torch.tensor(0.0, device=self.device))
            
            step_reward = torch.where(done, final_rewards + penalty, step_reward)
            
        self._update_routes_history(selected_vehicle_idx, selected_node_idx)
        
        next_state = (self.static, self.dynamic_truck, self.dynamic_drone, self.mask_cust, next_mask_veh)
        return next_state, step_reward, done, {}

    def get_valid_customer_mask(self, selected_vehicle_idx):
        """
        Logic Mask Node:
        - Nếu HẾT khách: Bắt buộc chọn Node 0 (cho cả Truck và Drone).
        """
        valid_mask = self.mask_cust.clone()
        
        unvisited_count = self.mask_cust[:, 1:].sum(dim=1)
        is_all_served = (unvisited_count == 0)
        
        # --- END LOGIC ---
        if is_all_served.any():
            served_indices = torch.where(is_all_served)[0]
            valid_mask[served_indices, :] = 0
            valid_mask[served_indices, 0] = 1
        
        # --- NORMAL LOGIC (Khi còn khách) ---
        batch_indices = torch.arange(self.batch_size, device=self.device)
        is_drone = selected_vehicle_idx >= self.num_trucks
        
        # Truck Logic (Start at Depot restriction)
        if (~is_drone).any():
            t_indices = torch.where(~is_drone)[0]
            t_ids = selected_vehicle_idx[t_indices]
            current_times = self.dynamic_truck[t_indices, 1, t_ids]
            
            # Nếu đang ở Depot (Time=0) -> Không được chọn Node 0
            at_start = (current_times == 0)
            if at_start.any():
                start_indices = t_indices[at_start]
                valid_mask[start_indices, 0] = 0
        
        if is_drone.any():
            d_indices = torch.where(is_drone)[0]
            d_ids = selected_vehicle_idx[d_indices] - self.num_trucks
            
            # 1. Enable Node 0 (Luôn cho phép về Depot sạc/chờ)
            valid_mask[d_indices, 0] = 1
            
            # 2. Mask Truck-only customers
            truck_only_flags = self.static[d_indices, 3, :]
            valid_mask[d_indices] *= (1 - truck_only_flags)
            
            # 3. CHECK ENERGY: (Curr -> Node) + (Node -> Depot)
            p = self.config.drone_params
            speed = self.config.drone_speed
            t_const = self.config.t_takeoff + self.config.t_landing
            
            # A. Tính chặng đi: Current -> Candidate Node
            curr_nodes = self.dynamic_drone[d_indices, 0, d_ids].long()
            coords = self.static[d_indices, :2, :] # (N_drone, 2, N_nodes)
            
            curr_xy = self._gather_coords(coords, curr_nodes)
            curr_xy_exp = curr_xy.unsqueeze(2) 
            
            # Khoảng cách đi (N_drone, N_nodes)
            dist_go = torch.norm(coords - curr_xy_exp, dim=1) * self.scale[d_indices, 0].unsqueeze(1)
            
            # B. Tính chặng về: Candidate Node -> Depot (Node 0)
            depot_xy = coords[:, :, 0].unsqueeze(2) # (N_drone, 2, 1)
            dist_return = torch.norm(coords - depot_xy, dim=1) * self.scale[d_indices, 0].unsqueeze(1)
            
            # C. Tính Năng Lượng
            payloads = self.static[d_indices, 2, :] 
            
            # Công suất (Power)
            power = p['gama(w)'] + p['beta(w/kg)'] * payloads
            
            # Năng lượng chặng đi
            time_go = dist_go / speed + t_const
            energy_go = power * time_go
            
            # Năng lượng chặng về
            time_return = dist_return / speed + t_const
            energy_return = power * time_return
            
            # Tổng năng lượng cần thiết
            total_energy_req = energy_go + energy_return

            # --- Xử lý ngoại lệ cho Node 0 (Depot) ---
            # Nếu ĐANG Ở Depot và chọn Node 0 (Đứng chờ): Energy req gần như bằng 0 (hoặc rất nhỏ).
            is_at_depot = (curr_nodes == 0).unsqueeze(1)
            target_is_depot = torch.zeros_like(total_energy_req, dtype=torch.bool)
            target_is_depot[:, 0] = True
            
            # Nếu (Đang ở 0) VÀ (Chọn 0) -> Energy Req = 0
            stay_mask = is_at_depot & target_is_depot
            total_energy_req[stay_mask] = 0.0
            
            # D. So sánh với Pin hiện tại
            energy_req_norm = total_energy_req / self.config.drone_max_energy
            curr_energy = self.dynamic_drone[d_indices, 2, d_ids].unsqueeze(1)
            
            # Tạo mask
            energy_mask = (curr_energy >= energy_req_norm).float()
            valid_mask[d_indices] *= energy_mask

            # 4. Check payloads

            # Payload hiện tại
            curr_load = self.dynamic_drone[d_indices, 3, d_ids].unsqueeze(1) # (N, 1)
            # Demand của các khách hàng tiềm năng
            demands = self.static[d_indices, 2, :] # (N, Nodes)
            # Capacity Max
            max_cap = self.config.drone_capacity_kg 
            
            # Điều kiện: Load sau khi nhặt <= Max Cap
            cap_mask = (curr_load + demands) <= max_cap
            valid_mask[d_indices] *= cap_mask.float()

        # Re-apply End Game logic (để chắc chắn không bị logic energy làm sai mask node 0)
        if is_all_served.any():
            served_indices = torch.where(is_all_served)[0]
            valid_mask[served_indices, 1:] = 0
            
        return valid_mask

    def _update_vehicle_mask(self):
        """
        Logic chọn xe:
        - Xe hết pin/hư hỏng -> Mask 0.
        - Nếu Hết khách (End Game):
            + Xe ĐANG Ở Node 0 -> Mask 0 (Đã về đích, không chọn nữa).
            + Xe CHƯA Ở Node 0 -> Mask 1 (Cần chọn để đưa về).
        """
        mask = torch.ones(self.batch_size, self.num_trucks + self.num_drones, device=self.device)
        
        # 1. Check Energy (Drone)
        energies = self.dynamic_drone[:, 2, :]
        mask[:, self.num_trucks:] = (energies > 0.05).float()
        
        # 2. Check End Game Status
        unvisited_count = self.mask_cust[:, 1:].sum(dim=1)
        is_all_served = (unvisited_count == 0)
        
        if is_all_served.any():
            served_idx = torch.where(is_all_served)[0]
            
            # Logic: Chỉ kích hoạt những xe CHƯA về nhà
            # Truck locations
            t_locs = self.dynamic_truck[served_idx, 0, :]
            t_needs_return = (t_locs != 0).float()
            mask[served_idx, :self.num_trucks] = t_needs_return
            
            # Drone locations
            d_locs = self.dynamic_drone[served_idx, 0, :]
            d_needs_return = (d_locs != 0).float()
            
            # Kết hợp với mask energy cũ
            mask[served_idx, self.num_trucks:] *= d_needs_return
            
        return mask

    def _update_drone_state(self, veh_idx, node_idx, active_mask):
        """Drone: Cập nhật vị trí, thời gian, năng lượng"""
        if not active_mask.any(): return
        b_idx = torch.where(active_mask)[0]
        coords = self.static[b_idx, :2, :]
        curr_nodes = self.dynamic_drone[b_idx, 0, veh_idx[b_idx]].long()
        target_nodes = node_idx[b_idx]
        
        # Check: Đang ở Depot VÀ Đi đến Depot -> Đứng yên (Wait/Recharge)
        is_staying = (curr_nodes == 0) & (target_nodes == 0)
        move_factor = (~is_staying).float()
        
        curr_xy = self._gather_coords(coords, curr_nodes)
        target_xy = self._gather_coords(coords, target_nodes)
        
        dist = torch.norm(target_xy - curr_xy, dim=1) * self.scale[b_idx, 0]
        payloads = self.static[b_idx, 2, target_nodes] 
        
        p = self.config.drone_params
        power = p['gama(w)'] + p['beta(w/kg)'] * payloads
        t_takeoff, t_landing = self.config.t_takeoff, self.config.t_landing
        t_cruise = dist / self.config.drone_speed
        
        # Tính năng lượng: Chỉ tốn khi di chuyển
        energy_joule = power * ((t_takeoff + t_landing) * move_factor + t_cruise)
        norm_cost = energy_joule / self.config.drone_max_energy
        
        # Tính thời gian: Di chuyển tốn time bay, Đứng yên tốn time chờ (ví dụ 30s)
        wait_time = 30.0 
        total_time = ((t_takeoff + t_landing) * move_factor + t_cruise)
        total_time = torch.where(is_staying, torch.tensor(wait_time, device=self.device), total_time)

        # Tính payloads
        current_payloads = self.dynamic_drone[b_idx, 3, veh_idx[b_idx]]
        target_demands = self.static[b_idx, 2, target_nodes]
        new_payloads = current_payloads + target_demands
        
        # Update State
        self.dynamic_drone[b_idx, 0, veh_idx[b_idx]] = target_nodes.float()
        self.dynamic_drone[b_idx, 1, veh_idx[b_idx]] += total_time
        self.dynamic_drone[b_idx, 2, veh_idx[b_idx]] -= norm_cost
        
        
        # Recharge Logic: Cứ về Depot là đầy pin
        is_back_depot = (target_nodes == 0)
        new_payloads = torch.where(is_back_depot, torch.zeros_like(new_payloads), new_payloads)
        # Update payloads
        self.dynamic_drone[b_idx, 3, veh_idx[b_idx]] = new_payloads

        if is_back_depot.any():
            idx_reset = b_idx[is_back_depot]
            d_ids_reset = veh_idx[idx_reset]
            self.dynamic_drone[idx_reset, 2, d_ids_reset] = 1.0
            self.dynamic_drone[idx_reset, 3, d_ids_reset] = 0.0

    def _update_truck_state(self, veh_idx, node_idx, active_mask):
        """Truck: Cập nhật vị trí, thời gian"""
        if not active_mask.any(): return
        b_idx = torch.where(active_mask)[0]
        coords = self.static[b_idx, :2, :]
        curr_nodes = self.dynamic_truck[b_idx, 0, veh_idx[b_idx]].long()
        target_nodes = node_idx[b_idx]
        curr_xy = self._gather_coords(coords, curr_nodes)
        target_xy = self._gather_coords(coords, target_nodes)
        dist = torch.norm(target_xy - curr_xy, dim=1) * self.scale[b_idx, 0]
        current_times = self.dynamic_truck[b_idx, 1, veh_idx[b_idx]]
        speeds = self.config.get_truck_speed_batch(current_times)
        travel_time = dist / speeds
        self.dynamic_truck[b_idx, 0, veh_idx[b_idx]] = target_nodes.float()
        self.dynamic_truck[b_idx, 1, veh_idx[b_idx]] += travel_time

        

    def _gather_coords(self, coords_batch, node_indices):
        idx_expanded = node_indices.view(-1, 1, 1).expand(-1, 2, -1)
        return torch.gather(coords_batch, 2, idx_expanded).squeeze(2)


    def _calculate_terminal_reward(self):
        w1 = self.weights[:, 0]
        w2 = self.weights[:, 1]
        
        # Makespan: Max thời gian của các xe
        truck_times = self.dynamic_truck[:, 1, :]
        drone_times = self.dynamic_drone[:, 1, :]
        all_times = torch.cat([truck_times, drone_times], dim=1)
        makespan, _ = torch.max(all_times, dim=1)
        
        waiting_time = self.total_waiting_time
        
        # Reward = -(w1 * Makespan + w2 * WaitingTime)
        reward = -(w1 * makespan + w2 * waiting_time)
        
        return reward / 1000000.0
        # return reward

    def _update_routes_history(self, veh_idx, node_idx):
        v_idx = veh_idx.cpu().numpy()
        n_idx = node_idx.cpu().numpy()
        for b in range(self.batch_size):
            node_val = int(n_idx[b])
            if v_idx[b] < self.num_trucks:
                self.routes[0][b]['trucks'][v_idx[b]].append(node_val)
            else:
                d_id = v_idx[b] - self.num_trucks
                self.routes[0][b]['drones'][d_id].append(node_val)

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.tensorboard import SummaryWriter
# import numpy as np
# from collections import deque
# import time
# import os
# import json
# from datetime import datetime
# from tqdm import tqdm 

# from ipywidgets import Output
# from IPython.display import display, clear_output
# import random

# # from config import SystemConfig
# # from environment import MOPVRPEnvironment
# # from model import MOPVRP_Actor
# # from dataloader import get_rl_dataloader
# # from visualizer import visualize_mopvrp

# class RolloutBuffer:
#     """Buffer để lưu trữ experience cho PPO (Giữ nguyên)"""
#     def __init__(self):
#         self.states = []
#         self.actions_veh = []
#         self.actions_node = []
#         self.logprobs_veh = []
#         self.logprobs_node = []
#         self.rewards = []
#         self.dones = []
#         self.values = []
#         self.masks_cust = []
        
#     def clear(self):
#         self.states = []
#         self.actions_veh = []
#         self.actions_node = []
#         self.logprobs_veh = []
#         self.logprobs_node = []
#         self.rewards = []
#         self.dones = []
#         self.values = []
#         self.masks_cust = []
    
#     def __len__(self):
#         return len(self.rewards)

# class Critic(nn.Module):
#     def __init__(self, static_size, dynamic_size_truck, dynamic_size_drone, hidden_size):
#         super(Critic, self).__init__()
#         self.static_conv = nn.Conv1d(static_size, hidden_size, kernel_size=1)
#         self.truck_conv = nn.Conv1d(dynamic_size_truck, hidden_size, kernel_size=1)
#         self.drone_conv = nn.Conv1d(dynamic_size_drone, hidden_size, kernel_size=1)
#         self.fc1 = nn.Linear(hidden_size * 3, hidden_size)
#         self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
#         self.fc3 = nn.Linear(hidden_size // 2, 1)
#         self.relu = nn.ReLU()
#         self.dropout = nn.Dropout(0.1)
        
#     def forward(self, static, dynamic_trucks, dynamic_drones):
#         static_embed = self.static_conv(static)
#         truck_embed = self.truck_conv(dynamic_trucks)
#         drone_embed = self.drone_conv(dynamic_drones)
#         combined = torch.cat([static_embed.mean(2), truck_embed.mean(2), drone_embed.mean(2)], dim=1)
#         x = self.relu(self.fc1(combined))
#         x = self.dropout(x)
#         x = self.relu(self.fc2(x))
#         x = self.dropout(x)
#         return self.fc3(x).squeeze(-1)

# class PPOConfig:
#     """Configuration CHỈ DÀNH CHO TRAINING (Hyperparameters)"""
#     def __init__(self):
#         # Environment settings (Training related)
#         self.batch_size = 256
#         self.max_steps = 200
        
#         # Model Architecture
#         self.hidden_size = 128
#         self.num_layers = 1
#         self.dropout = 0.2
        
#         # PPO Hyperparameters
#         self.lr_actor = 1e-4
#         self.lr_critic = 1e-4
#         self.gamma = 0.99
#         self.gae_lambda = 0.95
#         self.clip_epsilon = 0.2
#         self.entropy_coef = 0.01
#         self.value_loss_coef = 0.1
#         self.max_grad_norm = 0.5
        
#         # Training loop
#         self.num_epochs = 50
#         self.update_epochs = 4
#         self.rollout_steps = 2048
#         self.update_batch_size = 64
#         self.minibatch_size = 512    
        
#         # Target KL (optional)
#         self.target_kl = 0.03        
        
#         # Logging
#         self.log_interval = 1
#         self.save_interval = 1
#         self.eval_interval = 1
        
#         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         self.checkpoint_dir = "checkpoints"
#         self.log_dir = "logs"
    

# class PPOTrainer:
#     def __init__(self, ppo_config, sys_config): 
#         self.config = ppo_config
#         self.sys_config = sys_config 
#         self.device = ppo_config.device
        
#         os.makedirs(ppo_config.checkpoint_dir, exist_ok=True)
#         os.makedirs(ppo_config.log_dir, exist_ok=True)
        
#         self._init_networks()
        
#         self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=ppo_config.lr_actor)
#         self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=ppo_config.lr_critic)
        
#         self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optimizer, step_size=200, gamma=0.95)
#         self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optimizer, step_size=200, gamma=0.95)
        
#         # Initialize DataLoader
#         self.dataloader = get_rl_dataloader(
#             batch_size=ppo_config.batch_size, 
#             device=self.device
#         )
        
#         self.env = MOPVRPEnvironment(self.sys_config, self.dataloader, device=self.device)
        
#         self.buffer = RolloutBuffer()
#         timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#         self.writer = SummaryWriter(f"{ppo_config.log_dir}/ppo_{timestamp}")
        
#         self.episode_rewards = deque(maxlen=100)
#         self.episode_lengths = deque(maxlen=100)
#         self.best_reward = float('-inf')
#         self.total_steps = 0
#         self.num_updates = 0
        
#     def _init_networks(self):
#         self.actor = MOPVRP_Actor(4, 2, 4, self.config.hidden_size, self.config.num_layers, self.config.dropout).to(self.device)
#         self.critic = Critic(4, 2, 4, self.config.hidden_size).to(self.device)
#         self._init_weights(self.actor)
#         self._init_weights(self.critic)
        
#     def _init_weights(self, module):
#         for m in module.modules():
#             if isinstance(m, (nn.Linear, nn.Conv1d)):
#                 nn.init.xavier_uniform_(m.weight)
#                 if m.bias is not None: nn.init.zeros_(m.bias)

    
#     def select_action(self, state, last_hh=None, deterministic=False):
#         """
#         Select action using current policy
#         Returns: vehicle_idx, node_idx, logprob_veh, logprob_node, last_hh
#         """
#         static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
        
#         if mask_cust.sum(dim=1).eq(0).any():
#             # Clone để không ảnh hưởng đến state gốc
#             mask_cust = mask_cust.clone()
#             # Tìm các dòng có tổng = 0
#             zero_mask_indices = mask_cust.sum(dim=1) == 0
#             # Mở Node 0 (Depot) cho các dòng đó
#             mask_cust[zero_mask_indices, 0] = 1

#         with torch.no_grad():
#             # Get probabilities from actor
#             veh_probs, node_probs, last_hh = self.actor(
#                 static, dyn_truck, dyn_drone,
#                 decoder_input=None,
#                 last_hh=last_hh,
#                 mask_customers=mask_cust,
#                 mask_vehicles=mask_veh
#             )
        
#         if torch.isnan(node_probs).any() or (node_probs.sum(dim=1) == 0).any():
#             # Tạo một phân phối mặc định: 100% về Depot (Node 0)
#             fallback_probs = torch.zeros_like(node_probs)
#             fallback_probs[:, 0] = 1.0
            
#             # Tìm các dòng bị lỗi (NaN hoặc Sum=0)
#             invalid_rows = torch.isnan(node_probs).any(dim=1) | (node_probs.sum(dim=1) == 0)
            
#             # Gán đè phân phối mặc định vào các dòng lỗi
#             node_probs[invalid_rows] = fallback_probs[invalid_rows]

#         # Tương tự cho Vehicle Probs (Phòng hờ)
#         if torch.isnan(veh_probs).any() or (veh_probs.sum(dim=1) == 0).any():
#             fallback_veh = torch.zeros_like(veh_probs)
#             fallback_veh[:, 0] = 1.0 # Chọn xe đầu tiên
#             invalid_rows_veh = torch.isnan(veh_probs).any(dim=1) | (veh_probs.sum(dim=1) == 0)
#             veh_probs[invalid_rows_veh] = fallback_veh[invalid_rows_veh]

#         if deterministic:
#             # Greedy selection
#             veh_idx = torch.argmax(veh_probs, dim=1)
#             node_idx = torch.argmax(node_probs, dim=1)
#         else:
#             # Stochastic sampling
#             veh_dist = torch.distributions.Categorical(veh_probs)
#             node_dist = torch.distributions.Categorical(node_probs)
            
#             veh_idx = veh_dist.sample()
#             node_idx = node_dist.sample()
        
#         # Calculate log probabilities
#         logprob_veh = torch.log(veh_probs.gather(1, veh_idx.unsqueeze(1)) + 1e-10).squeeze(1)
#         logprob_node = torch.log(node_probs.gather(1, node_idx.unsqueeze(1)) + 1e-10).squeeze(1)
        
#         return veh_idx, node_idx, logprob_veh, logprob_node, last_hh

#     def compute_returns_and_advantages(self, rewards, values, dones):
#         returns, advantages, gae = [], [], 0
#         rewards, values, dones = np.array(rewards), np.array(values), np.array(dones)
#         for t in reversed(range(len(rewards))):
#             next_value = 0 if t == len(rewards) - 1 else values[t + 1]
#             delta = rewards[t] + self.config.gamma * next_value * (1 - dones[t]) - values[t]
#             gae = delta + self.config.gamma * self.config.gae_lambda * (1 - dones[t]) * gae
#             advantages.insert(0, gae)
#             returns.insert(0, gae + values[t])
#         return np.array(returns), np.array(advantages)

#     def collect_rollout(self):
#         """Collect rollout data với Progress Bar"""
#         self.buffer.clear()
#         state = self.env.reset()
#         last_hh = None
        
#         episode_reward = 0
#         episode_length = 0
        
#         pbar = tqdm(range(self.config.rollout_steps), desc="🔄 Collecting Rollout", leave=False)
        
#         for step in pbar:
#             # Select action
#             veh_idx, node_idx, logprob_veh, logprob_node, last_hh = self.select_action(state, last_hh)
            
#             # --- Logic Valid Mask & Fallback ---
#             valid_mask = self.env.get_valid_customer_mask(veh_idx)
#             invalid_nodes = (valid_mask.gather(1, node_idx.unsqueeze(1)) == 0).squeeze(1)
#             if invalid_nodes.any():
#                 node_idx = torch.where(invalid_nodes, torch.zeros_like(node_idx), node_idx)
            
#             static, dyn_truck, dyn_drone, _, _ = state
#             with torch.no_grad(): value = self.critic(static, dyn_truck, dyn_drone)
            
#             next_state, reward, done, _ = self.env.step(veh_idx, node_idx)
            
#             # Store in buffer
#             self.buffer.states.append(state)
#             self.buffer.actions_veh.append(veh_idx)
#             self.buffer.actions_node.append(node_idx)
#             self.buffer.logprobs_veh.append(logprob_veh)
#             self.buffer.logprobs_node.append(logprob_node)
#             self.buffer.rewards.append(reward.cpu().numpy())
#             self.buffer.dones.append(done.cpu().numpy())
#             self.buffer.values.append(value.cpu().numpy())
            
#             # Stats
#             step_reward = reward.mean().item()
#             episode_reward += step_reward
#             episode_length += 1
#             self.total_steps += 1
            
#             pbar.set_postfix({
#                 'reward': f"{step_reward:.2f}", 
#                 'ep_len': f"{episode_length}"
#             })
            
#             if done.any():
#                 self.episode_rewards.append(episode_reward)
#                 self.episode_lengths.append(episode_length)
#                 state = self.env.reset()
#                 last_hh = None
#                 episode_reward = 0
#                 episode_length = 0
#             else:
#                 state = next_state
            
#             if len(self.buffer) >= self.config.rollout_steps:
#                 break
        
#         pbar.close()

#     def update_policy(self):
#         """Update policy với Progress Bar hiển thị Loss"""
#         returns, advantages = self.compute_returns_and_advantages(
#             self.buffer.rewards, self.buffer.values, self.buffer.dones
#         )
        
#         advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
#         returns = torch.FloatTensor(returns).to(self.device)
#         advantages = torch.FloatTensor(advantages).to(self.device)

        
#         dashboard_area = Output()
        
#         pbar = tqdm(range(self.config.update_epochs), desc="🧠 Updating Policy", leave=False)
        
#         display(dashboard_area)
        
#         total_loss_log = 0
        
#         for epoch_i in pbar:
#             epoch_actor_loss = 0
#             epoch_critic_loss = 0
#             epoch_backprop_loss = 0
#             epoch_approx_kl = 0     
#             epoch_explained_var = 0 
            
#             # for i in range(0, len(self.buffer), self.update_batch_size):
#                 # state = self.buffer.states[i]
#                 # old_logprob_veh = self.buffer.logprobs_veh[i]
#                 # old_logprob_node = self.buffer.logprobs_node[i]
#                 # action_veh = self.buffer.actions_veh[i]
#                 # action_node = self.buffer.actions_node[i]
                
#                 # static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
                
#                 # if mask_cust.sum(dim=1).eq(0).any():
#                 #     mask_cust = mask_cust.clone()
#                 #     mask_cust[mask_cust.sum(dim=1) == 0, 0] = 1

#             indices = np.random.permutation(buffer_size)
        
#             # ✅ Loop qua mini-batches (4 batches × 4 epochs = 16 updates)
#             for start in range(0, len(self.buffer), self.update_batch_size):
#                 end = min(start + minibatch_size, buffer_size)
#                 mb_indices = indices[start:end]
                
#                 # Lấy mini-batch data
#                 mb_states = [self.buffer.states[i] for i in mb_indices]
#                 mb_actions_veh = torch.cat([self.buffer.actions_veh[i] for i in mb_indices])
#                 mb_actions_node = torch.cat([self.buffer.actions_node[i] for i in mb_indices])
#                 mb_old_logprob_veh = torch.cat([self.buffer.logprobs_veh[i] for i in mb_indices])
#                 mb_old_logprob_node = torch.cat([self.buffer.logprobs_node[i] for i in mb_indices])
#                 mb_returns = returns[mb_indices]
#                 mb_advantages = advantages[mb_indices]
                
#                 # ✅ Forward pass cho TOÀN BỘ mini-batch cùng lúc
#                 # Concatenate states từ các steps
#                 static_list = [s[0] for s in mb_states]
#                 dyn_truck_list = [s[1] for s in mb_states]
#                 dyn_drone_list = [s[2] for s in mb_states]
#                 mask_cust_list = [s[3] for s in mb_states]
#                 mask_veh_list = [s[4] for s in mb_states]
                
#                 static = torch.cat(static_list, dim=0)
#                 dyn_truck = torch.cat(dyn_truck_list, dim=0)
#                 dyn_drone = torch.cat(dyn_drone_list, dim=0)
#                 mask_cust = torch.cat(mask_cust_list, dim=0)
#                 mask_veh = torch.cat(mask_veh_list, dim=0)
            
                
#                 veh_probs, node_probs, _ = self.actor(
#                     static, dyn_truck, dyn_drone, None, None, mask_cust, mask_veh
#                 )

#                 # [FIX] RE-APPLY CONDITIONAL MASKING
#                 # Tái hiện logic chặn Drone vào Truck-Only để tính Loss đúng
#                 # refined_mask = self.env.get_valid_customer_mask(action_veh)
                
#                 # node_probs = node_probs_raw.clone()
#                 # node_probs[refined_mask == 0] = 0.0
                
#                 # prob_sum = node_probs.sum(dim=1, keepdim=True)
#                 # zero_sum = (prob_sum.squeeze(1) == 0)
#                 # if zero_sum.any():
#                 #     node_probs[zero_sum, 0] = 1.0
#                 #     prob_sum[zero_sum] = 1.0
                
#                 # node_probs = node_probs / prob_sum
#                 # =================================================================
#                 # 🖥️ LIVE DASHBOARD (Cập nhật liên tục 10 bước/lần)
#                 # =================================================================
#                 if i % 20 == 0: # Chỉ cập nhật mỗi 10 bước để đỡ giật màn hình
#                     from IPython.display import clear_output
#                     with dashboard_area:
#                         clear_output(wait=True) # Xóa màn hình cũ
                        
#                         # Lấy mẫu đầu tiên để hiển thị
#                         sample_idx = random.randint(0, self.config.batch_size)
                        
#                         # 1. Xử lý dữ liệu để in
#                         v_probs = veh_probs[sample_idx].detach().cpu().numpy()
#                         n_probs = node_probs[sample_idx].detach().cpu().numpy()
                        
#                         # Tìm Top 10 Node có xác suất cao nhất (để in cho gọn)
#                         top_node_indices = np.argsort(n_probs)[:][::-1]
                        
#                         # 2. In giao diện Dashboard
#                         print(f"╔══════════════════════════════════════════════════════════════╗")
#                         print(f"║  🚀 TRAINING MONITORING | Epoch {epoch_i+1} | Batch {i} | Vehicle Nums: {len(v_probs)} | Current position: {sample_idx}  ║")
#                         print(f"╠══════════════════════════════════════════════════════════════╣")
#                         # Chèn vào trong đoạn debug dashboard
#                         print(f"DEBUG: Num Trucks = {len(dyn_truck[sample_idx][0])}, Num Drones = {len(dyn_drone[sample_idx][0])}")
#                         print(f"╠══════════════════════════════════════════════════════════════╣")
                        
#                         # In xác suất Xe
#                         print(f"║  🚛 VEHICLE PROBABILITIES:                                   ")
#                         print(f"║     Trucks: {v_probs[:len(dyn_truck[sample_idx][0])].round(4)}")
#                         print(f"║     Drones: {v_probs[len(dyn_truck[sample_idx][0]):].round(4)}")
#                         print(f"╠══════════════════════════════════════════════════════════════╣")
                

                
#                 if torch.isnan(node_probs).any() or (node_probs.sum(dim=1) == 0).any():
#                     fallback = torch.zeros_like(node_probs); fallback[:, 0] = 1.0
#                     inv = torch.isnan(node_probs).any(dim=1) | (node_probs.sum(dim=1) == 0)
#                     node_probs[inv] = fallback[inv]

#                 # Tính toán Loss 
#                 new_logprob_veh = torch.log(veh_probs.gather(1, action_veh.unsqueeze(1)) + 1e-10).squeeze(1)
#                 new_logprob_node = torch.log(node_probs.gather(1, action_node.unsqueeze(1)) + 1e-10).squeeze(1)


#                 log_ratio = (new_logprob_veh + new_logprob_node) - (old_logprob_veh + old_logprob_node)
#                 ratio = torch.exp(log_ratio)
#                 # ratio = torch.exp(new_logprob_veh - old_logprob_veh) * torch.exp(new_logprob_node - old_logprob_node)
#                 adv = advantages[i].expand_as(ratio)
#                 surr1 = ratio * adv
#                 surr2 = torch.clamp(ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon) * adv
#                 actor_loss = -torch.min(surr1, surr2).mean()
                
#                 entropy = -((veh_probs * torch.log(veh_probs + 1e-10)).sum(1) + (node_probs * torch.log(node_probs + 1e-10)).sum(1)).mean()
                
#                 value_pred = self.critic(static, dyn_truck, dyn_drone)
#                 critic_loss = nn.MSELoss()(value_pred, returns[i].expand_as(value_pred))
                
#                 loss = actor_loss - self.config.entropy_coef * entropy + self.config.value_loss_coef * critic_loss

#                 # --- [NEW] CALCULATE METRICS ---
#                 with torch.no_grad():
#                     # 1. Approx KL Divergence (http://joschu.net/blog/kl-approx.html)
#                     # Công thức chuẩn hơn: (ratio - 1) - log(ratio)
#                     approx_kl = ((ratio - 1) - log_ratio).mean()
                    
#                     # 2. Explained Variance
#                     # 1 - Var(y_true - y_pred) / Var(y_true)
#                     y_true = returns[i]
#                     y_pred = value_pred
#                     var_y = torch.var(y_true)
#                     if var_y == 0:
#                         explained_var = torch.tensor(0.0)
#                     else:
#                         explained_var = 1 - torch.var(y_true - y_pred) / var_y
                
#                 self.actor_optimizer.zero_grad()
#                 self.critic_optimizer.zero_grad()
#                 loss.backward()
#                 nn.utils.clip_grad_norm_(self.actor.parameters(), self.config.max_grad_norm)
#                 nn.utils.clip_grad_norm_(self.critic.parameters(), self.config.max_grad_norm)
#                 self.actor_optimizer.step()
#                 self.critic_optimizer.step()
                
#                 epoch_actor_loss += actor_loss.item()
#                 epoch_critic_loss += critic_loss.item()
#                 epoch_backprop_loss += loss.item()
#                 epoch_approx_kl += approx_kl.item()
#                 epoch_explained_var += explained_var.item()
            
#             avg_act_loss = epoch_actor_loss / len(self.buffer)
#             avg_cri_loss = epoch_critic_loss / len(self.buffer)
#             avg_backprop_loss = epoch_backprop_loss / len(self.buffer)
#             avg_kl = epoch_approx_kl / len(self.buffer)
#             avg_ev = epoch_explained_var / len(self.buffer)
#             # pbar.set_postfix({'ActLoss': f"{avg_act_loss:.3f}", 'CriLoss': f"{avg_cri_loss:.3f}", 'BackPropLoss': f"{float(avg_backprop_loss):.3f}"})
#             pbar.set_postfix({
#                 'Act': f"{avg_act_loss:.2f}", 
#                 'Cri': f"{avg_cri_loss:.1f}", 
#                 'BackPropLoss': f"{float(avg_backprop_loss):.3f}",
#                 'KL': f"{avg_kl:.4f}",
#                 'EV': f"{avg_ev:.2f}"
#             })
#             total_loss_log += (avg_act_loss + avg_cri_loss)
#             self.writer.add_scalar('Loss/Actor', avg_act_loss, self.num_updates)
#             self.writer.add_scalar('Loss/Critic', avg_cri_loss, self.num_updates)
#             self.writer.add_scalar('Loss/Approx_KL', avg_kl, self.num_updates)
#             self.writer.add_scalar('Loss/Explained_Var', avg_ev, self.num_updates)

            
#         self.num_updates += 1
#         self.writer.add_scalar('Loss/Total', total_loss_log / self.config.update_epochs, self.num_updates)

    



#     def train(self):
#         print(f"🚀 Starting PPO Training on {self.device}")
#         start_time = time.time()
#         for epoch in range(self.config.num_epochs):
#             epoch_start = time.time()
#             self.collect_rollout()
#             self.update_policy()
#             self.actor_scheduler.step()
#             self.critic_scheduler.step()
            
#             if (epoch + 1) % self.config.log_interval == 0:
#                 avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0
#                 print(f"Epoch {epoch+1}: Avg Reward {avg_r:.4f} | Time: {time.time()-epoch_start:.2f}s")
#                 self.writer.add_scalar('Reward/Average', avg_r, epoch)
            
#             if (epoch + 1) % self.config.save_interval == 0:
#                 self.save_checkpoint(epoch + 1)

#     def evaluate(self, num_episodes=10):
#         """Evaluate current policy"""
#         self.actor.eval()
#         self.critic.eval()
        
#         eval_rewards = []
        
#         with torch.no_grad():
#             for _ in range(num_episodes):
#                 state = self.env.reset()
#                 last_hh = None
#                 episode_reward = 0
#                 done_flag = False
                
#                 for _ in range(self.config.max_steps):
#                     if done_flag:
#                         break
                    
#                     # Select action (deterministic)
#                     veh_idx, node_idx, _, _, last_hh = self.select_action(
#                         state, last_hh, deterministic=True
#                     )
                    
#                     # Validate action
#                     static, _, _, _, _ = state
#                     valid_mask = self.env.get_valid_customer_mask(veh_idx)
#                     invalid = (valid_mask.gather(1, node_idx.unsqueeze(1)) == 0).squeeze(1)
#                     if invalid.any():
#                         node_idx = torch.where(invalid, torch.zeros_like(node_idx), node_idx)
                    
#                     # Step
#                     state, reward, done, _ = self.env.step(veh_idx, node_idx)
#                     episode_reward += reward.mean().item()
                    
#                     if done.any():
#                         done_flag = True
                
#                 eval_rewards.append(episode_reward)
        
#         self.actor.train()
#         self.critic.train()
        
#         return np.mean(eval_rewards)
    
#     def save_checkpoint(self, epoch, is_best=False):
#         """Save model checkpoint"""
#         # Create a serializable config dict
#         config_dict = {
#             'batch_size': self.config.batch_size,
#             'max_steps': self.config.max_steps,
#             'hidden_size': self.config.hidden_size,
#             'num_layers': self.config.num_layers,
#             'dropout': self.config.dropout,
#             'lr_actor': self.config.lr_actor,
#             'lr_critic': self.config.lr_critic,
#             'gamma': self.config.gamma,
#             'gae_lambda': self.config.gae_lambda,
#             'clip_epsilon': self.config.clip_epsilon,
#             'entropy_coef': self.config.entropy_coef,
#             'value_loss_coef': self.config.value_loss_coef,
#             'max_grad_norm': self.config.max_grad_norm,
#             'drone_speed': self.sys_config.drone_speed,
#             'drone_max_energy': self.sys_config.drone_max_energy,
#             't_takeoff': self.sys_config.t_takeoff,
#             't_landing': self.sys_config.t_landing,
#             'drone_params': self.sys_config.drone_params,
#         }
        
#         checkpoint = {
#             'epoch': epoch,
#             'actor_state_dict': self.actor.state_dict(),
#             'critic_state_dict': self.critic.state_dict(),
#             'actor_optimizer': self.actor_optimizer.state_dict(),
#             'critic_optimizer': self.critic_optimizer.state_dict(),
#             'best_reward': self.best_reward,
#             'config': config_dict  # Save as dict instead of object
#         }
        
#         if is_best:
#             path = os.path.join(self.config.checkpoint_dir, 'best_model.pth')
#         else:
#             path = os.path.join(self.config.checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
        
#         torch.save(checkpoint, path)
#         print(f"💾 Checkpoint saved: {path}")
    
#     def load_checkpoint(self, path):
#         """Load model checkpoint"""
#         checkpoint = torch.load(path, map_location=self.device, weights_only=False)
        
#         self.actor.load_state_dict(checkpoint['actor_state_dict'])
#         self.critic.load_state_dict(checkpoint['critic_state_dict'])
#         self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
#         self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer'])
#         self.best_reward = checkpoint['best_reward']
        
#         print(f"✅ Checkpoint loaded from {path}")
#         return checkpoint['epoch']


# def main():
#     sys_config = SystemConfig('Truck_config.json', 'drone_linear_config.json', drone_type="1")
    
#     ppo_config = PPOConfig()
    
#     trainer = PPOTrainer(ppo_config, sys_config)
    
#     trainer.train()

# if __name__ == "__main__":
#     main()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from collections import deque
import time
import os
import json
from datetime import datetime
from tqdm import tqdm 

from ipywidgets import Output
from IPython.display import display, clear_output
import random

# from config import SystemConfig
# from environment import MOPVRPEnvironment
# from model import MOPVRP_Actor
# from dataloader import get_rl_dataloader
# from visualizer import visualize_mopvrp

class RolloutBuffer:
    """Buffer để lưu trữ experience cho PPO (Giữ nguyên)"""
    def __init__(self):
        self.states = []
        self.actions_veh = []
        self.actions_node = []
        self.logprobs_veh = []
        self.logprobs_node = []
        self.rewards = []
        self.dones = []
        self.values = []
        self.masks_cust = []
        self.masks_veh = []
        
        
    def clear(self):
        self.states = []
        self.actions_veh = []
        self.actions_node = []
        self.logprobs_veh = []
        self.logprobs_node = []
        self.rewards = []
        self.dones = []
        self.values = []
        self.masks_cust = []
        self.masks_veh = []
        
    
    def __len__(self):
        return len(self.rewards)

class Critic(nn.Module):
    def __init__(self, static_size, dynamic_size_truck, dynamic_size_drone, hidden_size):
        super(Critic, self).__init__()
        self.static_conv = nn.Conv1d(static_size, hidden_size, kernel_size=1)
        self.truck_conv = nn.Conv1d(dynamic_size_truck, hidden_size, kernel_size=1)
        self.drone_conv = nn.Conv1d(dynamic_size_drone, hidden_size, kernel_size=1)
        self.fc1 = nn.Linear(hidden_size * 3, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.fc3 = nn.Linear(hidden_size // 2, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        
    # def forward(self, static, dynamic_trucks, dynamic_drones):
    #     static_embed = self.static_conv(static)
    #     truck_embed = self.truck_conv(dynamic_trucks)
    #     drone_embed = self.drone_conv(dynamic_drones)
    #     combined = torch.cat([static_embed.mean(2), truck_embed.mean(2), drone_embed.mean(2)], dim=1)
    #     x = self.relu(self.fc1(combined))
    #     x = self.dropout(x)
    #     x = self.relu(self.fc2(x))
    #     x = self.dropout(x)
    #     return self.fc3(x).squeeze(-1)

    def forward(self, static, dynamic_trucks, dynamic_drones):
        # 1. Thêm ReLU cho phần Embedding để trích xuất tính chất phi tuyến
        static_embed = self.relu(self.static_conv(static))
        truck_embed = self.relu(self.truck_conv(dynamic_trucks))
        drone_embed = self.relu(self.drone_conv(dynamic_drones))
        
        # 2. Global Average Pooling (Giữ nguyên logic của bạn)
        combined = torch.cat([
            static_embed.mean(2), 
            truck_embed.mean(2), 
            drone_embed.mean(2)
        ], dim=1)
        
        # 3. Các lớp Fully Connected
        x = self.relu(self.fc1(combined))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        
        # 4. Lớp Output: Tuyệt đối không có Activation
        # Squeeze dim=1 để đảm bảo luôn giữ lại dim của Batch
        return self.fc3(x).squeeze(1)

class PPOConfig:
    """Configuration CHỈ DÀNH CHO TRAINING (Hyperparameters)"""
    def __init__(self):
        # Environment settings (Training related)
        self.batch_size = 256
        self.max_steps = 1000
        
        # Model Architecture
        self.hidden_size = 128
        self.num_layers = 1
        self.dropout = 0.2
        
        # PPO Hyperparameters
        self.lr_actor = 1e-5
        self.lr_critic = 5e-5
        self.gamma = 0.9
        self.gae_lambda = 0.95
        self.clip_epsilon = 0.2
        self.entropy_coef = 0.09
        self.value_loss_coef = 0.5
        self.max_grad_norm = 0.5
        
        # Training loop
        self.num_epochs = 50
        self.update_epochs = 2
        self.rollout_steps = 2048
        self.update_batch_size = 64
        self.minibatch_size = 256    
        
        # Target KL (optional)
        self.target_kl = 0.03        
        
        # Logging
        self.log_interval = 1
        self.save_interval = 1
        self.eval_interval = 1
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.checkpoint_dir = "checkpoints"
        # self.checkpoint_dir = "/Users/nguyentrithanh/Documents/20251/Project3/IT3940E-RL_for_MOVRP/ptr-net/checkpoints"
        self.log_dir = "logs"
    

class PPOTrainer:
    def __init__(self, ppo_config, sys_config): 
        self.config = ppo_config
        self.sys_config = sys_config 
        self.device = ppo_config.device
        
        os.makedirs(ppo_config.checkpoint_dir, exist_ok=True)
        os.makedirs(ppo_config.log_dir, exist_ok=True)
        
        self._init_networks()
        
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=ppo_config.lr_actor)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=ppo_config.lr_critic)
        
        self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optimizer, step_size=200, gamma=0.95)
        self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optimizer, step_size=200, gamma=0.95)
        
        # Initialize DataLoader
        self.dataloader = get_rl_dataloader(
            batch_size=ppo_config.batch_size, 
            device=self.device
        )
        
        self.env = MOPVRPEnvironment(self.sys_config, self.dataloader, device=self.device)
        
        self.buffer = RolloutBuffer()
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.writer = SummaryWriter(f"{ppo_config.log_dir}/ppo_{timestamp}")
        
        self.episode_rewards = deque(maxlen=1000)
        self.episode_lengths = deque(maxlen=1000)
        self.best_reward = float('-inf')
        self.total_steps = 0
        self.num_updates = 0
        
    # def _init_networks(self):
    #     self.actor = MOPVRP_Actor(4, 2, 4, self.config.hidden_size, self.config.num_layers, self.config.dropout).to(self.device)
    #     self.critic = Critic(4, 2, 4, self.config.hidden_size).to(self.device)
    #     self._init_weights(self.actor)
    #     self._init_weights(self.critic)

    def _init_networks(self):
        print(f"🛠️ Initializing Networks on {self.device}...")
        
        # 1. Khởi tạo Actor (Hierarchical Architecture)
        # Tham số: (static, truck, drone, hidden, dropout)
        self.actor = MOPVRP_Actor(
            static_size=4, 
            dynamic_size_truck=2, 
            dynamic_size_drone=4, 
            hidden_size=self.config.hidden_size,
            dropout=self.config.dropout
        ).to(self.device)
        
        # 2. Khởi tạo Critic (Global Architecture - Giữ nguyên)
        self.critic = Critic(
            static_size=4, 
            dynamic_size_truck=2, 
            dynamic_size_drone=4, 
            hidden_size=self.config.hidden_size
        ).to(self.device)
        
        # 3. Khởi tạo Trọng số (Weight Initialization)
        
        # Với Critic: Dùng khởi tạo Xavier đều (Uniform) để giá trị Value không bị bùng nổ
        self._init_weights(self.critic)
        
        # Với Actor: Dùng chiến thuật "High Variance" đặc biệt cho Hierarchical Model
        # Giúp phá vỡ tính đối xứng, bắt buộc model phải "chọn bừa" một hướng đi dứt khoát
        # thay vì chia đều xác suất cho tất cả các xe/node.
        # self.actor._init_weights_high_variance()
        
        print("✅ Networks initialized successfully.")
        print("   - Actor: MOPVRP_HierarchicalActor (Pairwise Attention)")
        print("   - Critic: Standard Critic (Global Pooling)")
        
    def _init_weights(self, module):
        for m in module.modules():
            if isinstance(m, (nn.Linear, nn.Conv1d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    
    def select_action(self, state, last_hh=None, deterministic=False):
        """
        Select action using current policy
        Returns: vehicle_idx, node_idx, logprob_veh, logprob_node, last_hh
        """
        static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
        
        if mask_cust.sum(dim=1).eq(0).any():
            # Clone để không ảnh hưởng đến state gốc
            mask_cust = mask_cust.clone()
            # Tìm các dòng có tổng = 0
            zero_mask_indices = mask_cust.sum(dim=1) == 0
            # Mở Node 0 (Depot) cho các dòng đó
            mask_cust[zero_mask_indices, 0] = 1

        with torch.no_grad():
            # Get probabilities from actor
            veh_probs, node_probs, internal_veh_idx, last_hh = self.actor(
                static, dyn_truck, dyn_drone,
                decoder_input=None,
                last_hh=last_hh,
                mask_customers=mask_cust,
                mask_vehicles=mask_veh
            )
        
        if torch.isnan(node_probs).any() or (node_probs.sum(dim=1) == 0).any():
            # Tạo một phân phối mặc định: 100% về Depot (Node 0)
            fallback_probs = torch.zeros_like(node_probs)
            fallback_probs[:, 0] = 1.0
            
            # Tìm các dòng bị lỗi (NaN hoặc Sum=0)
            invalid_rows = torch.isnan(node_probs).any(dim=1) | (node_probs.sum(dim=1) == 0)
            
            # Gán đè phân phối mặc định vào các dòng lỗi
            node_probs[invalid_rows] = fallback_probs[invalid_rows]

        # Tương tự cho Vehicle Probs (Phòng hờ)
        if torch.isnan(veh_probs).any() or (veh_probs.sum(dim=1) == 0).any():
            fallback_veh = torch.zeros_like(veh_probs)
            fallback_veh[:, 0] = 1.0 # Chọn xe đầu tiên
            invalid_rows_veh = torch.isnan(veh_probs).any(dim=1) | (veh_probs.sum(dim=1) == 0)
            veh_probs[invalid_rows_veh] = fallback_veh[invalid_rows_veh]

        veh_idx = internal_veh_idx
        if deterministic:
            # Greedy selection
            # veh_idx = torch.argmax(veh_probs, dim=1)
            
            node_idx = torch.argmax(node_probs, dim=1)
        else:
            # Stochastic sampling
            veh_dist = torch.distributions.Categorical(veh_probs)
            node_dist = torch.distributions.Categorical(node_probs)
            
            # veh_idx = veh_dist.sample()
            node_idx = node_dist.sample()
        
        # Calculate log probabilities
        logprob_veh = torch.log(veh_probs.gather(1, veh_idx.unsqueeze(1)) + 1e-10).squeeze(1)
        logprob_node = torch.log(node_probs.gather(1, node_idx.unsqueeze(1)) + 1e-10).squeeze(1)
        
        return veh_idx, node_idx, logprob_veh, logprob_node, last_hh

    def compute_returns_and_advantages(self, rewards, values, dones):
        returns, advantages, gae = [], [], 0
        rewards, values, dones = np.array(rewards), np.array(values), np.array(dones)
        for t in reversed(range(len(rewards))):
            next_value = 0 if t == len(rewards) - 1 else values[t + 1]
            delta = rewards[t] + self.config.gamma * next_value * (1 - dones[t]) - values[t]
            gae = delta + self.config.gamma * self.config.gae_lambda * (1 - dones[t]) * gae
            advantages.insert(0, gae)
            returns.insert(0, gae + values[t])
        return np.array(returns), np.array(advantages)
        
    # def _render_dashboard(self, output_widget, step, state, action_veh, action_node, reward, done):
    #     """
    #     Dashboard dạng bảng, căn lề cố định, dễ nhìn hơn.
    #     """
    #     # --- Config & Helpers ---
    #     MAX_LOAD = getattr(self.sys_config, 'truck_capacity', 1.0)
    #     MAX_BATTERY = getattr(self.sys_config, 'drone_max_energy', 1.0)
        
    #     def get_bar(current, max_val, width=10, type="load"):
    #         if max_val <= 0: return "ERR"
    #         ratio = max(0.0, min(1.0, current / max_val))
    #         filled = int(ratio * width)
    #         empty = width - filled
            
    #         # Ký tự block mượt hơn
    #         fill_char = "█" 
    #         empty_char = "░"
            
    #         return f"{fill_char * filled}{empty_char * empty}"

    #     # --- Data Extraction ---
    #     batch_avg_reward = reward.mean().item()
    #     batch_dones = done.sum().item()
        
    #     # Data cho Env #0
    #     env_idx = 0
    #     static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
    #     cur_r = reward[env_idx].item()
        
    #     # Progress
    #     total_nodes = mask_cust.shape[1] - 1
    #     visited = total_nodes - mask_cust[env_idx].sum().item()
    #     prog_pct = visited / total_nodes if total_nodes > 0 else 0
        
    #     # --- RENDER ---
    #     with output_widget:
    #         clear_output(wait=True)
            
    #         # 1. HEADER (Thông tin chung)
    #         print(f"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓")
    #         print(f"┃ 🕹️  STEP {step:04d}/{self.config.rollout_steps} ┃ 🏆 AVG RWD: {batch_avg_reward:>7.3f} ┃ ✅ DONES: {batch_dones:>3d}/256 ┃")
    #         print(f"┣━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┫")
            
    #         # 2. ENV #0 STATUS (Tiến độ cụ thể)
    #         prog_bar = get_bar(visited, total_nodes, width=30)
    #         print(f"┃ 🎯 Env #0 Progress: {prog_bar} {prog_pct*100:>5.1f}% ┃")
    #         print(f"┃ ⚡ Last Action: Vehicle {action_veh[env_idx]:<2} ➝ Node {action_node[env_idx]:<3} (R: {cur_r:.2f})    ┃")
    #         print(f"┠──────────────────────────────────────────────────────────────┨")
            
    #         # 3. TRUCK TABLE
    #         # Header căn lề: ID (3 ký tự), Loc (4), Load Val (6), Bar (10)
    #         print(f"┃ 🚛 TRUCKS   │ Loc  │ Load (kg) │ Capacity Bar             ┃")
    #         num_trucks = dyn_truck.shape[2]
    #         for i in range(num_trucks):
    #             loc = int(dyn_truck[env_idx, 0, i].item())
    #             load = dyn_truck[env_idx, 1, i].item()
    #             bar = get_bar(load, MAX_LOAD, width=12)
                
    #             # {:>3} nghĩa là căn phải, rộng 3 ký tự
    #             print(f"┃    #{i:<2}      │ {loc:>4d} │ {load:>9.1f} │ [{bar}] ┃")

    #         # 4. DRONE TABLE (Nếu có)
    #         num_drones = dyn_drone.shape[2]
    #         if num_drones > 0:
    #             print(f"┠──────────────────────────────────────────────────────────────┨")
    #             print(f"┃ 🚁 DRONES   │ Loc  │ Battery   │ Energy Level             ┃")
    #             for i in range(num_drones):
    #                 loc = int(dyn_drone[env_idx, 0, i].item())
    #                 bat = dyn_drone[env_idx, 2, i].item()
    #                 bar = get_bar(bat, MAX_BATTERY, width=12)
                    
    #                 # Cảnh báo pin yếu
    #                 icon = "⚠️" if (bat) < 0.2 else " "
                    
    #                 print(f"┃    #{i:<2} {icon}    │ {loc:>4d} │ {bat:>9.1f} │ [{bar}] ┃")
            
    #         print(f"┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛")

    # def _render_dashboard(self, output_widget, step, state, action_veh, action_node, reward, done):
    #     """
    #     Dashboard nâng cao: Hiển thị trạng thái Trucks, Drones và Mask của các Node.
    #     """
    #     MAX_LOAD = getattr(self.sys_config, 'truck_capacity', 1.0)
    #     MAX_BATTERY = getattr(self.sys_config, 'drone_max_energy', 1.0)
        
    #     def get_bar(current, max_val, width=10):
    #         if max_val <= 0: return "░" * width
    #         ratio = max(0.0, min(1.0, current / max_val))
    #         filled = int(ratio * width)
    #         return f"{'█' * filled}{'░' * (width - filled)}"

    #     # Extraction
    #     env_idx = 0
    #     static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
    #     batch_avg_reward = reward.mean().item()
        
    #     # --- RENDER ---
    #     with output_widget:
    #         clear_output(wait=True)
    #         print(f"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓")
    #         print(f"┃ 🕹️  STEP {step:04d}/{self.config.rollout_steps} ┃ 🏆 AVG RWD: {batch_avg_reward:>7.3f} ┃")
    #         print(f"┣━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┫")
            
    #         # --- PHẦN 1: CUSTOMER MASK STATUS ---
    #         # Hiển thị 20 node đầu tiên để quan sát mask
    #         m_cust = mask_cust[env_idx].cpu().numpy()
    #         total_nodes = len(m_cust)
    #         visited = total_nodes - m_cust.sum()
            
    #         print(f"┃ 🎯 NODES VISITED: {int(visited)}/{total_nodes}                                 ┃")
    #         mask_str = ""
    #         # Hiển thị biểu tượng: ✅ (Available), ⬛ (Masked/Visited)
    #         for i in range(min(total_nodes, 30)): # Xem 30 node đầu
    #             mask_str += "✅" if m_cust[i] == 1 else "⬛"
    #         print(f"┃ Mask (Top 30): {mask_str} ┃")
    #         print(f"┃   (✅: Available to visit | ⬛: Masked/Already Visited)   ┃")
    #         print(f"┠──────────────────────────────────────────────────────────────┨")
            
    #         # --- PHẦN 2: TRUCKS & DRONES ---
    #         print(f"┃ 🚛 TRUCKS Status:                                           ┃")
    #         for i in range(dyn_truck.shape[2]):
    #             loc = int(dyn_truck[env_idx, 0, i].item())
    #             load = dyn_truck[env_idx, 1, i].item()
    #             print(f"┃  #{i:<2} @Node {loc:>3} | Load: {load:>7.1f} [{get_bar(load, MAX_LOAD, 10)}]  ┃")

    #         if dyn_drone.shape[2] > 0:
    #             print(f"┠──────────────────────────────────────────────────────────────┨")
    #             print(f"┃ 🚁 DRONES Status:                                           ┃")
    #             for i in range(dyn_drone.shape[2]):
    #                 loc = int(dyn_drone[env_idx, 0, i].item())
    #                 bat = dyn_drone[env_idx, 2, i].item()
    #                 print(f"┃  #{i:<2} @Node {loc:>3} | Batt: {bat:>7.1f} [{get_bar(bat, MAX_BATTERY, 10)}]  ┃")
            
    #         print(f"┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛")


    def _render_dashboard(self, output_widget, step, state, action_veh, action_node, reward, done, live_mask):
        env_idx = 0
        static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
        
        # Xác định xe đang hoạt động là Truck hay Drone
        cur_veh_idx = action_veh[env_idx].item()
        is_drone = cur_veh_idx >= self.env.num_trucks
        drone_id = cur_veh_idx - self.env.num_trucks if is_drone else None

        with output_widget:
            clear_output(wait=True)
            print(f"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓")
            print(f"┃ 🕹️  STEP {step:04d} ┃ 🚛 VEHICLE SELECTED: {cur_veh_idx:<2} ({'DRONE' if is_drone else 'TRUCK'}) ┃")
            print(f"┣━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┫")
            
            # Hiển thị Mask thực tế (Live Mask) của phương tiện này
            # ✅: Hợp lệ | ⬛: Đã giao | 🚩: Không đủ pin/Quá tầm (Chỉ Drone mới bị)
            m_global = mask_cust[env_idx].cpu().numpy() # Node chưa giao
            m_live = live_mask[env_idx].cpu().numpy()   # Node xe này ĐI ĐƯỢC
            
            print(f"┃ 📍 LIVE MASK FOR SELECTED VEHICLE (Top 30 Nodes):           ┃")
            view_str = ""
            for i in range(min(len(m_global), 30)):
                if m_global[i] == 0:
                    view_str += "⬛" # Đã xong (Global Mask)
                elif m_live[i] == 1:
                    view_str += "✅" # Xe này có thể tới (Valid)
                else:
                    view_str += "❌" # Chưa xong nhưng xe này KHÔNG THỂ tới (Pin/Distance)
            
            print(f"┃ {view_str} ┃")
            print(f"┃ (✅: OK | ⬛: Visited | ❌: Blocked for this vehicle)     ┃")
            
            # Trạng thái Pin/Tải
            if is_drone:
                bat = dyn_drone[env_idx, 2, drone_id].item()
                print(f"┃ 🔋 DRONE #{drone_id} BATTERY: {bat:.2f}                               ┃")
            
            print(f"┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛")

    def _render_drone_vision(self, output_widget, step, state, all_drone_masks):
        env_idx = 0
        static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
        num_drones = len(all_drone_masks)
        m_global = mask_cust[env_idx].cpu().numpy()

        with output_widget:
            clear_output(wait=True)
            print(f"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓")
            print(f"┃ 🕹️  STEP {step:04d} ┃ 🚁 ALL DRONES MONITORING                           ┃")
            print(f"┣━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┫")
            print(f"┃ Chú thích: ✅ Có thể bay | ⬛ Đã xong | ❌ Quá tầm/Hết pin            ┃")
            print(f"┣━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┫")

            for i in range(num_drones):
                # Thông tin trạng thái của Drone
                bat = dyn_drone[env_idx, 2, i].item()
                loc = int(dyn_drone[env_idx, 0, i].item())
                m_live = all_drone_masks[i][env_idx].cpu().numpy()
                
                # Trạng thái hoạt động (Dựa trên mask_veh)
                status = "ACTIVE" if mask_veh[env_idx, self.env.num_trucks + i] == 1 else "OFFLINE"
                
                # Vẽ dải mask (Top 25 nodes)
                view_str = ""
                for n in range(min(len(m_global), 100)):
                    if m_global[n] == 0:
                        view_str += "⬛" # Đã được giao bởi bất kỳ xe nào
                    elif m_live[n] == 1:
                        view_str += "✅" # Drone này CÓ THỂ bay tới
                    else:
                        view_str += "❌" # Drone này KHÔNG THỂ bay tới (Dù node chưa giao)
                
                print(f"┃ 🚁 D#{i:02d} [{status}] Pin:{bat:5.2f} | Loc:{loc:3d} | {view_str} ┃")

            print(f"┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛")

    def collect_rollout(self, scenario=5):
        """Collect rollout data với Progress Bar"""
        self.buffer.clear()
        self.env.set_scenario(scenario)
        state = self.env.reset()
        last_hh = None
        
        episode_reward = 0
        episode_length = 0
        # --- Dashboard Setup ---
        dashboard_ui = Output()
        display(dashboard_ui)
        # -----------------------
        
        pbar = tqdm(range(self.config.rollout_steps), desc="🔄 Collecting Rollout", leave=False)
        
        for step in pbar:
            # Select action
            veh_idx, node_idx, logprob_veh, logprob_node, last_hh = self.select_action(state, last_hh)
            
            # --- Logic Valid Mask & Fallback ---
            valid_mask = self.env.get_valid_customer_mask(veh_idx)
            invalid_nodes = (valid_mask.gather(1, node_idx.unsqueeze(1)) == 0).squeeze(1) 
            if invalid_nodes.any():
                node_idx = torch.where(invalid_nodes, torch.zeros_like(node_idx), node_idx)
            
            static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
            with torch.no_grad(): value = self.critic(static, dyn_truck, dyn_drone)
            
            next_state, reward, done, _ = self.env.step(veh_idx, node_idx)
            current_veh_mask = self.env.get_valid_customer_mask(veh_idx)
            
            # 1. Lấy mask cho TẤT CẢ Drone (ví dụ xe từ num_trucks đến hết)
            num_trucks = self.env.num_trucks
            num_drones = self.env.num_drones
            all_drone_masks = []
            
            for d_idx in range(num_trucks, num_trucks + num_drones):
                d_mask = self.env.get_valid_customer_mask(torch.tensor([d_idx]).to(self.device))
                all_drone_masks.append(d_mask)

            # --- Cập nhật Dashboard (Update mỗi 2 bước để đỡ lag) ---
            if step % 2 == 0: 
                # self._render_dashboard(dashboard_ui, step, state, veh_idx, node_idx, reward, done, current_veh_mask)
                self._render_drone_vision(dashboard_ui, step, state, all_drone_masks)
            # -------------------------------------------------------
            
            # Store in buffer
            self.buffer.states.append(state)
            self.buffer.actions_veh.append(veh_idx)
            self.buffer.actions_node.append(node_idx)
            self.buffer.logprobs_veh.append(logprob_veh)
            self.buffer.logprobs_node.append(logprob_node)
            self.buffer.rewards.append(reward.cpu().numpy())
            self.buffer.dones.append(done.cpu().numpy())
            self.buffer.values.append(value.cpu().numpy())
            self.buffer.masks_cust.append(mask_cust.cpu().numpy())
            self.buffer.masks_veh.append(mask_veh.cpu().numpy())
            
            # Stats
            step_reward = reward.mean().item()
            episode_reward += step_reward
            episode_length += 1
            self.total_steps += 1
            
            pbar.set_postfix({
                'reward': f"{step_reward:.2f}", 
                'ep_len': f"{episode_length}"
            })
            
            if done.any():
                self.episode_rewards.append(episode_reward)
                self.episode_lengths.append(episode_length)
                state = self.env.reset()
                last_hh = None
                episode_reward = 0
                episode_length = 0
            else:
                state = next_state
            
            if len(self.buffer) >= self.config.rollout_steps:
                break
        
        pbar.close()


    def update_policy(self):
        """Update policy với thêm metrics: Explained Variance & Approx KL"""
        
        # --- Import hiển thị ---
        from ipywidgets import Output
        from IPython.display import display, clear_output
        
        # --- 1. Prepare Data ---
        rewards = self.buffer.rewards
        values = self.buffer.values
        dones = self.buffer.dones
        
        returns, advantages = self.compute_returns_and_advantages(
            rewards, values, dones
        )

        # returns = (raw_returns - raw_returns.mean()) / (raw_returns.std() + 1e-8)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        returns = torch.FloatTensor(returns).to(self.device)
        advantages = torch.FloatTensor(advantages).to(self.device)
        
        # --- 2. Flatten & Stack Data ---
        def pad_and_stack(tensor_list, pad_dim, pad_value=0):
            if not tensor_list: return torch.tensor([])
            max_size = max(t.size(pad_dim) for t in tensor_list)
            padded_list = []
            for t in tensor_list:
                if t.size(pad_dim) == max_size:
                    padded_list.append(t)
                else:
                    pad_shape = list(t.shape); pad_shape[pad_dim] = max_size - t.size(pad_dim)
                    padding = torch.full(pad_shape, pad_value, device=t.device, dtype=t.dtype)
                    padded_list.append(torch.cat([t, padding], dim=pad_dim))
            return torch.stack(padded_list)

        states_transposed = list(zip(*self.buffer.states))
        
        # Flatten dữ liệu: [Time, Batch, ...] -> [Total_Samples, ...]
        t_static = pad_and_stack(states_transposed[0], pad_dim=2).flatten(0, 1)
        t_dyn_truck = pad_and_stack(states_transposed[1], pad_dim=2).flatten(0, 1)
        t_dyn_drone = pad_and_stack(states_transposed[2], pad_dim=2).flatten(0, 1)
        t_mask_cust = pad_and_stack(states_transposed[3], pad_dim=1).flatten(0, 1)
        t_mask_veh = pad_and_stack(states_transposed[4], pad_dim=1).flatten(0, 1)
        
        t_actions_veh = torch.stack(self.buffer.actions_veh).flatten(0, 1)
        t_actions_node = torch.stack(self.buffer.actions_node).flatten(0, 1)
        t_old_logprobs_veh = torch.stack(self.buffer.logprobs_veh).flatten(0, 1)
        t_old_logprobs_node = torch.stack(self.buffer.logprobs_node).flatten(0, 1)
        returns = returns.flatten(0, 1)
        advantages = advantages.flatten(0, 1)

        buffer_size = t_static.size(0)
        batch_size = self.config.batch_size
        
        # --- 3. Dashboard Setup ---
        dashboard_area = Output()
        pbar = tqdm(range(self.config.update_epochs), desc="🧠 Updating Policy", leave=False)
        display(dashboard_area)
        
        total_loss_log = 0
        
        for epoch_i in pbar:
            epoch_actor_loss = 0
            epoch_critic_loss = 0
            epoch_approx_kl = 0     
            epoch_explained_var = 0 
            
            indices = torch.randperm(buffer_size)
            
            for start_idx in range(0, buffer_size, self.config.minibatch_size):
                end_idx = start_idx + batch_size
                mb_idx = indices[start_idx:end_idx]
                
                # Get Mini-batch
                b_static = t_static[mb_idx]
                b_dyn_truck = t_dyn_truck[mb_idx]
                b_dyn_drone = t_dyn_drone[mb_idx]
                b_mask_cust = t_mask_cust[mb_idx]
                b_mask_veh = t_mask_veh[mb_idx]
                b_act_veh = t_actions_veh[mb_idx]
                b_act_node = t_actions_node[mb_idx]
                b_old_log_veh = t_old_logprobs_veh[mb_idx]
                b_old_log_node = t_old_logprobs_node[mb_idx]
                b_adv = advantages[mb_idx]
                b_ret = returns[mb_idx]
                
                # Forward Actor
                # b_depot_xy = b_static[:, 0:2, 0].unsqueeze(2)
                b_depot_xy = b_static[:, 0:2, 0]
                
                veh_probs, node_probs, _, _ = self.actor(
                    b_static, b_dyn_truck, b_dyn_drone, 
                    decoder_input=b_depot_xy, last_hh=None, 
                    mask_customers=b_mask_cust, mask_vehicles=b_mask_veh
                )

                if torch.isnan(node_probs).any():
                    # print(f"⚠️ Warning: NaN detected in node_probs from model!")
                    nan_mask = torch.isnan(node_probs).any(dim=1)
                    num_nan = nan_mask.sum().item()
                    # print(f"   {num_nan} samples affected")
                    
                    # Tạo tensor mới thay vì modify inplace
                    uniform_probs = torch.ones_like(node_probs) / node_probs.shape[1]
                    node_probs = torch.where(
                        nan_mask.unsqueeze(1).expand_as(node_probs),
                        uniform_probs,
                        node_probs
                    )

                # Tương tự cho veh_probs
                if torch.isnan(veh_probs).any():
                    # print(f"⚠️ Warning: NaN detected in veh_probs from model!")
                    nan_mask = torch.isnan(veh_probs).any(dim=1)
                    uniform_probs = torch.ones_like(veh_probs) / veh_probs.shape[1]
                    veh_probs = torch.where(
                        nan_mask.unsqueeze(1).expand_as(veh_probs),
                        uniform_probs,
                        veh_probs
                    )
                
                
                # # node_probs = node_probs_raw.clone()
                # node_probs[b_mask_cust == 0] = 0.0

                # prob_sum = node_probs.sum(dim=1, keepdim=True)
                # zero_sum = (prob_sum.squeeze(1) == 0)
                # if zero_sum.any():
                #     node_probs[zero_sum, 0] = 1.0
                #     prob_sum[zero_sum] = 1.0
                # node_probs = node_probs / prob_sum

                # # --- LOSS CALCULATION ---
                # new_log_veh = torch.log(veh_probs.gather(1, b_act_veh.unsqueeze(1)) + 1e-10).squeeze(1)
                # new_log_node = torch.log(node_probs.gather(1, b_act_node.unsqueeze(1)) + 1e-10).squeeze(1)
                
                # # Tổng Log Prob (Xe + Node)
                # log_ratio = (new_log_veh + new_log_node) - (b_old_log_veh + b_old_log_node)
                # ratio = torch.exp(log_ratio)


                # veh_probs = torch.clamp(veh_probs, 1e-8, 1.0)
                # node_probs = torch.clamp(node_probs, 1e-8, 1.0)

                # # 2. Tạo bản clone để tránh lỗi "Inplace"
                # node_probs_fixed = node_probs.clone()

                # # 3. Tính tổng hàng
                # prob_sum = node_probs_fixed.sum(dim=1, keepdim=True)

                # # 4. Tìm các hàng bị "chết" (tổng quá nhỏ - thường do Mask chặn hết)
                # # Ngưỡng 1e-6 là an toàn
                # invalid_rows = (prob_sum < 1e-6).squeeze()

                # if invalid_rows.any():
                #     # Ép các hàng lỗi phải chọn Node 0 (Depot)
                #     node_probs_fixed[invalid_rows] = 1e-10 # Reset về gần 0
                #     node_probs_fixed[invalid_rows, 0] = 1.0 # Gán 100% cho Depot
                #     # Cập nhật lại tổng cho các hàng vừa sửa để lát nữa chia không bị nan
                #     prob_sum[invalid_rows] = 1.0

                # # 5. Re-normalize chuẩn (không inplace)
                # node_probs_final = node_probs_fixed / prob_sum


                # # 2. Sử dụng Categorical với Probs (PyTorch sẽ tự xử lý ổn định hơn là bạn tự log)
                # dist_veh = torch.distributions.Categorical(probs=veh_probs, validate_args=False)
                # dist_node = torch.distributions.Categorical(probs=node_probs_final, validate_args=False)

                # new_log_veh = dist_veh.log_prob(b_act_veh)
                # new_log_node = dist_node.log_prob(b_act_node)
                # current_log_probs = new_log_veh + new_log_node

                # # 3. Tính KL xấp xỉ và Ratio
                # log_ratio = current_log_probs - (b_old_log_veh + b_old_log_node)
                # # Kẹp log_ratio để tránh exp(log_ratio) bị nổ NaN
                # log_ratio = torch.clamp(log_ratio, -20, 2) 
                # ratio = torch.exp(log_ratio)
                
                # surr1 = ratio * b_adv
                # surr2 = torch.clamp(ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon) * b_adv
                # actor_loss = -torch.min(surr1, surr2).mean()

                # 1. KHÔNG clamp probs trước - để PyTorch xử lý
# 1. Normalize probs đúng cách
                veh_probs = veh_probs / (veh_probs.sum(dim=-1, keepdim=True) + 1e-10)
                node_probs = node_probs / (node_probs.sum(dim=-1, keepdim=True) + 1e-10)

                # 2. Xử lý invalid rows cho node_probs
                node_probs_fixed = node_probs.clone()
                prob_sum = node_probs_fixed.sum(dim=1, keepdim=True)
                invalid_rows = (prob_sum < 1e-6).squeeze()

                if invalid_rows.any():
                    # Reset toàn bộ hàng invalid về one-hot tại depot
                    num_invalid = invalid_rows.sum().item()
                    node_probs_fixed[invalid_rows] = 0.0  # Tất cả về 0
                    node_probs_fixed[invalid_rows, 0] = 1.0  # Depot = 1
                    # print(f"Warning: {num_invalid} invalid rows detected, forced to depot")

                # 3. Re-normalize với epsilon nhỏ hơn
                node_probs_final = node_probs_fixed / (node_probs_fixed.sum(dim=1, keepdim=True) + 1e-10)

                # 4. Clamp để đảm bảo trong [0, 1]
                node_probs_final = torch.clamp(node_probs_final, 0.0, 1.0)

                # 5. Force normalize lại lần cuối
                node_probs_final = node_probs_final / node_probs_final.sum(dim=1, keepdim=True)
                
                import torch.nn.functional as F
                node_probs_final = F.normalize(node_probs_final, p=1, dim=1)


                # 6. Debug check
                prob_sums = node_probs_final.sum(dim=1)

                # Tạo distributions
                dist_veh = torch.distributions.Categorical(probs=veh_probs, validate_args=False)
                dist_node = torch.distributions.Categorical(probs=node_probs_final, validate_args=False)

                # Tính log probs
                new_log_veh = dist_veh.log_prob(b_act_veh)
                new_log_node = dist_node.log_prob(b_act_node)

                # Tìm vị trí NaN
                if torch.isnan(new_log_veh).any():
                    nan_idx = torch.where(torch.isnan(new_log_veh))[0]
                    # print(f"\nNaN in new_log_veh at indices: {nan_idx[:10]}")  # 10 cái đầu
                    for idx in nan_idx[:3]:
                        print(f"  Sample {idx}: action={b_act_veh[idx]}, probs={veh_probs[idx]}")

                if torch.isnan(new_log_node).any():
                    nan_idx = torch.where(torch.isnan(new_log_node))[0]
                    # print(f"\nNaN in new_log_node at indices: {nan_idx[:10]}")
                    # for idx in nan_idx[:3]:
                        # print(f"  Sample {idx}: action={b_act_node[idx]}, probs={node_probs_final[idx]}")

                # 7. Tạo distributions (TẮT validate_args)
                dist_veh = torch.distributions.Categorical(probs=veh_probs, validate_args=False)
                dist_node = torch.distributions.Categorical(probs=node_probs_final, validate_args=False)

                # 3. Tính log probs
                new_log_veh = dist_veh.log_prob(b_act_veh)
                new_log_node = dist_node.log_prob(b_act_node)
                current_log_probs = new_log_veh + new_log_node


                # 6. Tính ratio
                log_ratio = current_log_probs - (b_old_log_veh + b_old_log_node)
                log_ratio = torch.clamp(log_ratio, -10, 10)
                ratio = torch.exp(log_ratio)

                # 7. PPO loss
                surr1 = ratio * b_adv
                surr2 = torch.clamp(ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon) * b_adv
                actor_loss = -torch.min(surr1, surr2).mean()

                # 8. Kiểm tra loss
                assert not torch.isnan(actor_loss), f"Loss is NaN! ratio range: [{ratio.min()}, {ratio.max()}], adv range: [{b_adv.min()}, {b_adv.max()}]"
                
                entropy = -((veh_probs * torch.log(veh_probs + 1e-10)).sum(1) + (node_probs * torch.log(node_probs + 1e-10)).sum(1)).mean()
                
                value_pred = self.critic(b_static, b_dyn_truck, b_dyn_drone)
                critic_loss = nn.MSELoss()(value_pred, b_ret)
                
                loss = actor_loss - self.config.entropy_coef * entropy + self.config.value_loss_coef * critic_loss
                
                # --- CALCULATE METRICS ---
                with torch.no_grad():
                    # 1. Approx KL Divergence (http://joschu.net/blog/kl-approx.html)
                    # (ratio - 1) - log(ratio)
                    approx_kl = ((ratio - 1) - log_ratio).mean()
                    
                    # 2. Explained Variance
                    # 1 - Var(y_true - y_pred) / Var(y_true)
                    y_true = b_ret
                    y_pred = value_pred
                    var_y = torch.var(y_true)
                    if var_y == 0:
                        explained_var = torch.tensor(0.0)
                    else:
                        explained_var = 1 - torch.var(y_true - y_pred) / var_y

                # Backprop
                self.actor_optimizer.zero_grad(); self.critic_optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.config.max_grad_norm)
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.config.max_grad_norm)
                self.actor_optimizer.step(); self.critic_optimizer.step()
                
                # Accumulate Logs
                epoch_actor_loss += actor_loss.item()
                epoch_critic_loss += critic_loss.item()
                epoch_approx_kl += approx_kl.item()
                epoch_explained_var += explained_var.item()

                # --- DASHBOARD UPDATE ---
                if start_idx % 1 == 0: 
                    with dashboard_area:
                        clear_output(wait=True)
                        
                        # In sample ngẫu nhiên
                        sample_idx = 0 # Hoặc random
                        v_probs = veh_probs[sample_idx].detach().cpu().numpy()
                        n_probs = node_probs[sample_idx].detach().cpu().numpy()
                        top_k = 10
                        top_node_indices = np.argsort(n_probs)[-top_k:][::-1]
                        
                        # Đếm số lượng xe thực tế để in cho chuẩn
                        num_total_veh = v_probs.shape[0]
                        num_trucks_cfg = self.env.num_trucks
                        
                        print(f"╔══════════════════════════════════════════════════════════════╗")
                        print(f"║  🚀 TRAINING MONITORING | Epoch {epoch_i+1}/{self.config.update_epochs} | Batch {start_idx}/{buffer_size}      ║")
                        print(f"║  Start  Index: {start_idx}      ║")
                        
                        print(f"╠══════════════════════════════════════════════════════════════╣")
                        print(f"║  📉 Loss Metrics:                                            ")
                        print(f"║     Actor Loss: {actor_loss.item():.4f} | Critic Loss: {critic_loss.item():.4f}")
                        print(f"║     Approx KL : {approx_kl.item():.4f}  (Target: ~0.01)    ")
                        print(f"║     Expl. Var : {explained_var.item():.4f}  (Target: >0.5)     ")
                        print(f"║     Return   : {b_ret[sample_idx]}      ")
                        print(f"╠══════════════════════════════════════════════════════════════╣")
                        print(f"║  🚛 VEHICLES (Total {num_total_veh}):")
                        print(f"║     Trucks: {v_probs[:num_trucks_cfg].round(3)}")
                        if num_total_veh > num_trucks_cfg:
                            print(f"║     Drones: {v_probs[num_trucks_cfg:].round(3)}")
                        else:
                            print(f"║     Drones: [None]")
                        print(f"╠══════════════════════════════════════════════════════════════╣")
                        print(f"║  📍 TOP {top_k} NODE PROBS (Sample {sample_idx}):                   ")
                        for rank, node_idx in enumerate(top_node_indices):
                            prob = n_probs[node_idx]
                            bar_len = int(prob * 20)
                            bar = "█" * bar_len + "░" * (20 - bar_len)
                            is_truck_only = b_static[sample_idx, 3, node_idx] == 1
                            note = "(TO)" if is_truck_only else ""
                            if prob == 0.0: note += " [BLK]"
                            print(f"║  #{rank+1:<2} Node {node_idx:3d} : {prob:.4f} {bar} {note}")
                        print(f"╚══════════════════════════════════════════════════════════════╝")

                        print(f"Value Pred: ", value_pred[sample_idx].item())
                        print(f"Return: ", b_ret[sample_idx])

            # Chèn đoạn này vào bên trong vòng lặp batch của update_policy
            with torch.no_grad():
                avg_pred = value_pred.mean()
                avg_ret = b_ret.mean()
                # So sánh sai số tuyệt đối chưa qua MSE
                abs_error = torch.abs(value_pred - b_ret).mean().item()
            
            print(f"║ 🔍 Critic Monitoring:                                            ")
            print(f"║     Avg Prediction: {avg_pred.item():.4f} | Avg Return: {avg_ret.item():.4f}")
            print(f"║     Mean Abs Error: {abs_error:.4f}                              ")
            print(f"║     Mean Error: {torch.abs(avg_pred - avg_ret):.4f}                              ")

            
            # Tính trung bình Epoch
            num_batches = max(1, (buffer_size + batch_size - 1) // batch_size)
            avg_act = epoch_actor_loss / num_batches
            avg_cri = epoch_critic_loss / num_batches
            avg_kl = epoch_approx_kl / num_batches
            avg_ev = epoch_explained_var / num_batches
            
            # Cập nhật Pbar
            pbar.set_postfix({
                'Act': f"{avg_act:.2f}", 
                'Cri': f"{avg_cri:.2f}", 
                'KL': f"{avg_kl:.4f}",
                'EV': f"{avg_ev:.2f}",
                'Num Update': f"{self.num_updates}"
            })
            
            total_loss_log += (avg_act + avg_cri)
            
            # Log lên Tensorboard
            self.writer.add_scalar('Loss/Actor', avg_act, self.num_updates)
            self.writer.add_scalar('Loss/Critic', avg_cri, self.num_updates)
            self.writer.add_scalar('Loss/Approx_KL', avg_kl, self.num_updates)
            self.writer.add_scalar('Loss/Explained_Var', avg_ev, self.num_updates)
            
        self.num_updates += 1


    def train(self):
        print(f"🚀 Starting PPO Training on {self.device}")
        start_time = time.time()
        for epoch in range(self.config.num_epochs):
            scenario_number = random.randint(0,5) 
            scenario_number = 5
            # print(f"Current Scenario: {scenario_number}")
            epoch_start = time.time()
            self.collect_rollout(scenario_number)
            self.update_policy()
            self.actor_scheduler.step()
            self.critic_scheduler.step()
            
            if (epoch + 1) % self.config.log_interval == 0:
                avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0
                print(f"Epoch {epoch+1}: Avg Reward {avg_r:.4f} | Time: {time.time()-epoch_start:.2f}s")
                self.writer.add_scalar('Reward/Average', avg_r, epoch)
            
            if (epoch + 1) % self.config.save_interval == 0:
                self.save_checkpoint(epoch + 1)

    def evaluate(self, num_episodes=10):
        """Evaluate current policy"""
        self.actor.eval()
        self.critic.eval()
        
        eval_rewards = []
        
        with torch.no_grad():
            for _ in range(num_episodes):
                state = self.env.reset()
                last_hh = None
                episode_reward = 0
                done_flag = False
                
                for _ in range(self.config.max_steps):
                    if done_flag:
                        break
                    
                    # Select action (deterministic)
                    veh_idx, node_idx, _, _, last_hh = self.select_action(
                        state, last_hh, deterministic=True
                    )
                    
                    # Validate action
                    static, _, _, _, _ = state
                    valid_mask = self.env.get_valid_customer_mask(veh_idx)
                    invalid = (valid_mask.gather(1, node_idx.unsqueeze(1)) == 0).squeeze(1)
                    if invalid.any():
                        node_idx = torch.where(invalid, torch.zeros_like(node_idx), node_idx)
                    
                    # Step
                    state, reward, done, _ = self.env.step(veh_idx, node_idx)
                    episode_reward += reward.mean().item()
                    
                    if done.any():
                        done_flag = True
                
                eval_rewards.append(episode_reward)
        
        self.actor.train()
        self.critic.train()
        
        return np.mean(eval_rewards)
    
    def save_checkpoint(self, epoch, is_best=False):
        """Save model checkpoint"""
        # Create a serializable config dict
        config_dict = {
            'batch_size': self.config.batch_size,
            'max_steps': self.config.max_steps,
            'hidden_size': self.config.hidden_size,
            'num_layers': self.config.num_layers,
            'dropout': self.config.dropout,
            'lr_actor': self.config.lr_actor,
            'lr_critic': self.config.lr_critic,
            'gamma': self.config.gamma,
            'gae_lambda': self.config.gae_lambda,
            'clip_epsilon': self.config.clip_epsilon,
            'entropy_coef': self.config.entropy_coef,
            'value_loss_coef': self.config.value_loss_coef,
            'max_grad_norm': self.config.max_grad_norm,
            'drone_speed': self.sys_config.drone_speed,
            'drone_max_energy': self.sys_config.drone_max_energy,
            't_takeoff': self.sys_config.t_takeoff,
            't_landing': self.sys_config.t_landing,
            'drone_params': self.sys_config.drone_params,
        }
        
        checkpoint = {
            'epoch': epoch,
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
            'critic_optimizer': self.critic_optimizer.state_dict(),
            'best_reward': self.best_reward,
            'config': config_dict  # Save as dict instead of object
        }
        
        if is_best:
            path = os.path.join(self.config.checkpoint_dir, 'best_model.pth')
        else:
            path = os.path.join(self.config.checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
        
        torch.save(checkpoint, path)
        print(f"💾 Checkpoint saved: {path}")
    
    def load_checkpoint(self, path):
        """Load model checkpoint"""
        checkpoint = torch.load(path, map_location=self.device, weights_only=False)
        
        self.actor.load_state_dict(checkpoint['actor_state_dict'])
        self.critic.load_state_dict(checkpoint['critic_state_dict'])
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
        self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer'])
        self.best_reward = checkpoint['best_reward']
        
        print(f"✅ Checkpoint loaded from {path}")
        return checkpoint['epoch']


def main():
    sys_config = SystemConfig('Truck_config.json', 'drone_linear_config.json', drone_type="1")
    
    ppo_config = PPOConfig()
    
    trainer = PPOTrainer(ppo_config, sys_config)
    
    trainer.train()

if __name__ == "__main__":
    main()

🛠️ Initializing Networks on cuda...
✅ Networks initialized successfully.
   - Actor: MOPVRP_HierarchicalActor (Pairwise Attention)
   - Critic: Standard Critic (Global Pooling)
🚀 Starting PPO Training on cuda


Output()

🧠 Updating Policy:   0%|          | 0/2 [00:00<?, ?it/s]                                           

Output()

🧠 Updating Policy:  50%|█████     | 1/2 [01:01<01:01, 61.30s/it, Act=0.08, Cri=22232.77, KL=1.0882, EV=-1.12, Num Update=0]

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: 157.1485 | Avg Return: 156.3809
║     Mean Abs Error: 77.1253                              
║     Mean Error: 0.7676                              


                                                                                                                            

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: 160.6104 | Avg Return: 160.4784
║     Mean Abs Error: 77.3605                              
║     Mean Error: 0.1321                              
Epoch 1: Avg Reward -34.9170 | Time: 173.22s
💾 Checkpoint saved: checkpoints/checkpoint_epoch_1.pth




Output()

🧠 Updating Policy:   0%|          | 0/2 [00:00<?, ?it/s]                                           

Output()

🧠 Updating Policy:  50%|█████     | 1/2 [01:01<01:01, 61.56s/it, Act=0.08, Cri=361.08, KL=1.0809, EV=0.02, Num Update=1]

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: 24.8085 | Avg Return: 26.3067
║     Mean Abs Error: 11.6167                              
║     Mean Error: 1.4982                              


                                                                                                                         

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: 25.4631 | Avg Return: 23.8211
║     Mean Abs Error: 11.7791                              
║     Mean Error: 1.6420                              
Epoch 2: Avg Reward -34.9186 | Time: 172.78s
💾 Checkpoint saved: checkpoints/checkpoint_epoch_2.pth




Output()

🧠 Updating Policy:   0%|          | 0/2 [00:00<?, ?it/s]                                           

Output()

🧠 Updating Policy:  50%|█████     | 1/2 [01:01<01:01, 61.38s/it, Act=0.11, Cri=11.25, KL=1.2664, EV=0.23, Num Update=2]

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: 2.0904 | Avg Return: 1.9298
║     Mean Abs Error: 2.2022                              
║     Mean Error: 0.1606                              


                                                                                                                        

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: 1.6442 | Avg Return: 1.6771
║     Mean Abs Error: 2.1593                              
║     Mean Error: 0.0329                              
Epoch 3: Avg Reward -34.3856 | Time: 173.59s
💾 Checkpoint saved: checkpoints/checkpoint_epoch_3.pth




Output()

🧠 Updating Policy:   0%|          | 0/2 [00:00<?, ?it/s]                                           

Output()

🧠 Updating Policy:  50%|█████     | 1/2 [01:01<01:01, 61.59s/it, Act=0.03, Cri=0.62, KL=1.1703, EV=-1.48, Num Update=3]

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: -0.5269 | Avg Return: -0.8294
║     Mean Abs Error: 0.4558                              
║     Mean Error: 0.3024                              


                                                                                                                        

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: -0.4711 | Avg Return: -0.8117
║     Mean Abs Error: 0.4756                              
║     Mean Error: 0.3407                              
Epoch 4: Avg Reward -34.3860 | Time: 173.04s
💾 Checkpoint saved: checkpoints/checkpoint_epoch_4.pth




Output()

🧠 Updating Policy:   0%|          | 0/2 [00:00<?, ?it/s]                                           

Output()

🧠 Updating Policy:  50%|█████     | 1/2 [01:01<01:01, 61.43s/it, Act=0.10, Cri=0.40, KL=1.3325, EV=-3.05, Num Update=4]

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: -0.7474 | Avg Return: -1.1260
║     Mean Abs Error: 0.4325                              
║     Mean Error: 0.3786                              


                                                                                                                        

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: -0.8746 | Avg Return: -1.1305
║     Mean Abs Error: 0.3692                              
║     Mean Error: 0.2559                              
Epoch 5: Avg Reward -34.3343 | Time: 173.01s
💾 Checkpoint saved: checkpoints/checkpoint_epoch_5.pth




Output()

🧠 Updating Policy:   0%|          | 0/2 [00:00<?, ?it/s]                                           

Output()

🧠 Updating Policy:  50%|█████     | 1/2 [01:01<01:01, 61.54s/it, Act=0.08, Cri=0.36, KL=1.2139, EV=-2.67, Num Update=5]

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: -0.7832 | Avg Return: -1.2580
║     Mean Abs Error: 0.4917                              
║     Mean Error: 0.4748                              


                                                                                                                        

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: -1.2134 | Avg Return: -1.2546
║     Mean Abs Error: 0.2078                              
║     Mean Error: 0.0412                              
Epoch 6: Avg Reward -34.3460 | Time: 173.72s
💾 Checkpoint saved: checkpoints/checkpoint_epoch_6.pth




Output()

🧠 Updating Policy:   0%|          | 0/2 [00:00<?, ?it/s]                                           

Output()

🧠 Updating Policy:  50%|█████     | 1/2 [01:01<01:01, 61.55s/it, Act=0.03, Cri=0.17, KL=1.1613, EV=-0.10, Num Update=6]

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: -1.4888 | Avg Return: -1.5398
║     Mean Abs Error: 0.3192                              
║     Mean Error: 0.0510                              


                                                                                                                        

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: -1.4897 | Avg Return: -1.5048
║     Mean Abs Error: 0.2941                              
║     Mean Error: 0.0150                              
Epoch 7: Avg Reward -34.3512 | Time: 173.31s
💾 Checkpoint saved: checkpoints/checkpoint_epoch_7.pth




Output()

🧠 Updating Policy:   0%|          | 0/2 [00:00<?, ?it/s]                                           

Output()

🧠 Updating Policy:  50%|█████     | 1/2 [01:01<01:01, 61.49s/it, Act=0.09, Cri=0.09, KL=1.2414, EV=-0.31, Num Update=7]

║ 🔍 Critic Monitoring:                                            
║     Avg Prediction: -1.5065 | Avg Return: -1.4971
║     Mean Abs Error: 0.2130                              
║     Mean Error: 0.0094                              


                                                                                                                        

KeyboardInterrupt: 

In [None]:
# # import gc, torch
# # gc.collect()
# # torch.cuda.empty_cache()
# !pwd
# !ls


/kaggle/working
checkpoints  drone_linear_config.json  logs  Truck_config.json


In [None]:
# from IPython.display import FileLink
# import base64
# from IPython.display import HTML

# def create_download_link(filename, title = "Click vào đây để tải về"):
#     """
#     Hàm này tạo nút tải file trực tiếp, bỏ qua lỗi path của VS Code
#     """
#     try:
#         data = open(filename, "rb").read()
#         b64 = base64.b64encode(data)
#         payload = b64.decode()
#         html = '<a download="{filename}" href="data:text/csv;base64,{payload}" target="_blank">{title}</a>'
#         html = html.format(payload=payload, title=title, filename=filename)
#         return HTML(html)
#     except FileNotFoundError:
#         return "Lỗi: Không tìm thấy file! Hãy kiểm tra lại tên file."

# # Kiểm tra lại xem file zip lúc nãy tên là gì (thường là file_nen_cua_toi.zip)
# # Nếu chưa nén thì nén lại bằng lệnh: !zip -r output.zip .

# create_download_link('checkpoints/checkpoint_epoch_3.pth', 'Tải file ZIP về ngay')
# # FileLink(r'checkpoints/checkpoint_epoch_3.pth')

In [None]:
# import torch
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns
# from collections import defaultdict
# import os
# import json
# from datetime import datetime
# import pandas as pd

# class MOPVRPEvaluator:
#     """Evaluator for trained MOPVRP models"""
    
#     def __init__(self, checkpoint_path, config=None, device=None):
#         """
#         Initialize evaluator
        
#         Args:
#             checkpoint_path: Path to model checkpoint
#             config: PPOConfig object (if None, load from checkpoint)
#             device: Device to run evaluation on
#         """
#         self.checkpoint_path = checkpoint_path
#         self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
#         # Load checkpoint
#         print(f"📂 Loading checkpoint from: {checkpoint_path}")
#         checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
#         # Load config
#         if config is None:
#             saved_config = checkpoint.get('config', None)
#             if isinstance(saved_config, dict):
#                 # Config was saved as dict, reconstruct PPOConfig
#                 self.config = PPOConfig()
#                 for key, value in saved_config.items():
#                     setattr(self.config, key, value)
#             else:
#                 # Config is a PPOConfig object or None
#                 self.config = saved_config if saved_config else PPOConfig()
#         else:
#             self.config = config
        
#         # Ensure config has all required attributes for environment
#         self._ensure_config_compatibility()
        
#         # Initialize networks
#         self._init_networks()
        
#         # Load weights
#         self.actor.load_state_dict(checkpoint['actor_state_dict'])
#         self.actor.eval()
        
#         if 'critic_state_dict' in checkpoint:
#             self.critic.load_state_dict(checkpoint['critic_state_dict'])
#             self.critic.eval()
        
#         print(f"✅ Model loaded successfully!")
#         print(f"   Epoch: {checkpoint.get('epoch', 'N/A')}")
#         print(f"   Best Reward: {checkpoint.get('best_reward', 'N/A'):.4f}")
        
#         # Metrics storage
#         self.eval_results = defaultdict(list)
    
#     def _ensure_config_compatibility(self):
#         """Ensure config has all attributes required by MOPVRPEnvironment"""
#         # Add missing attributes if not present
#         if not hasattr(self.config, 'drone_params'):
#             self.config.drone_params = {
#                 'gama(w)': 150.0,
#                 'beta(w/kg)': 80.0
#             }
        
#         if not hasattr(self.config, 'drone_speed'):
#             self.config.drone_speed = 15.0
        
#         if not hasattr(self.config, 'drone_max_energy'):
#             self.config.drone_max_energy = 500000.0
        
#         if not hasattr(self.config, 't_takeoff'):
#             self.config.t_takeoff = 5.0
        
#         if not hasattr(self.config, 't_landing'):
#             self.config.t_landing = 5.0
        
#         if not hasattr(self.config, 'get_truck_speed_batch'):
#             # Add method to config
#             def get_truck_speed_batch(times):
#                 base_speed = 12.0
#                 decay = torch.exp(-times / 3600.0) * 2.0
#                 return base_speed - decay
#             self.config.get_truck_speed_batch = get_truck_speed_batch
        
#         print("✅ Config compatibility ensured")
        
#     def _init_networks(self):
#         """Initialize Actor and Critic networks"""
#         self.actor = MOPVRP_Actor(
#             static_size=4,
#             dynamic_size_truck=2,
#             dynamic_size_drone=4,
#             hidden_size=self.config.hidden_size,
#             num_layers=self.config.num_layers,
#             dropout=self.config.dropout
#         ).to(self.device)
        
#         self.critic = Critic(
#             static_size=4,
#             dynamic_size_truck=2,
#             dynamic_size_drone=4,
#             hidden_size=self.config.hidden_size
#         ).to(self.device)
    
#     def select_action(self, state, last_hh=None, deterministic=True, temperature=1.0):
#         """
#         Select action using policy
        
#         Args:
#             state: Current environment state
#             last_hh: LSTM hidden state
#             deterministic: If True, use greedy selection
#             temperature: Temperature for sampling (lower = more deterministic)
#         """
#         static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
        
#         with torch.no_grad():
#             # Get probabilities
#             veh_probs, node_probs, last_hh = self.actor(
#                 static, dyn_truck, dyn_drone,
#                 decoder_input=None,
#                 last_hh=last_hh,
#                 mask_customers=mask_cust,
#                 mask_vehicles=mask_veh
#             )
            
#             # Apply temperature
#             if not deterministic and temperature != 1.0:
#                 veh_probs = torch.softmax(torch.log(veh_probs + 1e-10) / temperature, dim=1)
#                 node_probs = torch.softmax(torch.log(node_probs + 1e-10) / temperature, dim=1)
            
#             if deterministic:
#                 # Greedy
#                 veh_idx = torch.argmax(veh_probs, dim=1)
#                 node_idx = torch.argmax(node_probs, dim=1)
#             else:
#                 # Stochastic
#                 veh_dist = torch.distributions.Categorical(veh_probs)
#                 node_dist = torch.distributions.Categorical(node_probs)
#                 veh_idx = veh_dist.sample()
#                 node_idx = node_dist.sample()
        
#         return veh_idx, node_idx, last_hh
    
#     def evaluate_instance(self, env, max_steps=200, deterministic=True, temperature=1.0):
#         """
#         Evaluate on a single instance
        
#         Returns:
#             dict with metrics: reward, makespan, total_time, steps, routes, etc.
#         """
#         state = env.reset()
#         last_hh = None
        
#         episode_reward = 0
#         steps = 0
#         done_flag = False
        
#         step_rewards = []
        
#         for step in range(max_steps):
#             if done_flag:
#                 break
            
#             # Select action
#             veh_idx, node_idx, last_hh = self.select_action(
#                 state, last_hh, deterministic, temperature
#             )
            
#             # Validate action
#             static, _, _, _, _ = state
#             valid_mask = env.get_valid_customer_mask(veh_idx)
#             invalid = (valid_mask.gather(1, node_idx.unsqueeze(1)) == 0).squeeze(1)
            
#             if invalid.any():
#                 # Force valid action
#                 node_idx = torch.where(invalid, torch.zeros_like(node_idx), node_idx)
            
#             # Step environment
#             state, reward, done, info = env.step(veh_idx, node_idx)
            
#             episode_reward += reward.mean().item()
#             step_rewards.append(reward.cpu().numpy())
#             steps += 1
            
#             if done.any():
#                 done_flag = True
        
#         # Extract final metrics
#         truck_times = env.dynamic_truck[:, 1, :].cpu().numpy()
#         drone_times = env.dynamic_drone[:, 1, :].cpu().numpy()
        
#         all_times = np.concatenate([truck_times, drone_times], axis=1)
#         makespan = np.max(all_times, axis=1)
#         total_waiting = env.total_waiting_time.cpu().numpy()
        
#         # Check completion
#         unvisited = env.mask_cust[:, 1:].sum(dim=1).cpu().numpy()
#         trucks_home = (env.dynamic_truck[:, 0, :] == 0).all(dim=1).cpu().numpy()
#         drones_home = (env.dynamic_drone[:, 0, :] == 0).all(dim=1).cpu().numpy()
        
#         success = (unvisited == 0) & trucks_home & drones_home
        
#         return {
#             'reward': episode_reward,
#             'makespan': makespan,
#             'total_waiting': total_waiting,
#             'steps': steps,
#             'success': success,
#             'unvisited': unvisited,
#             'routes': env.routes[0] if hasattr(env, 'routes') else None,
#             'step_rewards': step_rewards,
#             'truck_times': truck_times,
#             'drone_times': drone_times
#         }
    
#     def evaluate_dataset(self, num_instances=100, max_steps=200, 
#                         deterministic=True, temperature=1.0, 
#                         batch_size=None, verbose=True):
#         """
#         Evaluate on multiple instances
        
#         Args:
#             num_instances: Number of instances to evaluate
#             max_steps: Max steps per instance
#             deterministic: Use greedy action selection
#             temperature: Sampling temperature
#             batch_size: Override batch size (default: use config)
#             verbose: Print progress
        
#         Returns:
#             dict with aggregated metrics
#         """
#         if batch_size is None:
#             batch_size = self.ppo_config.batch_size
        
#         # Create dataloader
#         dataloader = get_rl_dataloader(batch_size=batch_size, device=self.device)
#         env = MOPVRPEnvironment(self.config, dataloader, device=self.device)
        
#         all_rewards = []
#         all_makespans = []
#         all_waiting_times = []
#         all_steps = []
#         all_success = []
        
#         num_batches = (num_instances + batch_size - 1) // batch_size
        
#         print(f"\n🎯 Evaluating on {num_instances} instances...")
#         print(f"{'='*60}")
        
#         instances_evaluated = 0
        
#         for batch_idx in range(num_batches):
#             if verbose and (batch_idx + 1) % 10 == 0:
#                 print(f"Progress: {batch_idx + 1}/{num_batches} batches ({instances_evaluated} instances)")
            
#             # Evaluate batch
#             results = self.evaluate_instance(env, max_steps, deterministic, temperature)
            
#             # Get actual batch size (last batch might be smaller)
#             actual_batch_size = len(results['makespan'])
#             instances_evaluated += actual_batch_size
            
#             # Aggregate results - store per instance
#             all_makespans.extend(results['makespan'].tolist() if hasattr(results['makespan'], 'tolist') else results['makespan'])
#             all_waiting_times.extend(results['total_waiting'].tolist() if hasattr(results['total_waiting'], 'tolist') else results['total_waiting'])
#             all_success.extend(results['success'].tolist() if hasattr(results['success'], 'tolist') else results['success'])
            
#             # For rewards and steps: these are per-batch, so we need to replicate or average
#             # Option 1: Store episode reward (one value per batch)
#             all_rewards.append(results['reward'])
            
#             # Option 2: Store steps (one value per batch)
#             all_steps.append(results['steps'])
            
#             # Break if we have enough instances
#             if instances_evaluated >= num_instances:
#                 break
        
#         # Trim to exact number of instances requested
#         all_makespans = all_makespans[:num_instances]
#         all_waiting_times = all_waiting_times[:num_instances]
#         all_success = all_success[:num_instances]
        
#         # Compute statistics
#         eval_stats = {
#             'num_instances': len(all_makespans),
#             'avg_reward': np.mean(all_rewards) if all_rewards else None,
#             'std_reward': np.std(all_rewards) if all_rewards else None,
#             'avg_makespan': np.mean(all_makespans),
#             'std_makespan': np.std(all_makespans),
#             'avg_waiting_time': np.mean(all_waiting_times),
#             'std_waiting_time': np.std(all_waiting_times),
#             'avg_steps': np.mean(all_steps) if all_steps else None,
#             'std_steps': np.std(all_steps) if all_steps else None,
#             'success_rate': np.mean(all_success) * 100,
#             'min_makespan': np.min(all_makespans),
#             'max_makespan': np.max(all_makespans),
#             'min_waiting': np.min(all_waiting_times),
#             'max_waiting': np.max(all_waiting_times)
#         }
        
#         # Store for later analysis
#         self.eval_results['rewards'] = all_rewards
#         self.eval_results['makespans'] = all_makespans
#         self.eval_results['waiting_times'] = all_waiting_times
#         self.eval_results['steps'] = all_steps
#         self.eval_results['success'] = all_success
        
#         return eval_stats


#     def print_evaluation_summary(self, stats):
#         """Print formatted evaluation summary"""
#         print(f"\n{'='*60}")
#         print(f"📊 EVALUATION SUMMARY")
#         print(f"{'='*60}")
#         print(f"Instances Evaluated: {stats['num_instances']}")
#         print(f"\n🏆 Performance Metrics:")
        
#         if stats['avg_reward'] is not None:
#             print(f"   Average Reward:       {stats['avg_reward']:>10.4f} ± {stats['std_reward']:.4f}")
        
#         print(f"   Average Makespan:     {stats['avg_makespan']:>10.2f} ± {stats['std_makespan']:.2f}s")
#         print(f"   Average Waiting Time: {stats['avg_waiting_time']:>10.2f} ± {stats['std_waiting_time']:.2f}s")
        
#         if stats['avg_steps'] is not None:
#             print(f"   Average Steps:        {stats['avg_steps']:>10.2f} ± {stats['std_steps']:.2f}")
        
#         print(f"   Success Rate:         {stats['success_rate']:>10.2f}%")
#         print(f"\n📈 Range:")
#         print(f"   Makespan:     [{stats['min_makespan']:.2f}, {stats['max_makespan']:.2f}]s")
#         print(f"   Waiting Time: [{stats['min_waiting']:.2f}, {stats['max_waiting']:.2f}]s")
#         print(f"{'='*60}\n")
    
#     def compare_strategies(self, num_instances=50, max_steps=200):
#         """
#         Compare different evaluation strategies
        
#         Returns:
#             DataFrame with comparison results
#         """
#         print(f"\n🔬 Comparing Evaluation Strategies...")
        
#         strategies = [
#             {'name': 'Greedy', 'deterministic': True, 'temperature': 1.0},
#             {'name': 'Stochastic (T=1.0)', 'deterministic': False, 'temperature': 1.0},
#             {'name': 'Stochastic (T=0.5)', 'deterministic': False, 'temperature': 0.5},
#             {'name': 'Stochastic (T=0.1)', 'deterministic': False, 'temperature': 0.1},
#         ]
        
#         comparison = []
        
#         for strategy in strategies:
#             print(f"\n📍 Testing: {strategy['name']}")
#             stats = self.evaluate_dataset(
#                 num_instances=num_instances,
#                 max_steps=max_steps,
#                 deterministic=strategy['deterministic'],
#                 temperature=strategy['temperature'],
#                 verbose=False
#             )
            
#             comparison.append({
#                 'Strategy': strategy['name'],
#                 'Avg Reward': stats['avg_reward'],
#                 'Avg Makespan': stats['avg_makespan'],
#                 'Avg Waiting': stats['avg_waiting_time'],
#                 'Success Rate': stats['success_rate'],
#                 'Avg Steps': stats['avg_steps']
#             })
        
#         df = pd.DataFrame(comparison)
#         print(f"\n{'='*80}")
#         print(df.to_string(index=False))
#         print(f"{'='*80}\n")
        
#         return df
    
#     def visualize_solution(self, instance_idx=0, save_path=None):
#         """
#         Visualize a solution on a single instance
        
#         Args:
#             instance_idx: Index of instance to visualize
#             save_path: Path to save figure (if None, show instead)
#         """
#         # Create environment and evaluate
#         dataloader = get_rl_dataloader(batch_size=1, device=self.device)
#         env = MOPVRPEnvironment(self.config, dataloader, device=self.device)
        
#         results = self.evaluate_instance(env, max_steps=200, deterministic=True)
        
#         # Extract data
#         coords = env.static[0, :2, :].cpu().numpy().T
#         scale = env.scale[0, 0].item()
#         routes = results['routes']
        
#         if routes is None:
#             print("❌ No route information available")
#             return
        
#         # Create figure
#         fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
        
#         # Plot 1: Routes visualization
#         ax1.scatter(coords[0, 0], coords[0, 1], c='red', s=300, marker='s', 
#                    label='Depot', zorder=5, edgecolors='black', linewidths=2)
#         ax1.scatter(coords[1:, 0], coords[1:, 1], c='blue', s=100, 
#                    label='Customers', zorder=3, alpha=0.7)
        
#         # Annotate nodes
#         ax1.annotate('D', (coords[0, 0], coords[0, 1]), fontsize=12, 
#                     ha='center', va='center', color='white', weight='bold')
#         for i in range(1, len(coords)):
#             ax1.annotate(str(i), (coords[i, 0], coords[i, 1]), 
#                         fontsize=8, ha='center', va='center')
        
#         # Plot truck routes
#         colors_truck = plt.cm.Set1(np.linspace(0, 1, len(routes[0]['trucks'])))
#         for truck_idx, route in enumerate(routes[0]['trucks']):
#             route_coords = coords[route]
#             ax1.plot(route_coords[:, 0], route_coords[:, 1], 
#                     'o-', color=colors_truck[truck_idx], linewidth=2,
#                     markersize=6, label=f'Truck {truck_idx}', alpha=0.7)
        
#         # Plot drone routes
#         colors_drone = plt.cm.Set2(np.linspace(0, 1, len(routes[0]['drones'])))
#         for drone_idx, route in enumerate(routes[0]['drones']):
#             route_coords = coords[route]
#             ax1.plot(route_coords[:, 0], route_coords[:, 1], 
#                     's--', color=colors_drone[drone_idx], linewidth=2,
#                     markersize=5, label=f'Drone {drone_idx}', alpha=0.7)
        
#         ax1.set_xlabel('X Coordinate', fontsize=12)
#         ax1.set_ylabel('Y Coordinate', fontsize=12)
#         ax1.set_title(f'Vehicle Routes (Scale: {scale:.0f}m)', fontsize=14, weight='bold')
#         ax1.legend(loc='best', fontsize=9)
#         ax1.grid(True, alpha=0.3)
#         ax1.set_aspect('equal')
        
#         # Plot 2: Performance metrics
#         truck_times = results['truck_times'][0]
#         drone_times = results['drone_times'][0]
        
#         x_pos = np.arange(len(truck_times) + len(drone_times))
#         times = np.concatenate([truck_times, drone_times])
#         labels = [f'T{i}' for i in range(len(truck_times))] + \
#                  [f'D{i}' for i in range(len(drone_times))]
#         colors = ['steelblue'] * len(truck_times) + ['orange'] * len(drone_times)
        
#         bars = ax2.bar(x_pos, times, color=colors, alpha=0.7, edgecolor='black')
#         ax2.axhline(y=results['makespan'][0], color='red', linestyle='--', 
#                    linewidth=2, label=f"Makespan: {results['makespan'][0]:.2f}s")
        
#         ax2.set_xlabel('Vehicle', fontsize=12)
#         ax2.set_ylabel('Completion Time (seconds)', fontsize=12)
#         ax2.set_title('Vehicle Completion Times', fontsize=14, weight='bold')
#         ax2.set_xticks(x_pos)
#         ax2.set_xticklabels(labels)
#         ax2.legend(fontsize=10)
#         ax2.grid(True, alpha=0.3, axis='y')
        
#         # Add value labels on bars
#         for bar, time in zip(bars, times):
#             height = bar.get_height()
#             ax2.text(bar.get_x() + bar.get_width()/2., height,
#                     f'{time:.1f}s', ha='center', va='bottom', fontsize=9)
        
#         plt.tight_layout()
        
#         if save_path:
#             plt.savefig(save_path, dpi=300, bbox_inches='tight')
#             print(f"💾 Figure saved to: {save_path}")
#         else:
#             plt.show()
        
#         plt.close()
    
#     def plot_metrics_distribution(self, save_path=None):
#         """Plot distribution of evaluation metrics"""
#         if not self.eval_results:
#             print("❌ No evaluation results available. Run evaluate_dataset first.")
#             return
        
#         fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
#         # Makespans
#         axes[0, 0].hist(self.eval_results['makespans'], bins=30, 
#                        color='steelblue', alpha=0.7, edgecolor='black')
#         axes[0, 0].axvline(np.mean(self.eval_results['makespans']), 
#                           color='red', linestyle='--', linewidth=2, label='Mean')
#         axes[0, 0].set_xlabel('Makespan (seconds)', fontsize=11)
#         axes[0, 0].set_ylabel('Frequency', fontsize=11)
#         axes[0, 0].set_title('Makespan Distribution', fontsize=12, weight='bold')
#         axes[0, 0].legend()
#         axes[0, 0].grid(True, alpha=0.3)
        
#         # Waiting Times
#         axes[0, 1].hist(self.eval_results['waiting_times'], bins=30, 
#                        color='orange', alpha=0.7, edgecolor='black')
#         axes[0, 1].axvline(np.mean(self.eval_results['waiting_times']), 
#                           color='red', linestyle='--', linewidth=2, label='Mean')
#         axes[0, 1].set_xlabel('Total Waiting Time (seconds)', fontsize=11)
#         axes[0, 1].set_ylabel('Frequency', fontsize=11)
#         axes[0, 1].set_title('Waiting Time Distribution', fontsize=12, weight='bold')
#         axes[0, 1].legend()
#         axes[0, 1].grid(True, alpha=0.3)
        
#         # Steps (if available)
#         if self.eval_results['steps']:
#             axes[1, 0].hist(self.eval_results['steps'], bins=30, 
#                            color='green', alpha=0.7, edgecolor='black')
#             axes[1, 0].axvline(np.mean(self.eval_results['steps']), 
#                               color='red', linestyle='--', linewidth=2, label='Mean')
#             axes[1, 0].set_xlabel('Number of Steps', fontsize=11)
#             axes[1, 0].set_ylabel('Frequency', fontsize=11)
#             axes[1, 0].set_title('Steps Distribution', fontsize=12, weight='bold')
#             axes[1, 0].legend()
#             axes[1, 0].grid(True, alpha=0.3)
#         else:
#             axes[1, 0].text(0.5, 0.5, 'No Steps Data', 
#                            ha='center', va='center', fontsize=14)
#             axes[1, 0].set_title('Steps Distribution', fontsize=12, weight='bold')
        
#         # Rewards (if available)
#         if self.eval_results['rewards']:
#             axes[1, 1].hist(self.eval_results['rewards'], bins=30, 
#                            color='purple', alpha=0.7, edgecolor='black')
#             axes[1, 1].axvline(np.mean(self.eval_results['rewards']), 
#                               color='red', linestyle='--', linewidth=2, label='Mean')
#             axes[1, 1].set_xlabel('Episode Reward', fontsize=11)
#             axes[1, 1].set_ylabel('Frequency', fontsize=11)
#             axes[1, 1].set_title('Reward Distribution', fontsize=12, weight='bold')
#             axes[1, 1].legend()
#             axes[1, 1].grid(True, alpha=0.3)
#         else:
#             axes[1, 1].text(0.5, 0.5, 'No Reward Data', 
#                            ha='center', va='center', fontsize=14)
#             axes[1, 1].set_title('Reward Distribution', fontsize=12, weight='bold')
        
#         plt.tight_layout()
        
#         if save_path:
#             plt.savefig(save_path, dpi=300, bbox_inches='tight')
#             print(f"💾 Figure saved to: {save_path}")
#         else:
#             plt.show()
        
#         plt.close()
    

#     def save_results(self, output_dir='evaluation_results'):
#         """Save evaluation results to files"""
#         os.makedirs(output_dir, exist_ok=True)
        
#         timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
#         # Check if we have results
#         if not self.eval_results or len(self.eval_results.get('makespans', [])) == 0:
#             print("⚠️  No evaluation results to save. Run evaluate_dataset() first.")
#             return
        
#         # Save metrics as JSON
#         metrics_file = os.path.join(output_dir, f'metrics_{timestamp}.json')
#         metrics = {
#             'checkpoint': self.checkpoint_path,
#             'timestamp': timestamp,
#             'num_instances': len(self.eval_results['makespans']),
#             'avg_reward': float(np.mean(self.eval_results['rewards'])) if self.eval_results['rewards'] else None,
#             'avg_makespan': float(np.mean(self.eval_results['makespans'])),
#             'avg_waiting_time': float(np.mean(self.eval_results['waiting_times'])),
#             'avg_steps': float(np.mean(self.eval_results['steps'])) if self.eval_results['steps'] else None,
#             'success_rate': float(np.mean(self.eval_results['success']) * 100),
#             'std_makespan': float(np.std(self.eval_results['makespans'])),
#             'std_waiting_time': float(np.std(self.eval_results['waiting_times'])),
#             'min_makespan': float(np.min(self.eval_results['makespans'])),
#             'max_makespan': float(np.max(self.eval_results['makespans'])),
#         }
        
#         with open(metrics_file, 'w') as f:
#             json.dump(metrics, f, indent=4)
        
#         print(f"💾 Metrics saved to: {metrics_file}")
        
#         # Save detailed results as CSV
#         # Make sure all arrays have same length
#         csv_file = os.path.join(output_dir, f'detailed_results_{timestamp}.csv')
        
#         num_instances = len(self.eval_results['makespans'])
        
#         # Create dataframe with per-instance data
#         df_data = {
#             'instance_id': list(range(num_instances)),
#             'makespan': self.eval_results['makespans'],
#             'waiting_time': self.eval_results['waiting_times'],
#             'success': self.eval_results['success']
#         }
        
#         # Add steps if available and same length
#         if self.eval_results['steps'] and len(self.eval_results['steps']) == num_instances:
#             df_data['steps'] = self.eval_results['steps']
        
#         df = pd.DataFrame(df_data)
#         df.to_csv(csv_file, index=False)
        
#         print(f"💾 Detailed results saved to: {csv_file}")
        
#         # Also save summary statistics
#         summary_file = os.path.join(output_dir, f'summary_{timestamp}.txt')
#         with open(summary_file, 'w') as f:
#             f.write("="*60 + "\n")
#             f.write("EVALUATION SUMMARY\n")
#             f.write("="*60 + "\n")
#             f.write(f"Checkpoint: {self.checkpoint_path}\n")
#             f.write(f"Timestamp: {timestamp}\n")
#             f.write(f"Number of instances: {num_instances}\n")
#             f.write(f"\nPerformance Metrics:\n")
#             f.write(f"  Average Makespan:     {metrics['avg_makespan']:.2f} ± {metrics['std_makespan']:.2f}s\n")
#             f.write(f"  Average Waiting Time: {metrics['avg_waiting_time']:.2f} ± {metrics['std_waiting_time']:.2f}s\n")
#             if metrics['avg_reward'] is not None:
#                 f.write(f"  Average Reward:       {metrics['avg_reward']:.4f}\n")
#             if metrics['avg_steps'] is not None:
#                 f.write(f"  Average Steps:        {metrics['avg_steps']:.2f}\n")
#             f.write(f"  Success Rate:         {metrics['success_rate']:.2f}%\n")
#             f.write(f"\nRange:\n")
#             f.write(f"  Makespan: [{metrics['min_makespan']:.2f}, {metrics['max_makespan']:.2f}]s\n")
#             f.write("="*60 + "\n")
        
#         print(f"💾 Summary saved to: {summary_file}")


# def main():
#     """Main evaluation script"""
#     import argparse
    
#     parser = argparse.ArgumentParser(description='Evaluate MOPVRP Model')
#     parser.add_argument('--checkpoint', type=str, required=True,
#                        help='Path to model checkpoint')
#     parser.add_argument('--num_instances', type=int, default=100,
#                        help='Number of instances to evaluate')
#     parser.add_argument('--batch_size', type=int, default=32,
#                        help='Batch size for evaluation')
#     parser.add_argument('--visualize', action='store_true',
#                        help='Visualize sample solutions', default=True)
#     parser.add_argument('--compare', action='store_true',
#                        help='Compare different strategies')
#     parser.add_argument('--save_dir', type=str, default='evaluation_results',
#                        help='Directory to save results')
    
#     args = parser.parse_args(["--checkpoint", "checkpoints/checkpoint_epoch_5.pth"])
    
#     print(f"\n{'='*60}")
#     print(f"🚀 MOPVRP Model Evaluation")
#     print(f"{'='*60}")
    
#     # Initialize evaluator
#     evaluator = MOPVRPEvaluator(checkpoint_path=args.checkpoint)
    
#     # Run main evaluation
#     stats = evaluator.evaluate_dataset(
#         num_instances=args.num_instances,
#         batch_size=args.batch_size,
#         deterministic=True
#     )
    
#     evaluator.print_evaluation_summary(stats)
    
#     # Compare strategies
#     if args.compare:
#         evaluator.compare_strategies(num_instances=50)
    
#     # Visualizations
#     if args.visualize:
#         print("\n📊 Generating visualizations...")
#         evaluator.visualize_solution(save_path=os.path.join(args.save_dir, 'sample_solution.png'))
#         evaluator.plot_metrics_distribution(save_path=os.path.join(args.save_dir, 'metrics_distribution.png'))
    
#     # Save results
#     evaluator.save_results(output_dir=args.save_dir)
    
#     print(f"\n✅ Evaluation completed!")


# if __name__ == "__main__":
#     main()

In [None]:
# def test_integrated_masking():
#     """Test masking trong vòng lặp thực tế"""
#     print("\n" + "="*60)
#     print("🧪 TEST 3: INTEGRATED MASKING (End-to-End)")
#     print("="*60)
    
#     # Setup
#     config = SystemConfig('Truck_config.json', 'drone_linear_config.json', drone_type="1")
#     ppo_config = PPOConfig()
#     dataloader = get_rl_dataloader(batch_size=4, device='cuda')
#     env = MOPVRPEnvironment(config, dataloader, device='cuda')
    
#     # Simple random policy for testing
#     def random_policy(state):
#         static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
#         batch_size = static.size(0)
        
#         # Sample vehicle
#         veh_probs = mask_veh.float() / mask_veh.sum(dim=1, keepdim=True)
#         veh_idx = torch.multinomial(veh_probs, 1).squeeze(1)
        
#         # Get valid customer mask
#         valid_mask = env.get_valid_customer_mask(veh_idx)
        
#         # Sample customer
#         cust_probs = valid_mask / valid_mask.sum(dim=1, keepdim=True)
#         node_idx = torch.multinomial(cust_probs, 1).squeeze(1)
        
#         return veh_idx, node_idx
    
#     # Run simulation
#     state = env.reset()
#     episode_stats = {
#         'invalid_actions': 0,
#         'total_actions': 0,
#         'drone_actions': 0,
#         'truck_actions': 0,
#         'depot_returns': 0,
#         'mask_conflicts': 0
#     }
    
#     print("\n🏃 Running 50-step simulation with random policy...")
    
#     for step in range(50):
#         # Select action
#         veh_idx, node_idx = random_policy(state)
        
#         # Validate action before stepping
#         static, dyn_truck, dyn_drone, mask_cust, mask_veh = state
        
#         # Check 1: Vehicle mask
#         for b in range(env.batch_size):
#             if mask_veh[b, veh_idx[b]] == 0:
#                 episode_stats['invalid_actions'] += 1
#                 print(f"  ❌ Step {step}, Env {b}: Selected INVALID vehicle {veh_idx[b].item()}")
        
#         # Check 2: Customer mask
#         valid_mask = env.get_valid_customer_mask(veh_idx)
#         for b in range(env.batch_size):
#             if valid_mask[b, node_idx[b]] == 0:
#                 episode_stats['invalid_actions'] += 1
#                 print(f"  ❌ Step {step}, Env {b}: Selected INVALID customer {node_idx[b].item()} for vehicle {veh_idx[b].item()}")
        
#         # Step environment
#         next_state, reward, done, _ = env.step(veh_idx, node_idx)
        
#         # Statistics
#         episode_stats['total_actions'] += env.batch_size
#         episode_stats['drone_actions'] += (veh_idx >= env.num_trucks).sum().item()
#         episode_stats['truck_actions'] += (veh_idx < env.num_trucks).sum().item()
#         episode_stats['depot_returns'] += (node_idx == 0).sum().item()
        
#         # Check for mask conflicts
#         next_mask_veh = next_state[4]
#         if (next_mask_veh.sum(dim=1) == 0).any():
#             episode_stats['mask_conflicts'] += 1
#             conflict_envs = torch.where(next_mask_veh.sum(dim=1) == 0)[0]
#             print(f"  ⚠️  Step {step}: Mask conflict in envs {conflict_envs.cpu().numpy()}")
        
#         if done.any():
#             print(f"  ✅ Step {step}: {done.sum().item()} envs completed")
#             state = env.reset()
#         else:
#             state = next_state
    
#     # Print summary
#     print(f"\n{'='*60}")
#     print("📊 SIMULATION SUMMARY")
#     print(f"{'='*60}")
#     print(f"Total Actions:    {episode_stats['total_actions']}")
#     print(f"Invalid Actions:  {episode_stats['invalid_actions']} ({episode_stats['invalid_actions']/episode_stats['total_actions']*100:.2f}%)")
#     print(f"Truck Actions:    {episode_stats['truck_actions']} ({episode_stats['truck_actions']/episode_stats['total_actions']*100:.2f}%)")
#     print(f"Drone Actions:    {episode_stats['drone_actions']} ({episode_stats['drone_actions']/episode_stats['total_actions']*100:.2f}%)")
#     print(f"Depot Returns:    {episode_stats['depot_returns']} ({episode_stats['depot_returns']/episode_stats['total_actions']*100:.2f}%)")
#     print(f"Mask Conflicts:   {episode_stats['mask_conflicts']}")
    
#     if episode_stats['invalid_actions'] == 0:
#         print("\n✅ ALL TESTS PASSED: No invalid actions detected!")
#     else:
#         print(f"\n❌ FAILED: {episode_stats['invalid_actions']} invalid actions found!")
    
#     if episode_stats['mask_conflicts'] > 0:
#         print("⚠️  WARNING: Mask conflicts detected (no valid vehicles available)")

# # Chạy test
# test_integrated_masking()

In [None]:
    # args = parser.parse_args(["--checkpoint", "checkpoints/checkpoint_epoch_5.pth"])
