In [28]:
import torch
import random
import math
from torch.utils.data import Dataset

import pandas as pd
import numpy as np
import networkx as nx
import sys


In [29]:
class DVRPSR_Dataset(Dataset):

    customer_feature = 4 # customer features location (x_i,y_i) and duration of service(d), appearance (u)

    @classmethod
    def create_data(cls,
                    batch_size = 2,
                    vehicle_count = 2,
                    vehicle_speed = 20, # km/hr
                    Lambda = 0.025, # request rate per min
                    dod = 0.5,
                    horizon = 400,
                    fDmean = 10,
                    fDstd = 2.5):


        # static customer counts V = Lambda*horizon*(1-dod)/(dod+0.5)
        V_static = int(Lambda*horizon*(1-dod)/(dod)+0.5)

        # total customer count
        V = int(Lambda*horizon/(dod) + 0.5)

        size = (batch_size, V, 1)
        
        # initialize the graph of vienna network
        graph = cls.initialize_graph()

        # get the coordinates of customers
        data_vienna = pd.read_csv('vienna_cordinates.csv')

        # get depot coordinates: Id, xcoords, ycoords
        depot = cls.get_depot_location(data_vienna)

        # get location of customers: id, xcoords, ycoords
        locations = cls.get_customers_coordinates(data_vienna, batch_size, V, depot)

        # get edges index and attributes, which is distance between one node to others n_i*n_j
        edges_index, edges_attributes = cls.get_edges_attributes(batch_size, graph, depot, locations, V)
        
        ### generate Static_Dynamic customer requests
        dynamic_request = cls.generateRandomDynamicRequests(batch_size,
                                                            V,
                                                            V_static,
                                                            fDmean,
                                                            fDstd,
                                                            Lambda,
                                                            horizon)

        customers = torch.zeros((batch_size,V,cls.customer_feature))
        customers[:,:,:2] = locations[:,:,1:]
        customers[:,:,2:4] = dynamic_request



        depo = torch.zeros((batch_size, 1, cls.customer_feature))
        depo[:,:,0:2] = torch.from_numpy(depot[0][1:])
        depo[:,:,2] =  0

        nodes = torch.cat((depo, customers), 1)
        
        dataset = cls(vehicle_count, vehicle_speed, horizon, nodes, V, 
                      edges_index, edges_attributes, customer_mask = None)
        
        return dataset
    
    def __init__(self, vehicle_count, vehicle_speed, horizon, nodes, V,
                 edges_index, edges_attributes, customer_mask=None):
        
        self.vehicle_count = vehicle_count
        self.vehicle_speed = vehicle_speed
        self.nodes = nodes
        self.vehicle_time_budget = horizon
        self.edges_index = edges_index
        self.edges_attributes = edges_attributes

        self.batch_size, self.nodes_count, d = self.nodes.size()

        if d!= self.customer_feature:
            raise ValueError("Expected {} customer features per nodes, got {}".format(
                self.customer_feature, d))

        self.customer_mask = customer_mask
        self.customer_count = V
        
        
    def initialize_graph():
    
        coordinates = pd.read_csv("vienna_dist.csv", header = None, sep=' ')
        coordinates.columns = ['coord1','coord2','dist']
        graph = nx.DiGraph()

        # add the rows to the graph for shortest path and distance calculations
        for _, row in coordinates.iterrows():
            graph.add_edge(row['coord1'], row['coord2'], weight=row['dist'])

        return graph


    def precompute_shortest_path(graph, start_node, end_node):

        shortest_path = nx.shortest_path(graph, start_node, end_node)

        # TODO: distance need to be normalized afterwords
        shortest_path_length = sum(graph.get_edge_data(u, v)['weight'] 
                                   for u, v in zip(shortest_path, shortest_path[1:]))

        return shortest_path, shortest_path_length 
    
    
    def get_distanceLL(lat1, lon1, lat2, lon2):
    
        R = 6371  # Radius of the Earth in kilometers

        lat1_rad = math.radians(lat1)
        lon1_rad = math.radians(lon1)
        lat2_rad = math.radians(lat2)
        lon2_rad = math.radians(lon2)

        dlat = lat2_rad - lat1_rad
        dlon = lon2_rad - lon1_rad

        a = math.sin(dlat / 2) ** 2 + math.cos(lat1_rad) * math.cos(lat2_rad) * math.sin(dlon / 2) ** 2
        c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
        distance = R * c
        return distance
    

    def get_NearestNodeLL(lat, lon, lats, lons):
        nearest = (-1, sys.float_info.max)
        for i in range(len(lats)):
            dist = DVRPSR_Dataset.get_distanceLL(lat, lon, lats[i], lons[i])
            if dist < nearest[1]:
                nearest = (i, dist)
        return nearest[0]
    


    def get_depot_location(data_vienna):

        ll = (48.178808, 16.438460)
        lat = ll[0] / 180 * math.pi
        lon = ll[1] / 180 * math.pi
        lats = data_vienna['lats']
        lons = data_vienna['lons']
        depot = DVRPSR_Dataset.get_NearestNodeLL(lat, lon, lats, lons)
        depot_coordinates = np.array(data_vienna[data_vienna['id']==depot][['id','xcoords', 'ycoords']])

        return depot_coordinates
    
    def get_customers_coordinates(data_vienna, batch_size, customers_count, depot):
        
        torch.manual_seed(42)

        # Excluding depot id from the customers selection
        data_vienna_without_depot = data_vienna[data_vienna['id'] != int(depot[0][0])].reset_index()

        # Sample customers indices for all batches at once
        sampled_customers = torch.multinomial(torch.tensor(data_vienna_without_depot['id'], dtype=torch.float32),
                                              num_samples=batch_size * customers_count, replacement=True)
        
        sampled_customers = sampled_customers.reshape(batch_size, customers_count)

        # Gather the sampled locations using the indices
        sampled_locations = data_vienna_without_depot.loc[sampled_customers.flatten()].reset_index(drop=True)

        # Reshape the locations to match the batch size
        locations = sampled_locations.groupby(sampled_locations.index // customers_count)

        # Create PyTorch tensors for the batched data
        locations_tensors = []
        for _, batch in locations:
            id_tensor = torch.tensor(batch['id'].values, dtype=torch.long)
            coords_tensor = torch.tensor(batch[['xcoords', 'ycoords']].values, dtype=torch.float32)
            batch_tensor = torch.cat((id_tensor.unsqueeze(1), coords_tensor), dim=1)
            locations_tensors.append(batch_tensor)

        return torch.stack(locations_tensors)
    
    def c_dist(x1,x2):
        return ((x1[0]-x2[0])**2+(x1[1]-x2[1])**2)**0.5
    
    def get_edges_attributes(batch_size, graph, depot, locations, V):
    
        # all customers ID inclusing depot
        
        print('Initialzing edges')
        edge_depot = torch.zeros((batch_size, 1, 2))
        edge_depot[:,:,0] = depot[0][1]
        edge_depot[:,:,1] = depot[0][2]
        edge_data = torch.cat((edge_depot, locations[:,:,1:3]), dim=1)

        # generate edge index
        edges_index = []

        for i in range(V+1):
            for j in range(V+1):
                edges_index.append([i, j])
        edges_index = torch.LongTensor(edges_index)
        edges_index = edges_index.transpose(dim0=0,dim1=1)

        # generate nodes attributes
        edges_batch = []

        for batch in edge_data:
            edges = torch.zeros((V+1, V+1, 1), dtype=torch.float32)
            for i, node1 in enumerate(batch):
                for j, node2 in enumerate(batch):
                    distance = DVRPSR_Dataset.c_dist(node1, node2)
                    edges[i][j][0] = distance

            edges = edges.reshape(-1, 1)
            edges_batch.append(edges)

        return edges_index, torch.stack(edges_batch)
    
    
    def generateRandomDynamicRequests(batch_size=2 ,
                                      V=20,
                                      V_static=10,
                                      fDmean=10,
                                      fDstd=2.5,
                                      Lambda=0.025,
                                      horizon=400,
                                      dep = 0,
                                      u = 0):
        gen = random.Random()
        gen.seed() # uses the default system seed
        unifDist = gen.random # uniform distribution
        durDist = lambda: max(0.01, gen.gauss(fDmean, fDstd)) # normal distribution with fDmean and fDstd

        # TODO: in actual data , we need to add a depo node with corrdinate, which should be removed from selected
        #       nodes as well.

        requests = []
        for b in range(batch_size):

            static_request = []
            dynamic_request = []
            u = 0

            while True:
                unif = unifDist()
                u += -(1/Lambda) * math.log(unif)
                if u > horizon or len(dynamic_request) > (V-V_static+2):
                    break
                d = round(durDist(),2)
                while d<=0:
                    d = round(durDist(),2)

                dynamic_request.append([d, round(u,2)])

            for j in range(V-len(dynamic_request)):
                d = round(durDist(),2)
                while d<=0:
                    d = round(durDist(),2)
                static_request.append([d,0])

            request = static_request+dynamic_request
            random.shuffle(request)
            requests.append(request)

        return torch.tensor(requests)
    
    

    
    def __len__(self):
        return self.batch_size


    def __getitem__(self, i):
        if self.customer_mask is None:
            return self.nodes[i]
        else:
            return self.nodes[i], self.customer_mask[i]

    def nodes_generate(self):
        if self.customer_mask is None:
            yield from self.nodes
        else:
            yield from (n[m^1] for n,m in zip(self.nodes, self.customer_mask))  
            
            
    def normalize(self):
        loc_max, loc_min = self.nodes[:,:,:2].max().item(), self.nodes[:,:,:2].min().item()
        loc_max -= loc_min
        edge_max_length = self.edges_attributes[:,:,0].max().item()

        self.nodes[:,:,:2] -= loc_min
        self.nodes[:,:,:2] /= loc_max
        self.nodes[:,:,2:] /=self.vehicle_time_budget

        self.vehicle_speed *= self.vehicle_time_budget/edge_max_length
        self.vehicle_time_budget = 1
        self.edges_attributes /= edge_max_length
        return loc_max, 1

    def save(self, folder_path):
        torch.save({
            'veh_count':self.vehicle_count,
            'veh_speed':self.vehicle_speed,
            'nodes':self.nodes,
            'edges_index':self.edges_index,
            'edges_attributes':self.edges_attributes,
            'customer_count':self.customer_count,
            'customer_mask':self.customer_mask
        }, folder_path)

    @classmethod
    def load(cls, folder_path):
        return cls(**torch.load(folder_path))
        
        
        
        

In [30]:
data = DVRPSR_Dataset.create_data(batch_size=2, vehicle_count=2,vehicle_speed=1/3, Lambda=0.2, dod=0.75, horizon=600)

Initialzing edges


## Environment

In [31]:
import torch


class DVRPSR_Environment:
    vehicle_feature = 8  # vehicle coordinates(x_i,y_i), veh_time_time_budget, total_travel_time, last_customer,
                         # next(destination) customer, last rewards, next rewards
    customer_feature = 4

    # TODO: change pending cost for rewards

    def __init__(self, data, nodes=None, customer_mask=None, 
                 pending_cost=1, 
                 dynamic_reward=0.2,
                 budget_penalty = 10):

        self.vehicle_count = data.vehicle_count
        self.vehicle_speed = data.vehicle_speed
        self.vehicle_time_budget = data.vehicle_time_budget

        self.nodes = data.nodes if nodes is None else nodes
        self.edge_index = data.edges_index
        self.edge_attributes = data.edges_attributes
        self.init_customer_mask = data.customer_mask if customer_mask is None else customer_mask

        self.minibatch, self.nodes_count, _ = self.nodes.size()
        self.distance_matrix = data.edges_attributes.view((self.minibatch, self.nodes_count, self.nodes_count))
        self.pending_cost = pending_cost
        self.dynamic_reward = dynamic_reward
        self.budget_penalty = budget_penalty

    def _update_current_vehicles(self, dest, customer_index, tau=0):

        # calculate travel time
        # TODO: 1) in real world setting we need to calculate the distance of arc
        # If nodes i and j are directly connected by a road segment (i, j) ∈ A, then t(i,j)=t_ij;
        # otherwise, t(i,j)=t_ik1 +t_k1k2 +...+t_knj, where k1,...,kn ∈ V are the nodes along the
        # shortest path from node i to node j.
        #      2) calculate stating time for each vehicle $\tau $, currently is set to zero
        
        # update vehicle previous and next customer id
        self.current_vehicle[:, :, 4] = self.current_vehicle[:, :, 5]
        self.current_vehicle[:, :, 5] = customer_index
        
        # get the distance from current vehicle to its next destination
        dist = torch.zeros((self.minibatch, 1))
        for i in range(self.minibatch):
            dist[i, 0] = self.distance_matrix[i][int(self.current_vehicle[i, :, 4])][int(self.current_vehicle[i, :, 5])]
            
        
        # total travel time    
        tt = dist / self.vehicle_speed

        # customers which are dynamicaly appeared
        dyn_cust = (dest[:, :, 3] > 0).float()

        # budget left while travelling to destination nodes
        budget = tau + tt + dest[:, :, 2]
        # print(budget, tau, tt, dest[:,:,2])

        # update vehicle features based on destination nodes
        self.current_vehicle[:, :, :2] = dest[:, :, :2]
        self.current_vehicle[:, :, 2] -= budget
        self.current_vehicle[:, :, 3] += tt
        self.current_vehicle[:, :, 6] = self.current_vehicle[:, :, 7]
        self.current_vehicle[:, :, 7] = -dist

        # update vehicles states
        self.vehicles = self.vehicles.scatter(1,
                                              self.current_vehicle_index[:, :, None].expand(-1, -1, self.vehicle_feature),
                                              self.current_vehicle)

        return dist, dyn_cust

    def _done(self, customer_index):
        
        self.vehicle_done.scatter_(1, self.current_vehicle_index, torch.logical_or((customer_index == 0), 
                                                                     (self.current_vehicle[:, :, 2] <= 0)))
        # print(self.veh_done, cust_idx==0,self.cur_veh[:,:,2]<=0, (cust_idx==0) | (self.cur_veh[:,:,2]<=0))
        self.done = bool(self.vehicle_done.all())

    def _update_mask(self, customer_index):

        self.new_customer = False
        self.served.scatter_(1, customer_index, customer_index > 0)

        # cost for a vehicle to go to customer and back to deport considering service duration
        cost = torch.zeros((self.minibatch, self.nodes_count,1))
        for i in range(self.minibatch):
            for j in range(self.nodes_count):
                dist_vehicle_customer_depot = self.distance_matrix[i][int(self.current_vehicle[i, :, 4])][j] + \
                                              self.distance_matrix[i][j][0]
                cost[i,j] = dist_vehicle_customer_depot
                
        cost = cost / self.vehicle_speed

        cost += self.nodes[:, :, None, 2]

        overtime_mask = self.current_vehicle[:, :, None, 2] - cost
        overtime_mask = overtime_mask.squeeze(2).unsqueeze(1)
        overtime = torch.zeros_like(self.mask).scatter_(1,
                                                        self.current_vehicle_index[:, :, None].expand(-1, -1, self.nodes_count),
                                                        overtime_mask < 0)

        self.mask = self.mask | self.served[:, None, :] | overtime | self.vehicle_done[:, :, None]
        self.mask[:, :, 0] = 0  # depot

    # updating current vehicle to find the next available vehicle
    def _update_next_vehicle(self):
        
        avail = self.vehicles[:, :, 3].clone()
        avail[self.vehicle_done] = float('inf')

        self.current_vehicle_index = avail.argmin(1, keepdim=True)
        self.current_vehicle = self.vehicles.gather(1, self.current_vehicle_index[:, :, None].expand(-1, -1, self.vehicle_feature))
        self.current_vehicle_mask = self.mask.gather(1, self.current_vehicle_index[:, :, None].expand(-1, -1, self.nodes_count))

    def _update_dynamic_customers(self):

        time = self.current_vehicle[:, :, 3].clone()

        if self.init_customer_mask is None:
            reveal_dyn_reqs = torch.logical_and((self.customer_mask), (self.nodes[:, :, 3] <= time))
        else:
            reveal_dyn_reqs = torch.logical_and((self.customer_mask ^ self.init_customer_mask), (self.nodes[:, :, 3] <= time))

        if reveal_dyn_reqs.any():
            self.new_customer = True
            self.customer_mask = self.customer_mask ^ reveal_dyn_reqs
            self.mask = self.mask ^ reveal_dyn_reqs[:, None, :].expand(-1, self.vehicle_count, -1)
            self.vehicle_done = torch.logical_and(self.vehicle_done, (reveal_dyn_reqs.any(1) ^ True).unsqueeze(1))
            self.vehicles[:, :, 3] = torch.max(self.vehicles[:, :, 3], time)
            self._update_next_vehicle()

    def reset(self):
        # reset vehicle (minibatch*veh_count*veh_feature)
        self.vehicles = self.nodes.new_zeros((self.minibatch, self.vehicle_count, self.vehicle_feature))
        self.vehicles[:, :, :2] = self.nodes[:, :1, :2]
        self.vehicles[:, :, 2] = self.vehicle_time_budget

        # reset vehicle done
        self.vehicle_done = self.nodes.new_zeros((self.minibatch, self.vehicle_count), dtype=torch.bool)
        self.done = False

        # reset cust_mask
        self.customer_mask = self.nodes[:, :, 3] > 0
        if self.init_customer_mask is not None:
            self.customer_mask = self.customer_mask | self.init_customer_mask

        # reset new customers and served customer since now to zero (all false)
        self.new_customer = True
        self.served = torch.zeros_like(self.customer_mask)

        # reset mask (minibatch*veh_count*nodes)
        self.mask = self.customer_mask[:, None, :].repeat(1, self.vehicle_count, 1)

        # reset current vehicle index, current vehicle, current vehicle mask
        self.current_vehicle_index = self.nodes.new_zeros((self.minibatch, 1), dtype=torch.int64)
        
        self.current_vehicle = self.vehicles.gather(1, 
                                                    self.current_vehicle_index[:, :, None].expand(-1, -1, self.vehicle_feature))
        self.current_vehicle_mask = self.mask.gather(1, 
                                             self.current_vehicle_index[:, :, None].expand(-1, -1, self.nodes_count))

    
    def step(self, customer_index):
        dest = self.nodes.gather(1, customer_index[:, :, None].expand(-1, -1, self.customer_feature))
        dist, dyn_cust = self._update_current_vehicles(dest, customer_index)

        #cust = (dest[:, :, 3] >= 0).float()

        self._done(customer_index)
        self._update_mask(customer_index)
        self._update_next_vehicle()

        #reward = -dist * (1 - dyn_cust*self.dynamic_reward)
        reward = self.current_vehicle[:, :, 7] - self.current_vehicle[:,:,6] + self.dynamic_reward*dyn_cust
        pending_static_customers = torch.logical_and((self.served ^ True), 
                                                     (self.nodes[:, :, 3] == 0)).float().sum(-1,keepdim=True) - 1
        
        reward -= self.pending_cost*pending_static_customers

        if self.done:

            if self.init_customer_mask is not None:
                self.served += self.init_customer_mask
            # penalty for pending customers
            pending_customers = torch.logical_and((self.served ^ True), 
                                                  (self.nodes[:, :, 3] >= 0)).float().sum(-1, keepdim=True) - 1

            # TODO: penalty for having unused time budget as well not serving customers
            reward -= self.dynamic_reward*pending_customers
            

        self._update_dynamic_customers()

        return reward

    def state_dict(self, dest_dict=None):
        if dest_dict is None:
            dest_dict = {'vehicles': self.vehicles,
                         'vehicle_done': self.vehicle_done,
                         'served': self.served,
                         'mask': self.mask,
                         'current_vehicle_index': self.current_vehicle_index}

        else:
            dest_dict["vehicles"].copy_(self.vehicles)
            dest_dict["vehicle_done"].copy_(self.vehicle_done)
            dest_dict["served"].copy_(self.served)
            dest_dict["mask"].copy_(self.mask)
            dest_dict["current_vehicle_index"].copy_(self.current_vehicle_index)

        return dest_dict

    def load_state_dict(self, state_dict):
        self.vehicles.copy_(state_dict["vehicles"])
        self.vehicle_done.copy_(state_dict["vehicle_done"])
        self.served.copy_(state_dict["served"])
        self.mask.copy_(state_dict["mask"])
        self.current_vehicle_index.copy_(state_dict["current_vehicle_index"])

        self.current_vehicle = self.vehicles.gather(1, 
                                                    self.current_vehicle_index[:, :, None].expand(-1, -1, self.vehicle_feature))
        self.current_vehicle_mask = self.mask.gather(1, self.current_vehicle_index[:, :, None].expand(-1, -1, self.customer_feature))




In [35]:
data.normalize()
env = DVRPSR_Environment(data)
env.reset()

In [36]:
#env.step(torch.tensor([[1],[2],[3],[4]]))

In [37]:
env.edge_index.size()

torch.Size([2, 25921])

In [38]:
env.minibatch

2

## Encoder 

In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
import math
from torch.distributions.categorical import Categorical
from torch.optim.lr_scheduler import LambdaLR
import time

In [40]:
INIT = True


class GatConv(MessagePassing):
    
    def __init__(self, in_channel, out_channel, edge_channel,
                negative_slope = 0.2, dropout = 0):
        
        super(GatConv, self).__init__(aggr='sum')
        
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.edge_channel = edge_channel
        
        self.negative_slope = negative_slope
        self.dropout = dropout
        
        self.fc = nn.Linear(self.in_channel, self.out_channel)
        self.attention = nn.Linear(2*self.out_channel + self.edge_channel , self.out_channel)
        
        if INIT:
            
            for name, p in self.named_parameters():
                if 'weight' in name:
                    if len(p.size()) >= 2:
                        nn.init.orthogonal_(p, gain=1)
                elif 'bias' in name:
                    nn.init.constant_(p, 0)
                    
    def forward(self, x, edge_index, edge_attributes, customer_mask = None, size = None):
        x = self.fc(x)
        return self.propagate(edge_index, size=size, x=x, edge_attributes = edge_attributes)


    def message(self, edge_index_i, x_i, x_j, edge_attributes):

        x = torch.cat([x_i, x_j, edge_attributes], dim=1)
        alpha = self.attention(x)
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index_i)

        # Sample attention coefficients stochastically.
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        return x_j * alpha


    def update(self, aggr_out):
        return aggr_out


        

        
