In [1]:
import torch
from problems.DVRPSR_Dataset_street import DVRPSR_Dataset

In [2]:
train_data = torch.load("./data/validation/DVRPSR_{}_{}_{}_{}/unnormalized_val.pth".format(0.05,
                                                                                     0.75,
                                                                                     2,
                                                                                     600))

In [3]:
train_data.nodes.size()

torch.Size([10, 41, 4])

In [12]:
import torch


class DVRPSR_Environment:
    vehicle_feature = 8  # vehicle coordinates(x_i,y_i), veh_time_time_budget, total_travel_time, last_customer,
    # next(destination) customer, last rewards, next rewards
    customer_feature = 4

    # TODO: change pending cost for rewards

    def __init__(self, 
                 data = None, 
                 nodes=None, 
                 edges_attributes=None,
                 vehicle_count = 2,
                 vehicle_speed = 1,
                 vehicle_time_budget = 1,
                 pending_cost=1,
                 dynamic_reward=0.2)

        self.vehicle_count = data.vehicle_count if data is not None else vehicle_count
        self.vehicle_speed = data.vehicle_speed if data is not None else vehicle_speed
        self.vehicle_time_budget = data.vehicle_time_budget if data is not None else vehicle_time_budget

        self.nodes = data.nodes if nodes is None else nodes
        self.edge_attributes = data.edges_attributes if edges_attributes is None else edges_attributes

        self.minibatch, self.nodes_count, _ = self.nodes.size()
        
        self.distance_matrix = self.edge_attributes.view((self.minibatch, self.nodes_count, self.nodes_count))
        self.edges_attributes = None
        
        self.pending_cost = pending_cost
        self.dynamic_reward = dynamic_reward
        self.budget_penalty = budget_penalty

    def _update_current_vehicles(self, dest, customer_index, tau=0):

        # calculate travel time
        # TODO: 1) in real world setting we need to calculate the distance of arc
        # If nodes i and j are directly connected by a road segment (i, j) ∈ A, then t(i,j)=t_ij;
        # otherwise, t(i,j)=t_ik1 +t_k1k2 +...+t_knj, where k1,...,kn ∈ V are the nodes along the
        # shortest path from node i to node j.
        #      2) calculate stating time for each vehicle $\tau $, currently is set to zero

        # update vehicle previous and next customer id
        self.current_vehicle[:, :, 4] = self.current_vehicle[:, :, 5]
        self.current_vehicle[:, :, 5] = customer_index

        # get the distance from current vehicle to its next destination
        # Convert indices to integers
        current_idx = self.current_vehicle[:, :, 4].long()
        next_idx    = self.current_vehicle[:, :, 5].long()
        dist = self.distance_matrix[torch.arange(self.minibatch).unsqueeze(1), current_idx, next_idx].view(self.minibatch, 1)
        
        # total travel time
        tt = dist / self.vehicle_speed

        # customers which are dynamicaly appeared
        dyn_cust = (dest[:, :, 3] > 0).float()

        # budget left while travelling to destination nodes
        budget = tt + dest[:, :, 2]
        # print(budget, tau, tt, dest[:,:,2])

        # update vehicle features based on destination nodes
        self.current_vehicle[:, :, :2] = dest[:, :, :2]
        self.current_vehicle[:, :, 2] -= budget
        self.current_vehicle[:, :, 3] += tt
        self.current_vehicle[:, :, 6] = self.current_vehicle[:, :, 7]
        self.current_vehicle[:, :, 7] = -dist

        # update vehicles states
        self.vehicles = self.vehicles.scatter(1,
                                              self.current_vehicle_index[:, :, None].expand(-1, -1,
                                                                                            self.vehicle_feature),
                                              self.current_vehicle)

        return dist, dyn_cust

    def _done(self):
        
        
        pending_custtomers = (self.served^True).float().sum(-1, keepdim=True)-1
        self.vehicle_done.scatter_(1, 
                                   self.current_vehicle_index, 
                                   torch.logical_or((pending_custtomers == 0),(self.current_vehicle[:, :, 2] <= 0)))
        self.done = bool(self.vehicle_done.all())

    def _update_mask(self, customer_index):

        self.new_customer = False
        self.served.scatter_(1, customer_index, customer_index > 0)

        # cost for a vehicle to go to customer and back to deport considering service duration
        current_idx = self.current_vehicle[:, :, 4].long()
        dist_vehicle_customer_depot = self.distance_matrix[torch.arange(self.minibatch).unsqueeze(1), current_idx, :].squeeze(1) + self.distance_matrix[:, :, 0]
        cost = dist_vehicle_customer_depot / self.vehicle_speed
        cost += self.nodes[:, :, 2]


        overtime_mask = self.current_vehicle[:, :, None, 2] - cost.unsqueeze(1)
        overtime = torch.zeros_like(self.mask).scatter_(1,
                                                        self.current_vehicle_index[:, :, None].expand(-1, -1,
                                                                                                      self.nodes_count),
                                                        overtime_mask < 0)

        self.mask = self.mask | self.served[:, None, :] | overtime | self.vehicle_done[:, :, None]
        self.mask[:, :, 0] = 0  # depot

    # updating current vehicle to find the next available vehicle
    def _update_next_vehicle(self, veh_index=None):

        if veh_index is None:
            avail = self.vehicles[:, :, 3].clone()
            avail[self.vehicle_done] = float('inf')
            self.current_vehicle_index = avail.argmin(1, keepdim=True)
        else:
            self.current_vehicle_index = veh_index

        self.current_vehicle = self.vehicles.gather(1, self.current_vehicle_index[:, :, None].expand(-1, -1,
                                                                                                     self.vehicle_feature))
        self.current_vehicle_mask = self.mask.gather(1, self.current_vehicle_index[:, :, None].expand(-1, -1,
                                                                                                      self.nodes_count))

    def _update_dynamic_customers(self):

        time = self.current_vehicle[:, :, 3].clone()

        reveal_dyn_reqs = torch.logical_and((self.customer_mask), (self.nodes[:, :, 3] <= time))


        if reveal_dyn_reqs.any():
            self.new_customer = True
            self.customer_mask = self.customer_mask ^ reveal_dyn_reqs
            self.mask = self.mask ^ reveal_dyn_reqs[:, None, :].expand(-1, self.vehicle_count, -1)
            self.vehicle_done = torch.logical_and(self.vehicle_done, (reveal_dyn_reqs.any(1) ^ True).unsqueeze(1))
            self.vehicles[:, :, 3] = torch.max(self.vehicles[:, :, 3], time)
            self._update_next_vehicle()

    def reset(self):
        # reset vehicle (minibatch*veh_count*veh_feature)
        self.vehicles = self.nodes.new_zeros((self.minibatch, self.vehicle_count, self.vehicle_feature))
        self.vehicles[:, :, :2] = self.nodes[:, :1, :2]
        self.vehicles[:, :, 2] = self.vehicle_time_budget

        # reset vehicle done
        self.vehicle_done = self.nodes.new_zeros((self.minibatch, self.vehicle_count), dtype=torch.bool)
        self.done = False
        
        # initialize reward as tour length
        self.tour_length = torch.zeros((self.minibatch,1)).to(self.nodes.device)
        
        # reset cust_mask
        self.customer_mask = self.nodes[:, :, 3] > 0

        # reset new customers and served customer since now to zero (all false)
        self.new_customer = True
        self.served = torch.zeros_like(self.customer_mask)

        # reset mask (minibatch*veh_count*nodes)
        self.mask = self.customer_mask[:, None, :].repeat(1, self.vehicle_count, 1)

        # reset current vehicle index, current vehicle, current vehicle mask
        self.current_vehicle_index = self.nodes.new_zeros((self.minibatch, 1), dtype=torch.int64)

        self.current_vehicle = self.vehicles.gather(1,
                                                    self.current_vehicle_index[:, :, None].expand(-1, -1,
                                                                                                  self.vehicle_feature))
        self.current_vehicle_mask = self.mask.gather(1,
                                                     self.current_vehicle_index[:, :, None].expand(-1, -1,
                                                                                                   self.nodes_count))

    def get_reward(self):
        if self.done:
            # penalty for pending customers
            pending_customers = (self.served ^ True).float().sum(-1, keepdim=True) - 1
            self.tour_length -= pending_customers*self.pending_cost
        return self.tour_length
        
    
    def step(self, customer_index, veh_index=None):
        dest = self.nodes.gather(1, customer_index[:, :, None].expand(-1, -1, self.customer_feature))
        dist, dyn_cust = self._update_current_vehicles(dest, customer_index)

        self._done(customer_index)
        self._update_mask(customer_index)
        self._update_next_vehicle(veh_index)
        self.tour_length -= dist 
        self._update_dynamic_customers()


    def state_dict(self, dest_dict=None):
        if dest_dict is None:
            dest_dict = {'vehicles': self.vehicles,
                         'vehicle_done': self.vehicle_done,
                         'served': self.served,
                         'mask': self.mask,
                         'current_vehicle_index': self.current_vehicle_index}

        else:
            dest_dict["vehicles"].copy_(self.vehicles)
            dest_dict["vehicle_done"].copy_(self.vehicle_done)
            dest_dict["served"].copy_(self.served)
            dest_dict["mask"].copy_(self.mask)
            dest_dict["current_vehicle_index"].copy_(self.current_vehicle_index)

        return dest_dict

    def load_state_dict(self, state_dict):
        self.vehicles.copy_(state_dict["vehicles"])
        self.vehicle_done.copy_(state_dict["vehicle_done"])
        self.served.copy_(state_dict["served"])
        self.mask.copy_(state_dict["mask"])
        self.current_vehicle_index.copy_(state_dict["current_vehicle_index"])

        self.current_vehicle = self.vehicles.gather(1,
                                                    self.current_vehicle_index[:, :, None].expand(-1, -1,
                                                                                                  self.vehicle_feature))
        self.current_vehicle_mask = self.mask.gather(1, self.current_vehicle_index[:, :, None].expand(-1, -1,
                                                                                                      self.customer_feature))


