In [3]:
import torch
import random
import math
from torch.utils.data import Dataset

import pandas as pd
import numpy as np
import networkx as nx
import sys


In [146]:
class DVRPSR_Dataset(Dataset):

    customer_feature = 4 # customer features location (x_i,y_i) and duration of service(d), appearance (u)

    @classmethod
    def create_data(cls,
                    batch_size = 2,
                    vehicle_count = 2,
                    vehicle_speed = 20, # km/hr
                    Lambda = 0.025, # request rate per min
                    dod = 0.5,
                    horizon = 400,
                    fDmean = 10,
                    fDstd = 2.5):


        # static customer counts V = Lambda*horizon*(1-dod)/(dod+0.5)
        V_static = int(Lambda*horizon*(1-dod)/(dod)+0.5)

        # total customer count
        V = int(Lambda*horizon/(dod) + 0.5)

        size = (batch_size, V, 1)
        
        # initialize the graph of vienna network
        graph = cls.initialize_graph()

        # get the coordinates of customers
        data_vienna = pd.read_csv('./vienna_data/vienna_cordinates.csv')

        # get depot coordinates: Id, xcoords, ycoords
        depot = cls.get_depot_location(data_vienna)

        # get location of customers: id, xcoords, ycoords
        locations = cls.get_customers_coordinates(data_vienna, batch_size, V, depot)

        # get edges index and attributes, which is distance between one node to others n_i*n_j
        edges_index, edges_attributes = cls.get_edges_attributes(batch_size, graph, depot, locations, V)
        
        ### generate Static_Dynamic customer requests
        dynamic_request = cls.generateRandomDynamicRequests(batch_size,
                                                            V,
                                                            V_static,
                                                            fDmean,
                                                            fDstd,
                                                            Lambda,
                                                            horizon)

        customers = torch.zeros((batch_size,V,cls.customer_feature))
        customers[:,:,:2] = locations[:,:,1:]
        customers[:,:,2:4] = dynamic_request



        depo = torch.zeros((batch_size, 1, cls.customer_feature))
        depo[:,:,0:2] = torch.from_numpy(depot[0][1:])
        depo[:,:,2] =  0

        nodes = torch.cat((depo, customers), 1)
        
        dataset = cls(vehicle_count, vehicle_speed, horizon, nodes, V, 
                      edges_index, edges_attributes, customer_mask = None)
        
        return dataset
    
    def __init__(self, vehicle_count, vehicle_speed, horizon, nodes, V,
                 edges_index, edges_attributes, customer_mask=None):
        
        self.vehicle_count = vehicle_count
        self.vehicle_speed = vehicle_speed
        self.nodes = nodes
        self.vehicle_time_budget = horizon
        self.edges_index = edges_index
        self.edges_attributes = edges_attributes

        self.batch_size, self.nodes_count, d = self.nodes.size()

        if d!= self.customer_feature:
            raise ValueError("Expected {} customer features per nodes, got {}".format(
                self.customer_feature, d))

        self.customer_mask = customer_mask
        self.customer_count = V
        
        
    def initialize_graph():
    
        coordinates = pd.read_csv("./vienna_data/vienna_dist.csv", header = None, sep=' ')
        coordinates.columns = ['coord1','coord2','dist']
        graph = nx.DiGraph()

        # add the rows to the graph for shortest path and distance calculations
        for _, row in coordinates.iterrows():
            graph.add_edge(row['coord1'], row['coord2'], weight=row['dist'])

        return graph


    def precompute_shortest_path(graph, start_node, end_node):

        shortest_path = nx.shortest_path(graph, start_node, end_node)

        # TODO: distance need to be normalized afterwords
        shortest_path_length = sum(graph.get_edge_data(u, v)['weight'] 
                                   for u, v in zip(shortest_path, shortest_path[1:]))

        return shortest_path, shortest_path_length 
    
    
    def get_distanceLL(lat1, lon1, lat2, lon2):
    
        R = 6371  # Radius of the Earth in kilometers

        lat1_rad = math.radians(lat1)
        lon1_rad = math.radians(lon1)
        lat2_rad = math.radians(lat2)
        lon2_rad = math.radians(lon2)

        dlat = lat2_rad - lat1_rad
        dlon = lon2_rad - lon1_rad

        a = math.sin(dlat / 2) ** 2 + math.cos(lat1_rad) * math.cos(lat2_rad) * math.sin(dlon / 2) ** 2
        c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
        distance = R * c
        return distance
    

    def get_NearestNodeLL(lat, lon, lats, lons):
        nearest = (-1, sys.float_info.max)
        for i in range(len(lats)):
            dist = DVRPSR_Dataset.get_distanceLL(lat, lon, lats[i], lons[i])
            if dist < nearest[1]:
                nearest = (i, dist)
        return nearest[0]
    


    def get_depot_location(data_vienna):

        ll = (48.178808, 16.438460)
        lat = ll[0] / 180 * math.pi
        lon = ll[1] / 180 * math.pi
        lats = data_vienna['lats']
        lons = data_vienna['lons']
        depot = DVRPSR_Dataset.get_NearestNodeLL(lat, lon, lats, lons)
        depot_coordinates = np.array(data_vienna[data_vienna['id']==depot][['id','xcoords', 'ycoords']])

        return depot_coordinates
    
    def get_customers_coordinates(data_vienna, batch_size, customers_count, depot):
        
        torch.manual_seed(42)

        # Excluding depot id from the customers selection
        data_vienna_without_depot = data_vienna[data_vienna['id'] != int(depot[0][0])].reset_index()

        # Sample customers indices for all batches at once
        sampled_customers = torch.multinomial(torch.tensor(data_vienna_without_depot['id'], dtype=torch.float32),
                                              num_samples=batch_size * customers_count, replacement=True)
        
        sampled_customers = sampled_customers.reshape(batch_size, customers_count)

        # Gather the sampled locations using the indices
        sampled_locations = data_vienna_without_depot.loc[sampled_customers.flatten()].reset_index(drop=True)

        # Reshape the locations to match the batch size
        locations = sampled_locations.groupby(sampled_locations.index // customers_count)

        # Create PyTorch tensors for the batched data
        locations_tensors = []
        for _, batch in locations:
            id_tensor = torch.tensor(batch['id'].values, dtype=torch.long)
            coords_tensor = torch.tensor(batch[['xcoords', 'ycoords']].values, dtype=torch.float32)
            batch_tensor = torch.cat((id_tensor.unsqueeze(1), coords_tensor), dim=1)
            locations_tensors.append(batch_tensor)

        return torch.stack(locations_tensors)
    
    def c_dist(x1,x2):
        return ((x1[0]-x2[0])**2+(x1[1]-x2[1])**2)**0.5
    
    def get_edges_attributes(batch_size, graph, depot, locations, V):
    
        # all customers ID inclusing depot
        
        print('Initialzing edges')
        edge_depot = torch.zeros((batch_size, 1, 2))
        edge_depot[:,:,0] = depot[0][1]
        edge_depot[:,:,1] = depot[0][2]
        edge_data = torch.cat((edge_depot, locations[:,:,1:3]), dim=1)

        # generate edge index
        edges_index = []

        for i in range(V+1):
            for j in range(V+1):
                edges_index.append([i, j])
        edges_index = torch.LongTensor(edges_index)
        edges_index = edges_index.transpose(dim0=0,dim1=1)

        # generate nodes attributes
        edges_batch = []

        for batch in edge_data:
            edges = torch.zeros((V+1, V+1, 1), dtype=torch.float32)
            for i, node1 in enumerate(batch):
                for j, node2 in enumerate(batch):
                    distance = DVRPSR_Dataset.c_dist(node1, node2)
                    edges[i][j][0] = distance

            edges = edges.reshape(-1, 1)
            edges_batch.append(edges)

        return edges_index, torch.stack(edges_batch)
    
    
    def generateRandomDynamicRequests(batch_size=2 ,
                                      V=20,
                                      V_static=10,
                                      fDmean=10,
                                      fDstd=2.5,
                                      Lambda=0.025,
                                      horizon=400,
                                      dep = 0,
                                      u = 0):
        gen = random.Random()
        gen.seed() # uses the default system seed
        unifDist = gen.random # uniform distribution
        durDist = lambda: max(0.01, gen.gauss(fDmean, fDstd)) # normal distribution with fDmean and fDstd

        # TODO: in actual data , we need to add a depo node with corrdinate, which should be removed from selected
        #       nodes as well.

        requests = []
        for b in range(batch_size):

            static_request = []
            dynamic_request = []
            u = 0

            while True:
                unif = unifDist()
                u += -(1/Lambda) * math.log(unif)
                if u > horizon or len(dynamic_request) > (V-V_static+2):
                    break
                d = round(durDist(),2)
                while d<=0:
                    d = round(durDist(),2)

                dynamic_request.append([d, round(u,2)])

            for j in range(V-len(dynamic_request)):
                d = round(durDist(),2)
                while d<=0:
                    d = round(durDist(),2)
                static_request.append([d,0])

            request = static_request+dynamic_request
            random.shuffle(request)
            requests.append(request)

        return torch.tensor(requests)
    
    

    
    def __len__(self):
        return self.batch_size


    def __getitem__(self, i):
        if self.customer_mask is None:
            return self.nodes[i], self.edges_attributes[i]
        else:
            return self.nodes[i], self.customer_mask[i], self.edges_attributes[i]

    def nodes_generate(self):
        if self.customer_mask is None:
            yield from self.nodes
        else:
            yield from (n[m^1] for n,m in zip(self.nodes, self.customer_mask))  
            
            
    def normalize(self):
        loc_max, loc_min = self.nodes[:,:,:2].max().item(), self.nodes[:,:,:2].min().item()
        loc_max -= loc_min
        edge_max_length = self.edges_attributes.max().item()

        self.nodes[:,:,:2] -= loc_min
        self.nodes[:,:,:2] /= loc_max
        self.nodes[:,:,2:] /=self.vehicle_time_budget

        self.vehicle_speed *= self.vehicle_time_budget/loc_max
        self.vehicle_time_budget = 1
        self.edges_attributes /= loc_max
        return loc_max, 1

    def save(self, folder_path):
        torch.save({
            'veh_count':self.vehicle_count,
            'veh_speed':self.vehicle_speed,
            'nodes':self.nodes,
            'edges_index':self.edges_index,
            'edges_attributes':self.edges_attributes,
            'customer_count':self.customer_count,
            'customer_mask':self.customer_mask
        }, folder_path)

    @classmethod
    def load(cls, folder_path):
        return cls(**torch.load(folder_path))
        
        
        
        

In [147]:
## testing Normalize function

data = DVRPSR_Dataset.create_data(batch_size=1, vehicle_count=2, vehicle_speed=, Lambda=0.0125, )

Initialzing edges


In [148]:
data.nodes
nodes = data.nodes
nodes

tensor([[[ 1.5348e+02, -2.7868e+00,  0.0000e+00,  0.0000e+00],
         [ 1.3548e+02,  6.3629e+00,  8.4400e+00,  3.0431e+02],
         [ 1.3853e+02,  4.1748e+01,  1.2450e+01,  4.3070e+01],
         [ 1.5720e+02,  3.8793e+01,  1.2310e+01,  0.0000e+00],
         [ 1.2608e+02,  1.9041e+00,  8.5100e+00,  1.8300e+00],
         [ 1.3274e+02,  7.2597e+00,  9.3400e+00,  1.0158e+02],
         [ 1.1312e+02, -7.5473e-02,  8.4100e+00,  2.9705e+02],
         [ 1.2074e+02,  1.7892e+01,  1.2140e+01,  2.0499e+02],
         [ 1.4845e+02, -8.4152e+00,  1.1550e+01,  4.7340e+01],
         [ 1.3364e+02, -4.9474e+00,  9.4400e+00,  0.0000e+00],
         [ 1.3663e+02,  2.6947e+00,  1.0350e+01,  2.3596e+02]]])

In [161]:
data.edges_attributes

tensor([[[0.0000],
         [0.1219],
         [0.2837],
         [0.2521],
         [0.1678],
         [0.1392],
         [0.2442],
         [0.2339],
         [0.0456],
         [0.1205],
         [0.1070],
         [0.1219],
         [0.0000],
         [0.2145],
         [0.2357],
         [0.0628],
         [0.0174],
         [0.1405],
         [0.1130],
         [0.1187],
         [0.0692],
         [0.0232],
         [0.2837],
         [0.2145],
         [0.0000],
         [0.1141],
         [0.2521],
         [0.2112],
         [0.2955],
         [0.1797],
         [0.3088],
         [0.2835],
         [0.2361],
         [0.2521],
         [0.2357],
         [0.1141],
         [0.0000],
         [0.2914],
         [0.2410],
         [0.3548],
         [0.2538],
         [0.2899],
         [0.3000],
         [0.2509],
         [0.1678],
         [0.0628],
         [0.2521],
         [0.2914],
         [0.0000],
         [0.0516],
         [0.0792],
         [0.1018],
         [0.

In [162]:
data.normalize()

(1.0, 1)

In [151]:
l_max = data.nodes[:,:,:2].max().item()
l_min = data.nodes[:,:,:2].min().item()
print(l_max, l_min)
l_max -= l_min
print(l_max, l_min)

1.0 0.0
1.0 0.0


In [152]:
e_max = data.edges_attributes[:,:,:1].max()
e_max

tensor(0.3548)

In [153]:
data.vehicle_speed*400/e_max

tensor(3267.2869)

In [154]:
data.vehicle_speed*400/l_max

1159.3320040839722

In [155]:
edges = torch.pairwise_distance(data.nodes[:,0,:2], data.nodes[:,1,:2])/l_max
edges

tensor([0.1219])

In [156]:
data.normalize()

(1.0, 1)

In [157]:
data.nodes

tensor([[[0.9776, 0.0340, 0.0000, 0.0000],
         [0.8689, 0.0892, 0.0211, 0.7608],
         [0.8873, 0.3029, 0.0311, 0.1077],
         [1.0000, 0.2851, 0.0308, 0.0000],
         [0.8121, 0.0623, 0.0213, 0.0046],
         [0.8523, 0.0946, 0.0234, 0.2539],
         [0.7339, 0.0504, 0.0210, 0.7426],
         [0.7798, 0.1588, 0.0304, 0.5125],
         [0.9472, 0.0000, 0.0289, 0.1183],
         [0.8577, 0.0209, 0.0236, 0.0000],
         [0.8758, 0.0671, 0.0259, 0.5899]]])

In [164]:
edges_norm = torch.pairwise_distance(data.nodes[:,0,:2], data.nodes[:,2,:2])
edges_norm

tensor([0.2837])

In [165]:
edges_norm/data.vehicle_speed

tensor([0.0979])

In [166]:
data.vehicle_speed

2.8983300102099303

## Environment

In [None]:
import torch


class DVRPSR_Environment:
    vehicle_feature = 8  # vehicle coordinates(x_i,y_i), veh_time_time_budget, total_travel_time, last_customer,
                         # next(destination) customer, last rewards, next rewards
    customer_feature = 4

    # TODO: change pending cost for rewards

    def __init__(self, data, nodes=None, customer_mask=None, edges_attributes = None,
                 pending_cost=1, 
                 dynamic_reward=0.2,
                 budget_penalty = 10):

        self.vehicle_count = data.vehicle_count
        self.vehicle_speed = data.vehicle_speed
        self.vehicle_time_budget = data.vehicle_time_budget

        self.nodes = data.nodes if nodes is None else nodes
        self.edge_index = data.edges_index
        self.edge_attributes = data.edges_attributes if edges_attributes is None else edges_attributes
        self.init_customer_mask = data.customer_mask if customer_mask is None else customer_mask

        self.minibatch, self.nodes_count, _ = self.nodes.size()
        self.distance_matrix = self.edge_attributes.view((self.minibatch, self.nodes_count, self.nodes_count))
        self.pending_cost = pending_cost
        self.dynamic_reward = dynamic_reward
        self.budget_penalty = budget_penalty

    def _update_current_vehicles(self, dest, customer_index, tau=0):

        # calculate travel time
        # TODO: 1) in real world setting we need to calculate the distance of arc
        # If nodes i and j are directly connected by a road segment (i, j) ∈ A, then t(i,j)=t_ij;
        # otherwise, t(i,j)=t_ik1 +t_k1k2 +...+t_knj, where k1,...,kn ∈ V are the nodes along the
        # shortest path from node i to node j.
        #      2) calculate stating time for each vehicle $\tau $, currently is set to zero
        
        # update vehicle previous and next customer id
        self.current_vehicle[:, :, 4] = self.current_vehicle[:, :, 5]
        self.current_vehicle[:, :, 5] = customer_index
        
        # get the distance from current vehicle to its next destination
        dist = torch.zeros((self.minibatch, 1))
        for i in range(self.minibatch):
            dist[i, 0] = self.distance_matrix[i][int(self.current_vehicle[i, :, 4])][int(self.current_vehicle[i, :, 5])]
            
        
        # total travel time    
        tt = dist / self.vehicle_speed

        # customers which are dynamicaly appeared
        dyn_cust = (dest[:, :, 3] > 0).float()

        # budget left while travelling to destination nodes
        budget = tau + tt + dest[:, :, 2]
        # print(budget, tau, tt, dest[:,:,2])

        # update vehicle features based on destination nodes
        self.current_vehicle[:, :, :2] = dest[:, :, :2]
        self.current_vehicle[:, :, 2] -= budget
        self.current_vehicle[:, :, 3] += tt
        self.current_vehicle[:, :, 6] = self.current_vehicle[:, :, 7]
        self.current_vehicle[:, :, 7] = -dist

        # update vehicles states
        self.vehicles = self.vehicles.scatter(1,
                                              self.current_vehicle_index[:, :, None].expand(-1, -1, self.vehicle_feature),
                                              self.current_vehicle)

        return dist, dyn_cust

    def _done(self, customer_index):
        
        self.vehicle_done.scatter_(1, self.current_vehicle_index, torch.logical_or((customer_index == 0), 
                                                                     (self.current_vehicle[:, :, 2] <= 0)))
        # print(self.veh_done, cust_idx==0,self.cur_veh[:,:,2]<=0, (cust_idx==0) | (self.cur_veh[:,:,2]<=0))
        self.done = bool(self.vehicle_done.all())

    def _update_mask(self, customer_index):

        self.new_customer = False
        self.served.scatter_(1, customer_index, customer_index > 0)

        # cost for a vehicle to go to customer and back to deport considering service duration
        cost = torch.zeros((self.minibatch, self.nodes_count,1))
        for i in range(self.minibatch):
            for j in range(self.nodes_count):
                dist_vehicle_customer_depot = self.distance_matrix[i][int(self.current_vehicle[i, :, 4])][j] + \
                                              self.distance_matrix[i][j][0]
                cost[i,j] = dist_vehicle_customer_depot
                
        cost = cost / self.vehicle_speed

        cost += self.nodes[:, :, None, 2]

        overtime_mask = self.current_vehicle[:, :, None, 2] - cost
        overtime_mask = overtime_mask.squeeze(2).unsqueeze(1)
        overtime = torch.zeros_like(self.mask).scatter_(1,
                                                        self.current_vehicle_index[:, :, None].expand(-1, -1, self.nodes_count),
                                                        overtime_mask < 0)

        self.mask = self.mask | self.served[:, None, :] | overtime | self.vehicle_done[:, :, None]
        self.mask[:, :, 0] = 0  # depot

    # updating current vehicle to find the next available vehicle
    def _update_next_vehicle(self, veh_index=None):
        
        if veh_index is None:
            avail = self.vehicles[:, :, 3].clone()
            avail[self.vehicle_done] = float('inf')
            self.current_vehicle_index = avail.argmin(1, keepdim=True)
        else:
            self.current_vehicle_index = veh_index
            
        self.current_vehicle = self.vehicles.gather(1, self.current_vehicle_index[:, :, None].expand(-1, -1, self.vehicle_feature))
        self.current_vehicle_mask = self.mask.gather(1, self.current_vehicle_index[:, :, None].expand(-1, -1, self.nodes_count))

    def _update_dynamic_customers(self):

        time = self.current_vehicle[:, :, 3].clone()

        if self.init_customer_mask is None:
            reveal_dyn_reqs = torch.logical_and((self.customer_mask), (self.nodes[:, :, 3] <= time))
        else:
            reveal_dyn_reqs = torch.logical_and((self.customer_mask ^ self.init_customer_mask), (self.nodes[:, :, 3] <= time))

        if reveal_dyn_reqs.any():
            self.new_customer = True
            self.customer_mask = self.customer_mask ^ reveal_dyn_reqs
            self.mask = self.mask ^ reveal_dyn_reqs[:, None, :].expand(-1, self.vehicle_count, -1)
            self.vehicle_done = torch.logical_and(self.vehicle_done, (reveal_dyn_reqs.any(1) ^ True).unsqueeze(1))
            self.vehicles[:, :, 3] = torch.max(self.vehicles[:, :, 3], time)
            self._update_next_vehicle()

    def reset(self):
        # reset vehicle (minibatch*veh_count*veh_feature)
        self.vehicles = self.nodes.new_zeros((self.minibatch, self.vehicle_count, self.vehicle_feature))
        self.vehicles[:, :, :2] = self.nodes[:, :1, :2]
        self.vehicles[:, :, 2] = self.vehicle_time_budget

        # reset vehicle done
        self.vehicle_done = self.nodes.new_zeros((self.minibatch, self.vehicle_count), dtype=torch.bool)
        self.done = False

        # reset cust_mask
        self.customer_mask = self.nodes[:, :, 3] > 0
        if self.init_customer_mask is not None:
            self.customer_mask = self.customer_mask | self.init_customer_mask

        # reset new customers and served customer since now to zero (all false)
        self.new_customer = True
        self.served = torch.zeros_like(self.customer_mask)

        # reset mask (minibatch*veh_count*nodes)
        self.mask = self.customer_mask[:, None, :].repeat(1, self.vehicle_count, 1)

        # reset current vehicle index, current vehicle, current vehicle mask
        self.current_vehicle_index = self.nodes.new_zeros((self.minibatch, 1), dtype=torch.int64)
        
        self.current_vehicle = self.vehicles.gather(1, 
                                                    self.current_vehicle_index[:, :, None].expand(-1, -1, self.vehicle_feature))
        self.current_vehicle_mask = self.mask.gather(1, 
                                             self.current_vehicle_index[:, :, None].expand(-1, -1, self.nodes_count))

    
    def step(self, customer_index, veh_index = None):
        dest = self.nodes.gather(1, customer_index[:, :, None].expand(-1, -1, self.customer_feature))
        dist, dyn_cust = self._update_current_vehicles(dest, customer_index)

        #cust = (dest[:, :, 3] >= 0).float()

        self._done(customer_index)
        self._update_mask(customer_index)
        self._update_next_vehicle(veh_index)

        #reward = -dist * (1 - dyn_cust*self.dynamic_reward)
        reward = self.current_vehicle[:, :, 7] - self.current_vehicle[:,:,6] + self.dynamic_reward*dyn_cust
        pending_static_customers = torch.logical_and((self.served ^ True), 
                                                     (self.nodes[:, :, 3] == 0)).float().sum(-1,keepdim=True) - 1
        
        reward -= self.pending_cost*pending_static_customers

        if self.done:

            if self.init_customer_mask is not None:
                self.served += self.init_customer_mask
            # penalty for pending customers
            pending_customers = torch.logical_and((self.served ^ True), 
                                                  (self.nodes[:, :, 3] >= 0)).float().sum(-1, keepdim=True) - 1

            # TODO: penalty for having unused time budget as well not serving customers
            reward -= self.dynamic_reward*pending_customers
            

        self._update_dynamic_customers()

        return reward

    def state_dict(self, dest_dict=None):
        if dest_dict is None:
            dest_dict = {'vehicles': self.vehicles,
                         'vehicle_done': self.vehicle_done,
                         'served': self.served,
                         'mask': self.mask,
                         'current_vehicle_index': self.current_vehicle_index}

        else:
            dest_dict["vehicles"].copy_(self.vehicles)
            dest_dict["vehicle_done"].copy_(self.vehicle_done)
            dest_dict["served"].copy_(self.served)
            dest_dict["mask"].copy_(self.mask)
            dest_dict["current_vehicle_index"].copy_(self.current_vehicle_index)

        return dest_dict

    def load_state_dict(self, state_dict):
        self.vehicles.copy_(state_dict["vehicles"])
        self.vehicle_done.copy_(state_dict["vehicle_done"])
        self.served.copy_(state_dict["served"])
        self.mask.copy_(state_dict["mask"])
        self.current_vehicle_index.copy_(state_dict["current_vehicle_index"])

        self.current_vehicle = self.vehicles.gather(1, 
                                                    self.current_vehicle_index[:, :, None].expand(-1, -1, self.vehicle_feature))
        self.current_vehicle_mask = self.mask.gather(1, self.current_vehicle_index[:, :, None].expand(-1, -1, self.customer_feature))




In [None]:
def reinforce_loss(logprobs, rewards, baseline = None, weights = None, discount = 1.0, reduction = 'mean'):
    r"""
    :param logprobs:  Iterable of length :math:`L` on tensors of size :math:`N \times 1`
    :param rewards:   Iterable of length :math:`L` on tensors of size :math:`N \times 1`
                    or single tensor of size :math:`N \times 1` to use rewards cumulated on the whole trajectory
    :param baseline:  Iterable of length :math:`L` on tensors of size :math:`N \times 1`
                    or single tensor of size :math:`N \times 1` to use rewards cumulated on the whole trajectory
    :param weights:   Iterable of length :math:`L` on tensors of size :math:`N \times 1`
    :param discount:  Discount applied to cumulated future reward
    :param reduction: 'none' No reduction,
                      'sum'  Compute sum of loss on batch,
                      'mean' Compute mean of loss on batch
    """
    if weights is None:
        weights = repeat(1.0)

    if isinstance(rewards, torch.Tensor):
        if baseline is None:
            baseline = torch.zeros_like(rewards)

        loss = torch.stack([-logp * w for logp,w in zip(logprobs, weights)]).sum(dim = 0)
        loss *= (rewards - baseline.detach())

        if baseline.requires_grad:
            loss += F.smooth_l1_loss(baseline, rewards)

    else:
        if baseline is None:
            baseline = repeat(torch.zeros_like(rewards[0]))

        cumul = torch.zeros_like(rewards[0])
        vals = []
        for r in reversed(rewards):
            cumul = r + discount * cumul
            vals.append(cumul)
        vals.reverse()

        loss = []
        bl_loss = []
        for val, logp, bl, w in zip(vals, logprobs, baseline, weights):
            loss.append( -logp * (val - bl.detach()) * w )
            if bl.requires_grad:
                bl_loss.append( F.smooth_l1_loss(bl, val) )
        loss = torch.stack(loss).sum(dim = 0)

        if bl_loss:
            loss += torch.stack(bl_loss).sum(dim = 0)

    if reduction == 'none':
        return loss
    elif reduction == 'sum':
        return loss.sum()
    else: # reduction == 'mean'
        return loss.mean()


In [None]:
data = DVRPSR_Dataset.create_data(batch_size=4, 
                                  vehicle_count=2,
                                  vehicle_speed=1/3, 
                                  Lambda=0.2, 
                                  dod=0.75, 
                                  horizon=600)
data.normalize()
env = DVRPSR_Environment(data)
env.reset()

In [None]:
#device = 'mps' if torch.device('mps') else 'cpu'
device = 'cpu'
env.nodes = env.nodes.to(device)
env.edge_attributes = env.edge_attributes.to(device)
env.edge_index = env.edge_index.to(device)
env.edge_index.size()

## Encoder 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
import math
from torch.distributions.categorical import Categorical
from torch.optim.lr_scheduler import LambdaLR
import time

## Graph Encoder

In [None]:
class GraphMultiHeadAttention(nn.Module):
    
    def __init__(self, num_head, query_size, key_size = None, value_size = None, edge_dim_size = None, bias = False):
        
        super(GraphMultiHeadAttention, self).__init__()
        self.num_head = num_head
        self.query_size = query_size
        
        self.key_size = self.query_size if key_size is None else key_size
        self.value_size = self.key_size if value_size is None else value_size
        self.edge_dim_size = self.query_size//2 if edge_dim_size is None else edge_dim_size
        
        self.scaling_factor = self.key_size**-0.5
        
        self.keys_per_head = self.key_size // self.num_head
        self.values_per_head = self.value_size // self.num_head
        self.edge_size_per_head = self.edge_dim_size
        
        self.edge_embedding = nn.Linear(self.edge_dim_size, self.edge_size_per_head, bias = bias)
        self.query_embedding = nn.Linear(self.query_size, self.num_head * self.keys_per_head, bias = bias)
        self.key_embedding = nn.Linear(self.key_size, self.num_head * self.keys_per_head, bias = bias)
        self.value_embedding = nn.Linear(self.value_size, self.num_head * self.values_per_head, bias = bias)
        self.recombine = nn.Linear(self.num_head * self.values_per_head, self.value_size, bias = bias)
        
        
        self.K_project_pre = None
        self.V_project_pre = None
        
        self.initialize_weights()


    def initialize_weights(self):
        #TODO: add xavier initialziation as well
        
        nn.init.uniform_(self.query_embedding.weight, -self.scaling_factor, self.scaling_factor)
        nn.init.uniform_(self.key_embedding.weight, -self.scaling_factor, self.scaling_factor)
        inv_sq_dv = self.value_size ** -0.5
        nn.init.uniform_(self.value_embedding.weight, -inv_sq_dv, inv_sq_dv)
        
    def precompute(self, keys, values = None):
        
        values = keys if values is None else values
        
        size_KV = keys.size(-2)
        
        self.K_project_pre = self.key_embedding(keys).view(
                             -1, size_KV, self.num_head, self.keys_per_head).permute(0, 2, 3, 1)
        
        self.V_project_pre = self.value_embedding(values).view(
                              -1, size_KV, self.num_head, self.values_per_head).permute(0, 2, 1, 3)
        
        
    def forward(self, queries, keys = None, values = None, edge_attributes = None, mask = None, edge_mask=None):
        
        *batch_size, size_Q, _ = queries.size()
        
        # get queries projection
        Q_project = self.query_embedding(queries).view(
                              -1, size_Q, self.num_head, self.keys_per_head).permute(0, 2, 1, 3)
        
        # get keys projection
        if keys is None:
            if self.K_project_pre is None:
                size_KV = size_Q
                K_project = self.key_embedding(queries).view(
                                -1, size_KV, self.num_head, self.keys_per_head).permute(0, 2, 3, 1)
            else:
                size_KV = self.K_project_pre.size(-1)
                K_project = self.K_project_pre
        else:
            size_KV = keys.size(-2)
            K_project = self.key_embedding(keys).view(
                            -1, size_KV, self.num_head, self.keys_per_head).permute(0, 2, 3, 1)
            
         # get values projection   
        if values is None:
            if self.V_project_pre is None:
                V_project = self.value_embedding(queries).view(
                                -1, size_KV, self.num_head, self.values_per_head).permute(0, 2, 1, 3)
            else:
                V_project = self.V_project_pre
        else:
            V_project = self.value_embedding(values).view(
                            -1, size_KV, self.num_head, self.values_per_head).permute(0, 2, 1, 3)
            
        
        # calculate the compability
        attention = Q_project.matmul(K_project)
        attention *= self.scaling_factor
        
        # if edge attributes are required
        if edge_attributes is not None:
                
            #TODO: edge mask (is it required)
            edge_project = self.edge_embedding(edge_attributes).view(
                                -1, size_Q, size_Q, self.edge_size_per_head)
                
            # get enhanced attention inclusing edge attributes
            attention_expanded = attention.unsqueeze(-1).expand(-1, -1, -1, -1, self.edge_size_per_head)
            
            # Expand edge attributes to match the number of attention heads
            edge_project_expanded = edge_project.unsqueeze(1).expand(-1, attention.size(1), -1, -1, -1)
            
            attention = attention_expanded * edge_project_expanded
            attention = attention.mean(-1)
            
            #print(attention.size())
        
            
            
        if mask is not None:

            if mask.numel() * self.num_head == attention.numel():
                m = mask.view(-1, 1, size_Q, size_KV).expand_as(attention)
            else:
                m = mask.view(-1, 1, 1, size_KV).expand_as(attention)

            attention[m.bool()] = -float('inf')
            
        attention = F.softmax(attention, dim=-1)
        attention = attention.matmul(V_project).permute(0, 2, 1, 3).contiguous().view(
                                                *batch_size, size_Q, self.num_head * self.values_per_head)
        
        output = self.recombine(attention)
        
        return output
        
             

In [None]:
class GraphEncoderlayer(nn.Module):
    
    def __init__(self, num_head, model_size, ff_size):
        super().__init__()
        
        self.attention = GraphMultiHeadAttention(num_head, query_size=model_size)
        self.BN1 = nn.BatchNorm1d(model_size)
        self.FFN_layer1 = nn.Linear(model_size, ff_size)
        
        self.FFN_layer2 = nn.Linear(ff_size, model_size)
        self.BN2 = nn.BatchNorm1d(model_size)
        
    def forward(self, h, e = None, mask = None):
        
        h_attn = self.attention(h, edge_attributes = e, mask=mask)
        h_bn = self.BN1((h_attn + h).permute(0, 2, 1)).permute(0, 2, 1)
        
        h_layer1 = F.relu(self.FFN_layer1(h_bn))
        h_layer2 = self.FFN_layer2(h_layer1)
        
        h_out = self.BN2((h_bn + h_layer2).permute(0, 2, 1)).permute(0, 2, 1)
        
        if mask is not None:
            h_out[mask] = 0
            
        return h_out
    
class GraphEncoder(nn.Module):
    
    def __init__(self, encoder_layer, num_head, model_size, ff_szie):
        super().__init__()
        
        for l in range(encoder_layer):
            self.add_module(str(l), GraphEncoderlayer(num_head, model_size, ff_szie))
            
    def forward(self, h_in, e_in = None,  mask=None):
        
        h = h_in
        e = e_in
        
        for child in self.children():
            h = child(h, e, mask=mask)
        return h
                

In [None]:
## tetsing the encoder-decoder model

customer_encoder = GraphEncoder(encoder_layer=3, num_head=8, model_size=128, ff_szie=256)

cust_feature = 4
edge_feature = 1
model_size = 128
edge_embedding_dim = 64
customers = env.nodes
edges = env.edge_attributes
customer_mask = env.customer_mask
vehicle_mask = env.current_vehicle_mask

depot_embedding = nn.Linear(cust_feature, model_size)
cust_embedding = nn.Linear(cust_feature, model_size)
edge_embedding = nn.Linear(edge_feature, edge_embedding_dim)

cust_emb = torch.cat((
            depot_embedding(customers[:, 0:1, :]),
            cust_embedding(customers[:, 1:, :])), dim=1)

edge_emb = edge_embedding(edges)

if customer_mask is not None:
    cust_emb[customer_mask] = 0

print(edge_emb.size())

cust_encoding = customer_encoder(cust_emb, edge_emb)

## Graph Model

In [None]:

class GraphAttentionModel(nn.Module):
    
    def __init__(self, customer_feature, vehicle_feature, model_size=128, encoder_layer=3, 
                 num_head=8, ff_size=128, tanh_xplor=10, edge_embedding_dim = 64, greedy = False):
        super().__init__()
        
        # get models parameters for encoding-decoding
        self.model_size = model_size
        self.scaling_factor = self.model_size**0.5
        self.tanh_xplor = tanh_xplor
        self.greedy = greedy
        
        # Initialize encoder and embeddings
        self.customer_encoder = GraphEncoder(encoder_layer=3, num_head=8, model_size=128, ff_szie=512)
        self.customer_embedding = nn.Linear(customer_feature, model_size)
        self.depot_embedding = nn.Linear(customer_feature, model_size)
        
        # initialize edge embedding
        self.edge_embedding = nn.Linear(1, edge_embedding_dim)
        
        # Initialize vehicle embedding and encoding
        #self.vehicle_embedding = nn.Linear(vehicle_feature, ff_size, bias=False)
         
        self.fleet_attention = GraphMultiHeadAttention(num_head, vehicle_feature, model_size)
        
        self.vehicle_attention = GraphMultiHeadAttention(num_head, model_size)
        
        # customer projection
        self.customer_projection = nn.Linear(self.model_size, self.model_size) # TODO: MLP instaed of nn.Linear
        
        
    
    def encode_customers(self, env, customer_mask = None):
        
        customer_emb = torch.cat((self.depot_embedding(env.nodes[:,:1,:]),
                                  self.customer_embedding(env.nodes[:,1:,:])), dim=1)
        if customer_mask is not None:
            customer_emb[customer_mask] = 0
            
        edge_emb = self.edge_embedding(env.edge_attributes)
        
        self.customer_encoding = self.customer_encoder(customer_emb, edge_emb, mask = customer_mask)
        
        self.fleet_attention.precompute(self.customer_encoding)
        
        self.customer_representation = self.customer_projection(self.customer_encoding)
        if customer_mask is not None:
            self.customer_representation[customer_mask] = 0
            
    
        
    def vehicle_representation(self, vehicles, vehicle_index, vehicle_mask=None):
        
        #vehicles_embedding = self.vehicle_embedding(vehicles)
        
        #print(vehicles_embedding.size(), self.customer_representation.size())
        
        fleet_representation = self.fleet_attention(vehicles, mask = vehicle_mask)
        
#         print(fleet_representation.size())
        
        vehicle_query = fleet_representation.gather(0, vehicle_index.unsqueeze(2).expand(-1, -1, self.model_size))
        
        self._vehicle_representation = self.vehicle_attention(vehicle_query, 
                                                              fleet_representation,
                                                              fleet_representation)
        
        return self._vehicle_representation
    
    
    
    def score_customers(self, vehicle_representation):
        
        #print(vehicle_representation.size(), self.customer_representation.size())
        compact = torch.bmm(vehicle_representation,
                            self.customer_representation.transpose(2,1))
        compact *= self.scaling_factor
        
        if self.tanh_xplor is not None:
            compact = self.tanh_xplor*compact.tanh()
        
        return compact
    
    
    def get_prop(self, compact, vehicle_mask = None):
        
        compact = compact
        
        compact[vehicle_mask] = -float('inf')
        compact = F.softmax(compact, dim=-1)
        return compact
    
    
    def step(self, env, old_action=None):
        
        _vehicle_representation = self.vehicle_representation(env.vehicles, 
                                                              env.current_vehicle_index,
                                                              env.current_vehicle_mask)
        
        compact = self.score_customers(_vehicle_representation)
        prop = self.get_prop(compact, env.current_vehicle_mask)
        #print(compact.size())
        
        # step actions based on model act or evalaute
        if old_action is not None:
            
            # get entropy
            dist = Categorical(prop)
            old_actions_logp = dist.log_prob(old_action[:, 1].unsqueeze(-1))
            entropy = dist.entropy()
            
            is_done = float(env.done)
            
            entropy = entropy * (1. - is_done)
            old_actions_logp = old_actions_logp*(1. - is_done)
            return old_action[:, 1].unsqueeze(-1), entropy, old_actions_logp
            
            
        else:
            dist = Categorical(prop)
            
            if self.greedy:
                _, customer_index = p.max(dim=-1)
            else:
                customer_index = dist.sample()
                
            is_done = float(env.done)

            logp = dist.log_prob(customer_index)
            logp = logp * (1. - is_done)
            
            return customer_index, logp
    
    
    
    def forward(self, env, old_actions=None, is_update=False):
        
        if is_update:
            env.reset()
            entropys, old_actions_logps = [], []
            
            steps = old_actions.size(0)

            for i in range(steps):
                if env.new_customer:
                    self.encode_customers(env, env.customer_mask)
                    
                
                if i < steps-1:
                    old_action = old_actions[i,:,:]
                    next_action = old_actions[i+1,:,:]
                else:
                    # this would be the last action which the agent takes and envrionment is done
                    old_action = old_actions[i,:,:]
                    next_action = old_actions[i,:,:]
                    
        
                next_vehicle_index = next_action[:,0].unsqueeze(-1)
                #print(next_vehicle_index)
            

                customer_index, entropy, logp = self.step(env, old_action)
                
                env.step(customer_index, next_vehicle_index)
                
                old_actions_logps.append(logp)
                entropys.append(entropy)
                
            entropys = torch.cat(entropys, dim=1)
            num_e = entropys.ne(0).float().sum(1)
            entropy = entropys.sum(1) / num_e
            
            old_actions_logps = torch.cat(old_actions_logps, dim=1)
            old_actions_logps = old_actions_logps.sum(1)
            
                
            return entropy, old_actions_logps, 0
        
        else:
            env.reset()
            actions, logps, rewards = [], [], []

            while not env.done:
                if env.new_customer:
                    self.encode_customers(env, env.customer_mask)

                customer_index, logp = self.step(env)
                actions.append((env.current_vehicle_index, customer_index))
                logps.append(logp)
                rewards.append(env.step(customer_index))

            #actions = torch.cat(actions, dim=1)
            logps = torch.cat(logps, dim=1)
            logp = logps.sum(dim=1)
            
            rewards = torch.cat(rewards, dim=1)
            rewards = rewards.sum(dim=1)
            

            return actions, logp, rewards

            
        
        

In [None]:
## testing GraphAttentionModel

model = GraphAttentionModel(4, 8)
actions, logps, rewards = model(env, old_actions=None, is_update = False)
print("Forward pass ok for DVRPSR")

from itertools import repeat
#loss = reinforce_loss(logps, rewards)
#loss.backward()
print("Backward pass ok for DVRPSR")


In [None]:
logps

In [None]:
def formate_old_actions(actions):
    old_actions = []

    for action in actions:
        old_action = []
        for i in range(action[0].size(0)):
            old_action.append([action[0][i].item(), action[1][i].item()])
        old_actions.append(old_action)             
    return old_actions

In [None]:
old_actions = formate_old_actions(actions)
old_actions = torch.tensor(old_actions)
old_actions[0]

## PPO agent

In [None]:
class Critic(nn.Module):
    
    # critic will take environment as imput and ouput the values for loss function 
    # which is basically the estimation of complexity of actions
    
    def __init__(self, model, customers_count, ff_size = 512):
        super(Critic, self).__init__()
        
        self.model = model
        self.ff_layer1 = nn.Linear(customers_count, ff_size)
        self.ff_layer2 = nn.Linear(ff_size, customers_count)
        
    def eval_step(self, env, compatibility, customer_index):
        compact = compatibility.clone()
        compact[env.current_vehicle_mask] = 0
        
        value = self.ff_layer1(compact)
        value = F.relu(value)
        value = self.ff_layer2(value)
        
        val = value.gather(2, customer_index.unsqueeze(1)).expand(-1, 1, -1)
        return val.squeeze(1)
        
        
    def __call__(self, env):
        self.model.encode_customers(env)
        env.reset()
        
        values = []
        
        while not env.done:
            
            _vehicle_presentation = self.model.vehicle_representation(env.vehicles,
                                                                     env.current_vehicle_index,
                                                                     env.current_vehicle_mask)
            compatibility = self.model.score_customers(_vehicle_presentation)
            prop = self.model.get_prop(compatibility, env.current_vehicle_mask)
            dist = Categorical(prop)
            customer_index = dist.sample()
            
            values.append(self.eval_step(env, compatibility, customer_index))
            
            return values[0]
            
            

In [None]:
critic_net = Critic(model, 161, 512)
critic_net(env)

In [None]:
class Actor_Critic(nn.Module):
    
    def __init__(self,
                customer_feature,
                vehicle_feature,
                customers_count,
                model_size = 128,
                encoder_layer = 3,
                num_head = 8,
                ff_size_actor = 128,
                ff_size_critic = 512,
                tanh_xplor = 10,
                edge_embedding_dim = 64,
                greedy = False):
        
        super(Actor_Critic, self).__init__()
        
        model = GraphAttentionModel(customer_feature, vehicle_feature, model_size, encoder_layer, 
                                        num_head, ff_size_actor, tanh_xplor, edge_embedding_dim, greedy)
        self.actor = model
        
        self.critic = Critic(model, customers_count, ff_size_critic)
        
    def act(self, env, old_actions=None, is_update=False):
        
        actions, logps, rewards = self.actor(env)
        return actions, logps, rewards
    
    
    def evaluate(self, env, old_actions, is_update):
        
        entropys, old_logps, _ = self.actor(env, old_actions, is_update)
        values = self.critic(env)
        return entropys, old_logps, values
        

In [None]:
class Memory:
    
    def __init__(self):
        
        self.nodes = []
        self.edge_attributes = []
        self.actions = []
        self.rewards = []
        self.log_probs = []
        
    def clear(self):
        self.nodes.clear()
        self.edge_attributes.clear()
        self.actions.clear()
        self.rewards.clear()
        self.log_probs.clear()
        
        

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
import math
from torch.distributions.categorical import Categorical
from torch.optim.lr_scheduler import LambdaLR
import time

# from PPORolloutBaselin import RolloutBaseline
from sklearn.preprocessing import MinMaxScaler


device = 'cpu'
max_grad_norm = 2

class AgentPPO:
    
    def __init__(self,
                customer_feature,
                vehicle_feature,
                customers_count,
                model_size = 128,
                encoder_layer = 3,
                num_head = 8,
                ff_size_actor = 128,
                ff_size_critic = 512,
                tanh_xplor = 10,
                edge_embedding_dim = 64,
                greedy = False,
                learning_rate = 3e-4,
                ppo_epoch = 3,
                batch_size = 4,
                entropy_value = 0.2,
                epsilon_clip = 0.2):
        
        self.policy = Actor_Critic(customer_feature, vehicle_feature, customers_count, model_size = 128,
                                   encoder_layer = 3, num_head = 8, ff_size_actor = 128, ff_size_critic = 512,
                                   tanh_xplor = 10, edge_embedding_dim = 64, greedy = False)
        
        self.old_policy = Actor_Critic(customer_feature, vehicle_feature, customers_count, model_size = 128,
                                   encoder_layer = 3, num_head = 8, ff_size_actor = 128, ff_size_critic = 512,
                                   tanh_xplor = 10, edge_embedding_dim = 64, greedy = False)
        
        self.old_policy.load_state_dict(self.policy.state_dict())
        
        # ppo update parameters
        self.learning_rate = learning_rate
        self.ppo_epoch = ppo_epoch
        self.batch_size = batch_size
        self.entropy_value = entropy_value
        self.epsilon_clip = epsilon_clip
        self.batch_index = 1
        
        # initialize the Adam optimizer
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.MSE_loss = nn.MSELoss()
        
        # actor-critic parameters
        self.customer_feature = customer_feature
        self.vehicle_feature = vehicle_feature
        self.customers_count = customers_count
        self.model_size = model_size
        self.encoder_layer = encoder_layer
        self.num_head = num_head
        self.ff_size_actor = ff_size_actor
        self.ff_size_critic = ff_size_critic
        self.tanh_xplor = tanh_xplor
        self.edge_embedding_dim = edge_embedding_dim
        self.greedy = greedy
        
        self.times, self.losses, self.rewards, self.critic_rewards = [], [], [], []
        
    def advantage_normalization(self, advantage):
        
        std = advantage.std()
        
        assert std != 0. and not torch.isnan(std), 'Need nonzero std'
        
        norm_advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
        return norm_advantage
    
    def pad_actions(self, actions):
        max_len = max([a.size(0) for a in actions])
        padded_actions = []
        for a in actions:
            pad_length = max_len - a.size(0)
            padded_a = F.pad(a, (0, 0, 0, pad_length))
            padded_actions.append(padded_a)
        return torch.stack(padded_actions), max_len

        
    
    
    
    def update(self, memory, epoch, data=None, env=None):
        
        old_nodes = torch.stack(memory.nodes)
        old_edge_attributes = torch.stack(memory.edge_attributes)
        old_rewards = torch.stack(memory.rewards).unsqueeze(-1)
        old_log_probs = torch.stack(memory.log_probs).unsqueeze(-1)
        
        # preprocessing on old actions
        padded_actions, max_length = self.pad_actions(memory.actions)
        
        # create update data for PPO
        datas = []
        
        #print(memory.actions.size())
        
        for i in range(old_nodes.size(0)):
            
            data_to_load = Data(nodes = old_nodes[i],
                                edge_attributes = old_edge_attributes[i],
                                actions = padded_actions[i],
                                rewards = old_rewards[i],
                                log_probs = old_log_probs[i])
            
            datas.append(data_to_load)
        #print(datas[0], self.batch_size)
            
        
            
        self.policy.to(device)
        
        data_loader = DataLoader(datas, batch_size=self.batch_size, shuffle=False)
        
        scheduler = LambdaLR(self.optimizer, lr_lambda=lambda f: 0.96 ** epoch)
        value_buffer = 0
        
        env = env if env is not None else DVRPSR_Environment
        
        for i in range(self.ppo_epoch):
            
            self.policy.train()
            epoch_start = time.time()
            start = epoch_start
            
            self.times, self.losses, self.rewards, self.critic_rewards = [], [], [], []
            
            for batch_index, minibatch_data in enumerate(data_loader):
                
                self.batch_index += 1
                
                if data.customer_mask is None:
                    nodes = minibatch_data.nodes.to(device)
                    customer_mask = None
                    edge_attributes = minibatch_data.edge_attributes.to(device)
                    
                nodes = nodes.view(self.batch_size, self.customers_count, self.customer_feature)
                edge_attributes = edge_attributes.view(self.batch_size, self.customers_count*self.customers_count,1)
                
                old_actions_for_env = minibatch_data.actions.view(self.batch_size, max_length,2).permute(1,0,2)
                
                dyna_env = env(data, nodes, customer_mask, edge_attributes)
                
                entropy, log_probs, values = self.policy.evaluate(dyna_env, old_actions_for_env, True)
                
                # normalize the rewards and get the MSE loss with critics values
                R = minibatch_data.rewards
                R_norm = self.advantage_normalization(R)
                
                print(R_norm, values.size())
                
                mse_loss = self.MSE_loss(R_norm, values)
                
                # PPO ration (r(0)_t)
                ratio = torch.exp(log_probs - minibatch_data.log_probs)
                
                # PPO advantage
                advantage = R_norm - values.detach()
                
                # PPO overall loss function
                actor_loss1 = ratio * advantage
                actor_loss2 = torch.clamp(ratio, 1-self.epsilon_clip, 1+self.epsilon_clip)*advantage
                
                actor_loss = torch.min(actor_loss1, actor_loss2)
                
                # total loss
                loss = actor_loss + 0.5*mse_loss - self.entropy_value*entropy
                
                # optimizer and backpropogation
                self.optimizer.zero_grad()
                loss.mean().backward()
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), max_grad_norm)
                self.optimizer.step()

                scheduler.step()

                self.rewards.append(torch.mean(R_norm.detach()).item())
                self.losses.append(torch.mean(loss.detach()).item())
                self.critic_rewards.append(torch.mean(values.detach()).item())
            
        self.old_policy.load_state_dict(self.policy.state_dict())
        
# if __name__ == '__main__':
#     raise Exception('Cannot be called from main')
                
                
        

## Train PPO Agent

In [None]:
import os
import time
import torch

from collections import OrderedDict
from collections import namedtuple
from itertools import product
import numpy as np



device = torch.device('cpu' if torch.backends.mps.is_available()else 'cpu')

customers_count = 21

class TrainPPOAgent:
    
    def __init__(self,
                customer_feature,
                vehicle_feature,
                customers_count,
                model_size = 128,
                encoder_layer = 3,
                num_head = 8,
                ff_size_actor = 128,
                ff_size_critic = 512,
                tanh_xplor = 10,
                edge_embedding_dim = 64,
                greedy = False,
                learning_rate = 3e-4,
                ppo_epoch = 3,
                batch_size = 4,
                entropy_value = 0.2,
                epsilon_clip = 0.2,
                epoch = 40,
                timestep = 2):
        
        self.greedy = greedy
        self.memory = Memory()
        self.batch_size = batch_size
        self.update_timestep = timestep
        self.epoch = epoch
        self.agent = AgentPPO(customer_feature, vehicle_feature, customers_count, model_size = 128,
                              encoder_layer = 3, num_head = 8, ff_size_actor = 128, ff_size_critic = 512,
                              tanh_xplor = 10, edge_embedding_dim = 64, greedy = False, learning_rate = 3e-4,
                              ppo_epoch = 3, batch_size = 4, entropy_value = 0.2, epsilon_clip = 0.2)
        
        
    def run_train(self, datas, env, batch_size):
        
        train_data_loader = DataLoader(datas, batch_size=self.batch_size, shuffle=True)
        
        memory = Memory()
        self.agent.old_policy.to(device)
        
        costs = []
        
        for i in range(self.epoch):
            
            print('running epoch {}'.format(i))
            self.agent.old_policy.train()
            times, losses, rewards1, critic_rewards = [], [], [], []

            
            epoch_start = time.time()
            start_time = epoch_start
            
            for batch_index, minibatch in enumerate(train_data_loader):
                
                if datas.customer_mask is None:
                    nodes, customer_mask, edge_attributes = minibatch[0], None, minibatch[1]
                    
                #print(nodes.size(), edge_attributes.size())
                
                nodes = nodes.view(batch_size, customers_count, 4)
                edge_attributes = edge_attributes.view(batch_size, customers_count*customers_count, 1)
                
                nodes.to(device)
                edge_attributes.to(device)
                
                #print(nodes.size(), self.batch_size)
                
                
                dyna_env = env(datas, nodes, customer_mask, edge_attributes)
                
                actions, logps, rewards = self.agent.old_policy.act(dyna_env)
                
                ## formate the actions for memory
                actions = formate_old_actions(actions)
                actions = torch.tensor(actions)
                actions = actions.permute(1, 0, 2)
                
                actions = actions.to(torch.device('cpu')).detach()
                logps   = logps.to(torch.device('cpu')).detach()
                rewards = rewards.to(torch.device('cpu')).detach()
                
                for i in range(self.batch_size):
                    memory.nodes.append(nodes[i])
                    memory.edge_attributes.append(edge_attributes[i])
                    memory.rewards.append(rewards[i])
                    memory.log_probs.append(logps[i])
                    memory.actions.append(actions[i])
                    
                if (batch_index+1) % self.update_timestep == 0:
                    self.agent.update(memory, i, datas, env)
                    memory.clear()
                    
                time_Space = 100
                rewards1.append(torch.mean(rewards.detach()).item())
                
                if (batch_index+1) % time_Space == 0:
                    end = time.time()
                    times.append(end - start)
                    start = end
                    mean_reward = np.mean(rewards1[-time_Space:])
                    print('  Batch %d/%d, reward: %2.3f,took: %2.4fs' %(batch_idx, len(data_loader), 
                                                                        mean_reward, times[-1]))
                    
                    
                ### TODO: test epoch code/function    
                
    
def train():
        
    class RunBuilder():
        @staticmethod

        def get_runs(params):

            Run = namedtuple('Run', params.keys())
            runs = []
            for v in product(*params.values()):
                runs.append(Run(*v))
            return runs

    params = OrderedDict(customer_feature = [4],
             vehicle_feature = [8],
             customers_count = [21],
             model_size = [128],
             encoder_layer = [3],
             num_head = [8],
             ff_size_actor = [128],
             ff_size_critic = [512],
             tanh_xplor = [10],
             edge_embedding_dim = [64],
             greedy = [False],
             learning_rate = [3e-4],
             ppo_epoch = [3],
             batch_size = [4],
             entropy_value = [0.2],
             epsilon_clip = [0.2],
             epoch = [4],
             timestep = [2])  


    runs = RunBuilder.get_runs(params)

    for customer_feature, vehicle_feature, customers_count, model_size, \
        encoder_layer, num_head, ff_size_actor, ff_size_critic, tanh_xplor, edge_embedding_dim, greedy,\
        learning_rate, ppo_epoch, batch_size, entropy_value, epsilon_clip, epoch, timestep in runs:

        data = DVRPSR_Dataset.create_data(batch_size=batch_size*2, vehicle_speed=1/3)
        data.normalize()
        env = DVRPSR_Environment
        
        trainppo = TrainPPOAgent(customer_feature, vehicle_feature, customers_count, model_size,
                                 encoder_layer, num_head, ff_size_actor, ff_size_critic, tanh_xplor, 
                                 edge_embedding_dim, greedy,learning_rate, ppo_epoch, batch_size, 
                                 entropy_value, epsilon_clip, epoch, timestep)
        
        trainppo.run_train(data, env, batch_size)
        

   

train()           

In [None]:
# train_data = []
# for i in range(4):
#     datas = Data(nodes = data.nodes[i],
#                  customer_mask = data.customer_mask[i] if data.customer_mask is not None else 'None',
#                  edge_attributes = data.edges_attributes[i])
#     train_data.append(datas)
    
train_data_loader = DataLoader(data, batch_size=4, shuffle=True)
print(train_data_loader)



In [None]:
for batch_index, minibatch in enumerate(train_data_loader):
    print(minibatch[1])
                


In [None]:
    nodes, customer_mask, edge_attributes = minibatch.nodes, minibatch.customer_mask, minibatch.edge_attributes
    print(nodes.size(), customer_mask.size(), edge_attributes.size())

    nodes = nodes.view(batch_size, customers_count, 4)
    edge_attributes = edge_attributes.view(batch_size, customers_count*customers_count, 1)
    if customer_mask is not None:
        customer_mask = customer_mask.view(batch_size, customers_count, 1)
        customer_mask.to(device)