class GatEncoder(nn.Module):
    
    def __init__(self, input_node_dim, hidden_node_dim, input_edge_dim, hidden_edge_dim, conv_layers = 3):
        
        super(GatEncoder, self).__init__()
        
        self.hidden_node_dim = hidden_node_dim
        self.fc_node = nn.Linear(input_node_dim, hidden_node_dim)
        self.fc_depot = nn.Linear(input_node_dim, hidden_node_dim)
        
        self.fc_edge = nn.Linear(input_edge_dim, hidden_edge_dim)
        
        self.bn_node = nn.BatchNorm1d(hidden_node_dim)
        self.bn_edge = nn.BatchNorm1d(hidden_edge_dim)
        
        self.convs = nn.ModuleList(
                     [GatConv(hidden_node_dim, hidden_node_dim, hidden_edge_dim) for i in range(conv_layers)])
            
        if INIT:
            for name, p in self.named_parameters():
                if 'weight' in name:
                    if len(p.size()) >= 2:
                        nn.init.orthogonal_(p, gain=1)
                elif 'bias' in name:
                    nn.init.constant_(p, 0)
            
    def forward(self, env, mask=None):
            
        batch_size = env.minibatch
        nodes_count = env.nodes_count
        
        x = env.nodes

        customer_embed = torch.cat((self.fc_depot(x[:, :1, :]),
                                    self.fc_node(x[:, 1:, :])), dim=1)
        if mask is not None:
            customer_embed[mask] = 0
            
        x = customer_embed.view(batch_size*nodes_count, self.hidden_node_dim)
        x = self.bn_node(x)

        edge_attributes = self.fc_edge(env.edge_attributes.view(batch_size*nodes_count*nodes_count,1).float())
        edge_attributes = self.bn_edge(edge_attributes)
        edge_index = env.edge_index.repeat(batch_size,1).view(2,batch_size*nodes_count*nodes_count)

        for conv in self.convs:
            x1 = conv(x, edge_index, edge_attributes)
            x = x + x1

        x = x.reshape((batch_size, -1, self.hidden_node_dim))

        return x

            
    