In [13]:
env = DVRPSR_Environment(train_data)

In [14]:
env.reset()

In [15]:
env.step(torch.tensor([[3],[2],[3],[3],[2],[3],[3],[2],[3],[3]]))

In [18]:
(env.served^True).float().sum(-1, keepdim=True)-1

tensor([[39.],
        [39.],
        [39.],
        [39.],
        [39.],
        [39.],
        [39.],
        [39.],
        [39.],
        [39.]])

In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphMultiHeadAttention(nn.Module):

    def __init__(self, num_head, query_size, key_size=None, value_size=None, edge_dim_size=None, bias=False):

        super(GraphMultiHeadAttention, self).__init__()
        self.num_head = num_head
        self.query_size = query_size

        self.key_size = self.query_size if key_size is None else key_size
        self.value_size = self.key_size if value_size is None else value_size
        self.edge_dim_size = self.query_size // 2 if edge_dim_size is None else edge_dim_size

        self.scaling_factor = self.key_size ** -0.5

        self.keys_per_head = self.key_size // self.num_head
        self.values_per_head = self.value_size // self.num_head
        self.edge_size_per_head = self.edge_dim_size // self.num_head

        self.edge_embedding = nn.Linear(self.edge_dim_size, self.num_head * self.edge_size_per_head, bias=bias)
        self.query_embedding = nn.Linear(self.query_size, self.num_head * self.keys_per_head, bias=bias)
        self.key_embedding = nn.Linear(self.key_size, self.num_head * self.keys_per_head, bias=bias)
        self.value_embedding = nn.Linear(self.value_size, self.num_head * self.values_per_head, bias=bias)
        self.recombine = nn.Linear(self.num_head * self.values_per_head, self.value_size, bias=bias)
        self.initialize_weights()

    def initialize_weights(self):
        nn.init.xavier_uniform_(self.query_embedding.weight)
        nn.init.xavier_uniform_(self.key_embedding.weight)
        inv_sq_dv = self.value_size ** -0.5
        nn.init.uniform_(self.value_embedding.weight, -inv_sq_dv, inv_sq_dv)
        
        
    def precompute(self, keys, values=None):

        values = keys if values is None else values

        size_KV = keys.size(-2)

        self.K_project_pre = self.key_embedding(keys).view(
            -1, size_KV, self.num_head, self.keys_per_head).permute(0, 2, 3, 1)

        self.V_project_pre = self.value_embedding(values).view(
            -1, size_KV, self.num_head, self.values_per_head).permute(0, 2, 1, 3)
        

    def forward(self, queries, keys=None, values=None, edge_attributes=None, mask=None, edge_mask=None):

        *batch_size, size_Q, _ = queries.size()

        # get queries projection
        Q_project = self.query_embedding(queries).view(
            -1, size_Q, self.num_head, self.keys_per_head).permute(0, 2, 1, 3)

        # get keys projection
        if keys is None:
            if self.K_project_pre is None:
                size_KV = size_Q
                K_project = self.key_embedding(queries).view(
                    -1, size_KV, self.num_head, self.keys_per_head).permute(0, 2, 3, 1)
            else:
                size_KV = self.K_project_pre.size(-1)
                K_project = self.K_project_pre
        else:
            size_KV = keys.size(-2)
            K_project = self.key_embedding(keys).view(
                -1, size_KV, self.num_head, self.keys_per_head).permute(0, 2, 3, 1)

        # get values projection
        if values is None:
            if self.V_project_pre is None:
                V_project = self.value_embedding(queries).view(
                    -1, size_KV, self.num_head, self.values_per_head).permute(0, 2, 1, 3)
            else:
                V_project = self.V_project_pre
        else:
            V_project = self.value_embedding(values).view(
                -1, size_KV, self.num_head, self.values_per_head).permute(0, 2, 1, 3)

        # calculate the compability
        attention = Q_project.matmul(K_project)
        attention *= self.scaling_factor

        # if edge attributes are required
        if edge_attributes is not None:
            # TODO: edge mask (is it required)
            edge_project = self.edge_embedding(edge_attributes).view(
                -1, size_Q, size_Q, self.num_head, self.edge_size_per_head)
            edge_project_expanded = edge_project.mean(-1).permute(0, 3, 1, 2)

            attention = attention * edge_project_expanded

        if mask is not None:

            if mask.numel() * self.num_head == attention.numel():
                m = mask.view(-1, 1, size_Q, size_KV).expand_as(attention)
            else:
                m = mask.view(-1, 1, 1, size_KV).expand_as(attention)

            attention[m.bool()] = -float('inf')

        attention = F.softmax(attention, dim=-1)
        attention = attention.matmul(V_project).permute(0, 2, 1, 3).contiguous().view(
            *batch_size, size_Q, self.num_head * self.values_per_head)

        output = self.recombine(attention)

        return output




import torch
import torch.nn as nn
import torch.nn.functional as F
from nets import GraphMultiHeadAttention

class GraphEncoderlayer(nn.Module):

    def __init__(self, num_head, model_size, ff_size):
        super().__init__()

        self.attention = GraphMultiHeadAttention(num_head, query_size=model_size)
        self.BN1 = nn.BatchNorm1d(model_size)
        self.FFN_layer1 = nn.Linear(model_size, ff_size)

        self.FFN_layer2 = nn.Linear(ff_size, model_size)
        self.BN2 = nn.BatchNorm1d(model_size)

    def forward(self, h, e=None, mask=None):
        h_attn = self.attention(h, edge_attributes=e, mask=mask)
        h_bn = self.BN1((h_attn + h).permute(0, 2, 1)).permute(0, 2, 1)

        h_layer1 = F.relu(self.FFN_layer1(h_bn))
        h_layer2 = self.FFN_layer2(h_layer1)

        h_out = self.BN2((h_bn + h_layer2).permute(0, 2, 1)).permute(0, 2, 1)

        if mask is not None:
            h_out[mask] = 0

        return h_out


class GraphEncoder(nn.Module):

    def __init__(self, encoder_layer, num_head, model_size, ff_size):
        super().__init__()

        for l in range(encoder_layer):
            self.add_module(str(l), GraphEncoderlayer(num_head, model_size, ff_size))

    def forward(self, h_in, e_in=None, mask=None):

        h = h_in
        e = e_in

        for child in self.children():
            h = child(h, e, mask=mask)
        return h

    
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from nets import GraphMultiHeadAttention
from nets.Encoder import GraphEncoder