In [41]:
device = torch.device('cpu')
encoder = GatEncoder(input_node_dim = 4, hidden_node_dim=128, input_edge_dim=1, hidden_edge_dim=16, conv_layers=3)
encoder.to(device)

GatEncoder(
  (fc_node): Linear(in_features=4, out_features=128, bias=True)
  (fc_depot): Linear(in_features=4, out_features=128, bias=True)
  (fc_edge): Linear(in_features=1, out_features=16, bias=True)
  (bn_node): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_edge): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (convs): ModuleList(
    (0-2): 3 x GatConv()
  )
)

In [42]:
env.nodes = env.nodes.to(device)
env.edge_attributes = env.edge_attributes.to(device)
env.edge_index = env.edge_index.to(device)
x = encoder(env)

In [43]:
mm

NameError: name 'mm' is not defined

## Graph Encoder

In [44]:
class GraphMultiHeadAttention(nn.Module):
    
    def __init__(self, num_head, query_size, key_size = None, value_size = None, edge_dim_size = None, bias = False):
        
        super(GraphMultiHeadAttention, self).__init__()
        self.num_head = num_head
        self.query_size = query_size
        
        self.key_size = self.query_size if key_size is None else key_size
        self.value_size = self.key_size if value_size is None else value_size
        self.edge_dim_size = self.query_size//2 if edge_dim_size is None else edge_dim_size
        
        self.scaling_factor = self.key_size**-0.5
        
        self.keys_per_head = self.key_size // self.num_head
        self.values_per_head = self.value_size // self.num_head
        self.edge_size_per_head = self.edge_dim_size
        
        self.edge_embedding = nn.Linear(self.edge_dim_size, self.edge_size_per_head, bias = bias)
        self.query_embedding = nn.Linear(self.query_size, self.num_head * self.keys_per_head, bias = bias)
        self.key_embedding = nn.Linear(self.key_size, self.num_head * self.keys_per_head, bias = bias)
        self.value_embedding = nn.Linear(self.value_size, self.num_head * self.values_per_head, bias = bias)
        self.recombine = nn.Linear(self.num_head * self.values_per_head, self.value_size, bias = bias)
        
        
        self.K_project_pre = None
        self.V_project_pre = None
        
        self.initialize_weights()


    def initialize_weights(self):
        #TODO: add xavier initialziation as well
        
        nn.init.uniform_(self.query_embedding.weight, -self.scaling_factor, self.scaling_factor)
        nn.init.uniform_(self.key_embedding.weight, -self.scaling_factor, self.scaling_factor)
        inv_sq_dv = self.value_size ** -0.5
        nn.init.uniform_(self.value_embedding.weight, -inv_sq_dv, inv_sq_dv)
        
    def precompute(self, keys, values = None):
        
        values = keys if values is None else values
        
        size_KV = keys.size(-2)
        
        self.K_project_pre = self.key_embedding(keys).view(
                             -1, size_KV, self.num_head, self.keys_per_head).permute(0, 2, 3, 1)
        
        self.V_project_pre = self.value_embedding(values).view(
                              -1, size_KV, self.num_head, self.values_per_head).permute(0, 2, 1, 3)
        
        
    def forward(self, queries, keys = None, values = None, edge_attributes = None, mask = None, edge_mask=None):
        
        *batch_size, size_Q, _ = queries.size()
        
        # get queries projection
        Q_project = self.query_embedding(queries).view(
                              -1, size_Q, self.num_head, self.keys_per_head).permute(0, 2, 1, 3)
        
        # get keys projection
        if keys is None:
            if self.K_project_pre is None:
                size_KV = size_Q
                K_project = self.key_embedding(queries).view(
                                -1, size_KV, self.num_head, self.keys_per_head).permute(0, 2, 3, 1)
            else:
                size_KV = self.K_project_pre.size(-1)
                K_project = self.K_project_pre
        else:
            size_KV = keys.size(-2)
            K_project = self.key_embedding(keys).view(
                            -1, size_KV, self.num_head, self.keys_per_head).permute(0, 2, 3, 1)
            
         # get values projection   
        if values is None:
            if self.V_project_pre is None:
                V_project = self.value_embedding(queries).view(
                                -1, size_KV, self.num_head, self.values_per_head).permute(0, 2, 1, 3)
            else:
                V_project = self.V_project_pre
        else:
            V_project = self.value_embedding(values).view(
                            -1, size_KV, self.num_head, self.values_per_head).permute(0, 2, 1, 3)
            
        
        # calculate the compability
        attention = Q_project.matmul(K_project)
        attention *= self.scaling_factor
        
        # if edge attributes are required
        if edge_attributes is not None:
                
            #TODO: edge mask (is it required)
            edge_project = self.edge_embedding(edge_attributes).view(
                                -1, size_Q, size_Q, self.edge_size_per_head)
                
            # get enhanced attention inclusing edge attributes
            attention_expanded = attention.unsqueeze(-1).expand(-1, -1, -1, -1, self.edge_size_per_head)
            
            # Expand edge attributes to match the number of attention heads
            edge_project_expanded = edge_project.unsqueeze(1).expand(-1, attention.size(1), -1, -1, -1)
            
            attention = attention_expanded * edge_project_expanded
            attention = attention.mean(-1)
            
            #print(attention.size())
        
            
            
        if mask is not None:

            if mask.numel() * self.num_head == attention.numel():
                m = mask.view(-1, 1, size_Q, size_KV).expand_as(attention)
            else:
                m = mask.view(-1, 1, 1, size_KV).expand_as(attention)

            attention[m.bool()] = -float('inf')
            
        attention = F.softmax(attention, dim=-1)
        attention = attention.matmul(V_project).permute(0, 2, 1, 3).contiguous().view(
                                                *batch_size, size_Q, self.num_head * self.values_per_head)
        
        output = self.recombine(attention)
        
        return output
        
             