class GraphAttentionModel(nn.Module):

    def __init__(self, customer_feature, vehicle_feature, model_size=128, encoder_layer=3,
                 num_head=8, ff_size=128, tanh_xplor=10, edge_embedding_dim=64, greedy=False):
        super().__init__()

        # get models parameters for encoding-decoding
        self.model_size = model_size
        self.scaling_factor = self.model_size ** 0.5
        self.tanh_xplor = tanh_xplor
        self.greedy = greedy

        # Initialize encoder and embeddings
        self.customer_encoder = GraphEncoder(encoder_layer=3, num_head=8, model_size=model_size, ff_size=ff_size)
        self.customer_embedding = nn.Linear(customer_feature, model_size)
        self.depot_embedding = nn.Linear(customer_feature, model_size)

        # initialize edge embedding
        self.edge_embedding = nn.Linear(1, edge_embedding_dim)

        # Initialize vehicle embedding and encoding
        # self.vehicle_embedding = nn.Linear(vehicle_feature, ff_size, bias=False)

        self.fleet_attention = GraphMultiHeadAttention(num_head, vehicle_feature, model_size)

        self.vehicle_attention = GraphMultiHeadAttention(num_head, model_size)

        # customer projection
        self.customer_projection = nn.Linear(self.model_size, self.model_size)  # TODO: MLP instaed of nn.Linear

    def encode_customers(self, env, customer_mask=None):

        customer_emb = torch.cat((self.depot_embedding(env.nodes[:, :1, :]),
                                  self.customer_embedding(env.nodes[:, 1:, :])), dim=1)
        if customer_mask is not None:
            customer_emb[customer_mask] = 0

        edge_emb = self.edge_embedding(env.edge_attributes)

        self.customer_encoding = self.customer_encoder(customer_emb, edge_emb, mask=customer_mask)

        self.fleet_attention.precompute(self.customer_encoding)

        self.customer_representation = self.customer_projection(self.customer_encoding)
        if customer_mask is not None:
            self.customer_representation[customer_mask] = 0

    def vehicle_representation(self, vehicles, vehicle_index, vehicle_mask=None):

        fleet_representation = self.fleet_attention(vehicles, mask=vehicle_mask)
        vehicle_query = fleet_representation.gather(0, vehicle_index.unsqueeze(2).expand(-1, -1, self.model_size))

        return self.vehicle_attention(vehicle_query, fleet_representation, fleet_representation)
    

    def score_customers(self, vehicle_representation):

        # print(vehicle_representation.size(), self.customer_representation.size())
        compact = torch.bmm(vehicle_representation,
                            self.customer_representation.transpose(2, 1))
        compact *= self.scaling_factor

        if self.tanh_xplor is not None:
            compact = self.tanh_xplor * compact.tanh()

        return compact

    def get_prop(self, compact, vehicle_mask=None):
        if vehicle_mask is not None:
            compact = compact.masked_fill(vehicle_mask.unsqueeze(-1), float('-inf'))
        compact = F.softmax(compact, dim=-1)
        return compact


    def step(self, env, old_action=None):

        _vehicle_representation = self.vehicle_representation(env.vehicles,
                                                              env.current_vehicle_index,
                                                              env.current_vehicle_mask)

        compact = self.score_customers(_vehicle_representation)
        prop = self.get_prop(compact, env.current_vehicle_mask)
        # print(compact.size())

        # step actions based on model act or evalaute
        if old_action is not None:

            # get entropy
            dist = Categorical(prop)
            old_actions_logp = dist.log_prob(old_action[:, 1].unsqueeze(-1))
            entropy = dist.entropy()

            is_done = float(env.done)

            entropy *= (1. - is_done)
            old_actions_logp *= (1. - is_done)
            return old_action[:, 1].unsqueeze(-1), entropy, old_actions_logp


        else:
            dist = Categorical(prop)

            if self.greedy:
                _, customer_index = p.max(dim=-1)
            else:
                customer_index = dist.sample()

            is_done = float(env.done)

            logp = dist.log_prob(customer_index)
            logp *= (1. - is_done)

            return customer_index, logp

    def forward(self, env, old_actions=None, is_update=False):

        if is_update:
            env.reset()
            entropys, old_actions_logps = [], []

            steps = old_actions.size(0)

            for i in range(steps):
                if env.new_customer:
                    self.encode_customers(env, env.customer_mask)

                old_action = old_actions[i, :, :]
                next_action = old_actions[i + 1, :, :] if i < steps - 1 else old_action


                next_vehicle_index = next_action[:, 0].unsqueeze(-1)
                # print(next_vehicle_index)

                customer_index, entropy, logp = self.step(env, old_action)

                env.step(customer_index, next_vehicle_index)

                old_actions_logps.append(logp)
                entropys.append(entropy)

            entropys = torch.cat(entropys, dim=1)
            num_e = entropys.ne(0).float().sum(1)
            entropy = entropys.sum(1) / num_e

            old_actions_logps = torch.cat(old_actions_logps, dim=1)
            old_actions_logps = old_actions_logps.sum(1)

            return entropy, old_actions_logps, 0

        else:
            env.reset()
            actions, logps, rewards = [], [], []

            while not env.done:
                if env.new_customer:
                    self.encode_customers(env, env.customer_mask)

                customer_index, logp = self.step(env)
                actions.append((env.current_vehicle_index, customer_index))
                logps.append(logp)
                rewards.append(env.step(customer_index))

            # actions = torch.cat(actions, dim=1)
            logps = torch.cat(logps, dim=1)
            logp = logps.sum(dim=1)

            rewards = torch.cat(rewards, dim=1)
            rewards = rewards.sum(dim=1)

            return actions, logp, rewards





In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

class Critic(nn.Module):

    # critic will take environment as imput and ouput the values for loss function
    # which is basically the estimation of complexity of actions

    def __init__(self, model, customers_count, ff_size=512):
        super(Critic, self).__init__()

        self.model = model
        self.ff_layer1 = nn.Linear(customers_count, ff_size)
        self.ff_layer2 = nn.Linear(ff_size, customers_count)

    def eval_step(self, env, compatibility, customer_index):
        compact = compatibility.clone()
        compact[env.current_vehicle_mask] = 0

        value = self.ff_layer1(compact)
        value = F.relu(value)
        value = self.ff_layer2(value)

        val = value.gather(2, customer_index.unsqueeze(1)).expand(-1, 1, -1)
        return val.squeeze(1)

    def __call__(self, env):
        self.model.encode_customers(env)
        env.reset()

        values = []

        while not env.done:
            _vehicle_presentation = self.model.vehicle_representation(env.vehicles,
                                                                      env.current_vehicle_index,
                                                                      env.current_vehicle_mask)
            compatibility = self.model.score_customers(_vehicle_presentation)
            prop = self.model.get_prop(compatibility, env.current_vehicle_mask)
            dist = Categorical(prop)
            customer_index = dist.sample()

            values.append(self.eval_step(env, compatibility, customer_index))

        return values[0]