In [45]:
class GraphEncoderlayer(nn.Module):
    
    def __init__(self, num_head, model_size, ff_size):
        super().__init__()
        
        self.attention = GraphMultiHeadAttention(num_head, query_size=model_size)
        self.BN1 = nn.BatchNorm1d(model_size)
        self.FFN_layer1 = nn.Linear(model_size, ff_size)
        
        self.FFN_layer2 = nn.Linear(ff_size, model_size)
        self.BN2 = nn.BatchNorm1d(model_size)
        
    def forward(self, h, e = None, mask = None):
        
        h_attn = self.attention(h, edge_attributes = e, mask=mask)
        h_bn = self.BN1((h_attn + h).permute(0, 2, 1)).permute(0, 2, 1)
        
        h_layer1 = F.relu(self.FFN_layer1(h_bn))
        h_layer2 = self.FFN_layer2(h_layer1)
        
        h_out = self.BN2((h_bn + h_layer2).permute(0, 2, 1)).permute(0, 2, 1)
        
        if mask is not None:
            h_out[mask] = 0
            
        return h_out
    
class GraphEncoder(nn.Module):
    
    def __init__(self, encoder_layer, num_head, model_size, ff_szie):
        super().__init__()
        
        for l in range(encoder_layer):
            self.add_module(str(l), GraphEncoderlayer(num_head, model_size, ff_szie))
            
    def forward(self, h_in, e_in = None,  mask=None):
        
        h = h_in
        e = e_in
        
        for child in self.children():
            h = child(h, e, mask=mask)
        return h
        
        
        
          

In [46]:
customer_encoder = GraphEncoder(encoder_layer=3, num_head=8, model_size=128, ff_szie=256)

In [47]:
cust_feature = 4
edge_feature = 1
model_size = 128
edge_embedding_dim = 64
customers = env.nodes
edges = env.edge_attributes
customer_mask = env.customer_mask
vehicle_mask = env.current_vehicle_mask

In [48]:
edges.size()

torch.Size([2, 25921, 1])

In [49]:
depot_embedding = nn.Linear(cust_feature, model_size)
cust_embedding = nn.Linear(cust_feature, model_size)
edge_embedding = nn.Linear(edge_feature, edge_embedding_dim)

cust_emb = torch.cat((
            depot_embedding(customers[:, 0:1, :]),
            cust_embedding(customers[:, 1:, :])), dim=1)

edge_emb = edge_embedding(edges)

if customer_mask is not None:
    cust_emb[customer_mask] = 0

edge_emb.size()

torch.Size([2, 25921, 64])

In [50]:
# edge_emb = edge_emb.view(env.minibatch, env.nodes_count, env.nodes_count, edge_embedding_dim)
# edge_emb.size()

In [51]:
cust_encoding = customer_encoder(cust_emb, edge_emb)

## Graph Model

In [114]:

class GatAgent(nn.Module):
    
    def __init__(self, customer_feature, vehicle_feature, model_size=128, layer_count=3, 
                 num_head=8, ff_size=128, tanh_xplor=10, greedy = False):
        super().__init__()
        
        self.model_size = model_size
        self.scaling_factor = self.model_size**0.5
        self.tanh_xplor = tanh_xplor
        self.greedy = greedy
        
        self.customer_encoder = GatEncoder(input_node_dim = customer_feature, # TODO: need to automate rest of values
                                           hidden_node_dim=ff_size, 
                                           input_edge_dim=1, 
                                           hidden_edge_dim=16, 
                                           conv_layers=layer_count)
        
        self.vehicle_embedding = nn.Linear(vehicle_feature, ff_size, bias=False)
        
        self.fleet_attention = nn.MultiheadAttention(self.model_size, num_head, 
                                                     kdim = self.model_size, 
                                                     vdim=self.model_size)
        self.vehicle_attention = nn.MultiheadAttention(self.model_size, num_head)
        
        self.customer_to_action_projection = nn.Linear(self.model_size, self.model_size) # TODO: MLP instaed of nn.Linear
        
    
    def encode_customers(self, env, customer_mask = None):
        
        self.customer_representation = self.customer_encoder(env, customer_mask)
        if customer_mask is not None:
            self.customer_representation[customer_mask] = 0
            
    
        
    def vehicle_representation(self, vehicles, vehicle_index, vehicle_mask=None):
        
        vehicles_embedding = self.vehicle_embedding(vehicles)
        
        print(vehicles_embedding.size(), self.customer_representation.size())
        
        fleet_representation, _ = self.fleet_attention(query = vehicles_embedding.permute(1,0, 2),
                                                       key = self.customer_representation.permute(1,0, 2),
                                                       value = self.customer_representation.permute(1,0, 2))
        
        vehicle_query = fleet_representation.gather(1, vehicle_index.unsqueeze(2).expand(-1, -1, self.model_size))
        
        print(vehicle_query.size(), fleet_representation.size())
        self.vehicle_representation_, _ = self.vehicle_attention(query = vehicle_query.permute(1,0,2),
                                                                key = fleet_representation.permute(1,0,2),
                                                                value = fleet_representation.permute(1,0,2))
        
        return self.vehicle_representation_
    
    
    
    def score_customers(self, vehicle_representation):
        
        compact = torch.bmm(vehicle_representation.permute(1,0,2),
                            self.customer_representation.permute(0,2,1))
        compact *= self.scaling_factor
        
        print(compact.size(), vehicle_representation.size())
        
        if self.tanh_xplor is not None:
            compact = self.tanh_xplor*compact.tanh()
        
        return compact
    
    
    def get_logp(self, compact, vehicle_mask = None):
        
        compact[vehicle_mask] = -float('inf')
        return compact.log_softmax(2).squeeze(1)
    
    
    def step(self, env):
        
        vehicle_representation_ = self.vehicle_representation(env.vehicles, 
                                                             env.current_vehicle_index,
                                                             env.current_vehicle_mask)
        compact = self.score_customers(vehicle_representation_)
        logp = self.get_logp(compact, env.current_vehicle_mask)
        
        if self.greedy:
            customer_index = logp.argmax(dim=1, keepdim=True)
        else:
            customer_index = logp.exp().multinomial(1)

        return customer_index, logp.gather(1, customer_index)
    
    
    
    def forward(self, env):
        
        env.reset()
        actions, logps, rewards = [], [], []
        
        while not env.done:
            if env.new_customer:
                self.encode_customers(env, env.customer_mask)
                
            customer_index, logp = self.step(env)
            actions.append((env.current_vehicle_index, customer_index))
            logps.append(logp)
            rewards.append(env.step(customer_index))

        return actions, logps, rewards
        
        