In [None]:
import torch
import torch.nn as nn
from nets import GraphAttentionModel
from agents.Critic import Critic


class Actor_Critic(nn.Module):

    def __init__(self,
                 customer_feature,
                 vehicle_feature,
                 customers_count,
                 model_size=128,
                 encoder_layer=3,
                 num_head=8,
                 ff_size_actor=128,
                 ff_size_critic=512,
                 tanh_xplor=10,
                 edge_embedding_dim=64,
                 greedy=False):
        super(Actor_Critic, self).__init__()

        model = GraphAttentionModel(customer_feature, vehicle_feature, model_size, encoder_layer,
                                    num_head, ff_size_actor, tanh_xplor, edge_embedding_dim, greedy)
        self.actor = model

        self.critic = Critic(model, customers_count, ff_size_critic)

    def act(self, env, old_actions=None, is_update=False):
        actions, logps, rewards = self.actor(env)
        return actions, logps, rewards

    def evaluate(self, env, old_actions, is_update):
        values = self.critic(env)
        entropys, old_logps, _ = self.actor(env, old_actions, is_update)
        return entropys, old_logps, values


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.optim.lr_scheduler import LambdaLR
import time
from torch.nn.utils import clip_grad_norm_

from agents.Actor_Critic import Actor_Critic
from problems import DVRPSR_Environment


class AgentPPO:

    def __init__(self,
                 customer_feature,
                 vehicle_feature,
                 customers_count,
                 model_size=128,
                 encoder_layer=3,
                 num_head=8,
                 ff_size_actor=128,
                 ff_size_critic=128,
                 tanh_xplor=10,
                 edge_embedding_dim=64,
                 greedy=False,
                 learning_rate=3e-4,
                 ppo_epoch=3,
                 batch_size=128,
                 entropy_value=0.2,
                 epsilon_clip=0.2,
                 max_grad_norm = 2):

        self.policy = Actor_Critic(customer_feature, vehicle_feature, customers_count, model_size,
                                   encoder_layer, num_head, ff_size_actor, ff_size_critic,
                                   tanh_xplor, edge_embedding_dim, greedy)

        self.old_policy = Actor_Critic(customer_feature, vehicle_feature, customers_count, model_size,
                                       encoder_layer, num_head, ff_size_actor, ff_size_critic,
                                       tanh_xplor, edge_embedding_dim, greedy)

        self.old_policy.load_state_dict(self.policy.state_dict())

        # ppo update parameters
        # self.learning_rate = learning_rate
        self.ppo_epoch = ppo_epoch
        self.batch_size = batch_size
        self.entropy_value = entropy_value
        self.epsilon_clip = epsilon_clip
        self.batch_index = 1

        # initialize the Adam optimizer
        #self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.MSE_loss = nn.MSELoss()

        # actor-critic parameters
        self.customer_feature = customer_feature
        self.vehicle_feature = vehicle_feature
        self.customers_count = customers_count
        self.model_size = model_size
        self.encoder_layer = encoder_layer
        self.num_head = num_head
        self.ff_size_actor = ff_size_actor
        self.ff_size_critic = ff_size_critic
        self.tanh_xplor = tanh_xplor
        self.edge_embedding_dim = edge_embedding_dim
        self.greedy = greedy
        self.max_grad_norm = max_grad_norm

        self.times, self.losses, self.rewards, self.critic_rewards = [], [], [], []

    def advantage_normalization(self, advantage):

        std = advantage.std()

        assert std != 0. and not torch.isnan(std), 'Need nonzero std'

        norm_advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
        return norm_advantage

    def pad_actions(self, actions):
        max_len = max(a.size(0) for a in actions)
        padded_actions = torch.zeros(len(actions), max_len, actions[0].size(1), dtype=actions[0].dtype)
        for i, a in enumerate(actions):
            length = a.size(0)
            padded_actions[i, :length] = a
        return padded_actions

    def update(self, memory, epoch, data=None, env=None, optim=None, lr_scheduler=None, device=None):

        old_nodes = torch.stack(memory.nodes)
        old_edge_attributes = torch.stack(memory.edge_attributes)
        old_rewards = torch.stack(memory.rewards).unsqueeze(-1)
        old_log_probs = torch.stack(memory.log_probs).unsqueeze(-1)
        padded_actions = torch.stack(memory.actions)

        # preprocessing on old actions
        #padded_actions, max_length = self.pad_actions(memory.actions)

        # create update data for PPO
        datas = []

        # print(memory.actions.size())

        # Create update data for PPO
        datas = [Data(nodes=nodes, 
                      edge_attributes=edge_attributes, 
                      actions=actions, 
                      rewards=rewards, 
                      log_probs=log_probs)
                 for nodes, edge_attributes, actions, rewards, log_probs in zip(old_nodes, old_edge_attributes,
                                                                                padded_actions, old_rewards,
                                                                                old_log_probs)]
        # print(datas[0], self.batch_size)

        self.policy.to(device)
        self.policy.train()

        data_loader = DataLoader(datas, batch_size=self.batch_size, shuffle=False)

        # scheduler = LambdaLR(self.optimizer, lr_lambda=lambda f: 0.96 ** epoch)
        value_buffer = 0

        env = env if env is not None else DVRPSR_Environment

        for i in range(self.ppo_epoch):

            
            epoch_start = time.time()
            start = epoch_start

            self.times, self.losses, self.rewards, self.critic_rewards = [], [], [], []

            for batch_index, minibatch_data in enumerate(data_loader):

                self.batch_index += 1
                #minibatch_data =  minibatch_data.to(device)

                if data.customer_mask is None:
                    nodes = minibatch_data.nodes.to(device)
                    customer_mask = None
                    edge_attributes = minibatch_data.edge_attributes.to(device)

                nodes = nodes.view(self.batch_size, self.customers_count, self.customer_feature)
                edge_attributes = edge_attributes.view(self.batch_size, self.customers_count * self.customers_count, 1)

                old_actions_for_env = minibatch_data.actions.view(self.batch_size, max_length, 2).permute(1, 0, 2).to(device)
                #print(old_actions_for_env)

                with torch.no_grad():
                    dyna_env = env(data, nodes, customer_mask, edge_attributes)
                    entropy, log_probs, values = self.policy.evaluate(dyna_env, old_actions_for_env, True)

                # normalize the rewards and get the MSE loss with critics values
                R = minibatch_data.rewards
                R_norm = self.advantage_normalization(R)

                mse_loss = self.MSE_loss(R_norm, values.squeeze(-1))

                # PPO ration (r(0)_t)
                ratio = torch.exp(log_probs - minibatch_data.log_probs)

                # PPO advantage
                advantage = R_norm - values.detach()

                # PPO overall loss function
                actor_loss1 = ratio * advantage
                actor_loss2 = torch.clamp(ratio, 1 - self.epsilon_clip, 1 + self.epsilon_clip) * advantage

                actor_loss = torch.min(actor_loss1, actor_loss2)

                # total loss
                loss = actor_loss + 0.5 * mse_loss - self.entropy_value * entropy

                # optimizer and backpropogation
                #self.optimizer.zero_grad()
                optim.zero_grad()
                loss.mean().backward(retain_graph = True)
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                
                #self.optimizer.step()
                optim.step()

                lr_scheduler.step()

                self.rewards.append(torch.mean(R_norm.detach()).item())
                self.losses.append(torch.mean(loss.detach()).item())
                self.critic_rewards.append(torch.mean(values.detach()).item())

        self.old_policy.load_state_dict(self.policy.state_dict())

        return self.rewards, self.losses, self.critic_rewards

import os
import time
import torch
from torch.utils.data import Dataset, DataLoader

from collections import OrderedDict
from collections import namedtuple
from itertools import product
import numpy as np
from agents import AgentPPO
from utils import Memory
from utils.Misc import formate_old_actions
import tqdm
from tqdm import tqdm