In [115]:
model = GatAgent(4,8)


In [116]:
actions, logps, rewards = model(env)
print("Forward pass ok for DVRPSR")

torch.Size([2, 2, 128]) torch.Size([2, 161, 128])
torch.Size([2, 1, 128]) torch.Size([2, 2, 128])
torch.Size([2, 1, 161]) torch.Size([1, 2, 128])
torch.Size([2, 2, 128]) torch.Size([2, 161, 128])
torch.Size([2, 1, 128]) torch.Size([2, 2, 128])
torch.Size([2, 1, 161]) torch.Size([1, 2, 128])
torch.Size([2, 2, 128]) torch.Size([2, 161, 128])
torch.Size([2, 1, 128]) torch.Size([2, 2, 128])
torch.Size([2, 1, 161]) torch.Size([1, 2, 128])
torch.Size([2, 2, 128]) torch.Size([2, 161, 128])
torch.Size([2, 1, 128]) torch.Size([2, 2, 128])
torch.Size([2, 1, 161]) torch.Size([1, 2, 128])
torch.Size([2, 2, 128]) torch.Size([2, 161, 128])
torch.Size([2, 1, 128]) torch.Size([2, 2, 128])
torch.Size([2, 1, 161]) torch.Size([1, 2, 128])
torch.Size([2, 2, 128]) torch.Size([2, 161, 128])
torch.Size([2, 1, 128]) torch.Size([2, 2, 128])
torch.Size([2, 1, 161]) torch.Size([1, 2, 128])
torch.Size([2, 2, 128]) torch.Size([2, 161, 128])
torch.Size([2, 1, 128]) torch.Size([2, 2, 128])
torch.Size([2, 1, 161]) to

In [117]:
actions

[(tensor([[0],
          [0]]),
  tensor([[ 0],
          [35]])),
 (tensor([[1],
          [1]]),
  tensor([[137],
          [ 22]])),
 (tensor([[0],
          [1]]),
  tensor([[ 72],
          [147]])),
 (tensor([[1],
          [0]]),
  tensor([[ 32],
          [132]])),
 (tensor([[0],
          [0]]),
  tensor([[ 26],
          [101]])),
 (tensor([[1],
          [1]]),
  tensor([[ 85],
          [148]])),
 (tensor([[0],
          [0]]),
  tensor([[55],
          [ 0]])),
 (tensor([[0],
          [1]]),
  tensor([[8],
          [0]])),
 (tensor([[1],
          [0]]),
  tensor([[  0],
          [102]])),
 (tensor([[0],
          [0]]),
  tensor([[63],
          [37]])),
 (tensor([[0],
          [1]]),
  tensor([[ 88],
          [128]])),
 (tensor([[1],
          [0]]),
  tensor([[0],
          [0]])),
 (tensor([[0],
          [0]]),
  tensor([[0],
          [0]]))]

In [118]:
rewards

[tensor([[-37.],
         [-40.]]),
 tensor([[-36.4398],
         [-39.5029]]),
 tensor([[-36.2398],
         [-38.6726]]),
 tensor([[-35.4363],
         [-36.6738]]),
 tensor([[-35.1854],
         [-36.1840]]),
 tensor([[-33.9776],
         [-35.9468]]),
 tensor([[-33.5333],
         [-35.9081]]),
 tensor([[-33.5521],
         [-35.8597]]),
 tensor([[-34.2660],
         [-35.8776]]),
 tensor([[-33.4706],
         [-35.8776]]),
 tensor([[-33.4493],
         [-35.8504]]),
 tensor([[-34.0729],
         [-36.0961]]),
 tensor([[-64.1527],
         [-65.8343]])]