class TrainPPOAgent:

    def __init__(self,
                 customer_feature,
                 vehicle_feature,
                 customers_count,
                 model_size=128,
                 encoder_layer=3,
                 num_head=8,
                 ff_size_actor=128,
                 ff_size_critic=128,
                 tanh_xplor=10,
                 edge_embedding_dim=64,
                 greedy=False,
                 learning_rate=3e-4,
                 ppo_epoch=3,
                 batch_size=128,
                 entropy_value=0.2,
                 epsilon_clip=0.2,
                 epoch=50,
                 timestep=2,
                 max_grad_norm=2):

        self.greedy = greedy
        self.memory = Memory()
        self.batch_size = batch_size
        self.customers_count = customers_count
        self.update_timestep = timestep
        self.epoch = epoch
        self.agent = AgentPPO(customer_feature, vehicle_feature, customers_count, model_size,
                              encoder_layer, num_head, ff_size_actor, ff_size_critic,
                              tanh_xplor, edge_embedding_dim, greedy, learning_rate,
                              ppo_epoch, batch_size, entropy_value, epsilon_clip, max_grad_norm)

    def run_train(self, args, datas, env, env_params, optim, lr_scheduler, device, epoch):

        train_data_loader = DataLoader(datas, batch_size=self.batch_size, shuffle=True)
        #print(self.batch_size)

        memory = Memory()
        self.agent.old_policy.to(device)

        epoch_loss = 0
        epoch_prop = 0
        epoch_val = 0
        epoch_c_val = 0

        self.agent.old_policy.train()
        times, losses, rewards1, critic_rewards = [], [], [], []

        epoch_start = time.time()
        start_time = epoch_start

        with tqdm(train_data_loader, desc="Epoch #{: >3d}/{: <3d}".format(epoch + 1, args.epoch_count)) as progress:

            for batch_index, minibatch in enumerate(progress):

                if datas.customer_mask is None:
                    nodes, customer_mask, edge_attributes = minibatch[0].to(device), None, minibatch[1].to(device)

                nodes = nodes.view(self.batch_size, self.customers_count, 4)
                edge_attributes = edge_attributes.view(self.batch_size, self.customers_count * self.customers_count, 1)

                print(datas.nodes.device, nodes.device, edge_attributes.device)

                dyna_env = env(datas=None, nodes, customer_mask, edge_attributes, *env_params)

                actions, logps, rewards = self.agent.old_policy.act(dyna_env)

                actions = actions.to(torch.device('cpu')).detach()
                logps = logps.to(torch.device('cpu')).detach()
                rewards = rewards.to(torch.device('cpu')).detach()
                
                ## formate the actions for memory
                actions = formate_old_actions(actions)
                actions = torch.tensor(actions)
                actions = actions.permute(1, 0, 2)

                #print(actions.size())

                memory.nodes.extend(minibath[0])
                memory.edge_attributes.extend(minibatch[1])
                memory.rewards.extend(rewards)
                memory.log_probs.extend(logps)
                memory.actions.extend(actions)

                if (batch_index + 1) % self.update_timestep == 0:
                    u_rewards, u_losses, u_critic_rewards = self.agent.update(memory, epoch, datas, env, optim, lr_scheduler, device)
                    #print(u_losses, u_critic_rewards)
                    memory.clear()

                prob = torch.stack([logps]).sum(dim=0).exp().mean()
                val = torch.stack([rewards]).sum(dim=0).mean()
                c_val = torch.tensor(u_critic_rewards).mean()
                u_losses = torch.tensor(u_losses).mean()

                progress.set_postfix_str("l={:.4g} p={:9.4g} val={:6.4g} c_val={:6.4g}".format(
                    u_losses.item(), prob.item(), val.item(), c_val.item()))

                epoch_loss += u_losses.item()
                epoch_prop += prob.item()
                epoch_val += val.item()
                epoch_c_val += c_val.item()

            return tuple(stats / args.iter_count for stats in (epoch_loss, epoch_prop, epoch_val, epoch_c_val))

    def test_epoch(self, args, env, agent, ref_costs):
        agent.eval()
        costs = env.nodes.new_zeros(env.minibatch)

        for _ in range(100):
            _, _, rewards = agent.act(env)
            costs += torch.stack([rewards]).sum(dim=0).squeeze(-1)

        costs = costs / 100

        mean = costs.mean()
        std = costs.std()
        gap = (costs.to(ref_costs.device) / ref_costs - 1).mean()

        print("Cost on test dataset: {:5.2f} +- {:5.2f} ({:.2%})".format(mean, std, gap))
        return mean.item(), std.item(), gap.item()

