In [None]:
import torch
import torch.nn as nn
import time
import argparse

import os
import datetime
import gc
from torch.distributions.categorical import Categorical
from torch.utils.data import DataLoader

import math
import numpy as np
import torch.nn.functional as F
import random
import torch.optim as optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
from torch.utils.data import Dataset
from torch.autograd import Variable
import matplotlib
matplotlib.use('Agg')

# visualization 
%matplotlib inline
from IPython.display import set_matplotlib_formats, clear_output
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
device = torch.device("cpu"); gpu_id = -1 # select CPU

gpu_id = '0' # select a single GPU  
#gpu_id = '2,3' # select multiple GPUs  
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)  
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('GPU name: {:s}, gpu_id: {:s}'.format(torch.cuda.get_device_name(0),gpu_id))   
    
print(device)

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.autograd import Variable
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


class VehicleRoutingDataset(Dataset):
    def __init__(self, num_samples, input_size, max_load=20, max_demand=9):
        super(VehicleRoutingDataset, self).__init__()

        if max_load < max_demand:
            raise ValueError(':param max_load: must be > max_demand')

        self.num_samples = num_samples
        self.max_load = max_load
        self.max_demand = max_demand

        # Depot location will be the first node in each
        locations = torch.rand((num_samples, 2, input_size + 1))
        self.static = locations

        # All states will broadcast the drivers current load
        # Note that we only use a load between [0, 1] to prevent large
        # numbers entering the neural network
        dynamic_shape = (num_samples, 1, input_size + 1)
        loads = torch.full(dynamic_shape, 0)

        # All states will have their own intrinsic demand in [1, max_demand), 
        # then scaled by the maximum load. E.g. if load=10 and max_demand=30, 
        # demands will be scaled to the range (0, 3)
        HalfofTheDemands = torch.randint(1, max_demand + 1, (num_samples,1,int(input_size / 2)))
        TheOtherHalfDemands = torch.randint(-1 * max_demand, 0, (num_samples,1,int(input_size / 2)))
        
        # Cat both demands with each other
        demands = torch.cat((HalfofTheDemands,TheOtherHalfDemands),dim = 2).squeeze(1)
        # Shuffling the demands tensor over the col dim
        demands = demands[:,torch.randperm(demands.size()[1])]
        # Shuffling the demands "converting demands back into numpy array for shuffling"
        # Adding zero demand for the depot 
        demands = torch.cat((torch.zeros((num_samples,1,1)),demands.unsqueeze(1)),dim = 2)
        # Normlize demands with the maximum load
        demands = demands / float(max_load)
        self.dynamic = torch.cat((loads, demands), dim=1)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # (static, dynamic, start_loc)
        return (self.static[idx], self.dynamic[idx], self.static[idx, :, 0:1])


def reward_fn(static, tour_indices):
    
    """
    Euclidean distance between all cities / nodes given by tour_indices
    """
    # Convert the indices back into a tour
    idx = tour_indices.unsqueeze(1).expand(-1, static.size(1), -1)
    tour = torch.gather(static.data, 2, idx).permute(0, 2, 1)

    # Ensure we're always returning to the depot - note the extra concat
    # won't add any extra loss, as the euclidean distance between consecutive
    # points is 0
    start = static.data[:, :, 0].unsqueeze(1)
    y = torch.cat((start, tour, start), dim=1)

    # Euclidean distance between each consecutive point
    tour_len = torch.sqrt(torch.sum(torch.pow(y[:, :-1] - y[:, 1:], 2), dim=2))

    return tour_len.sum(1).detach()


def render_fn(static, tour_indices, save_path):
    """Plots the found solution."""

    plt.close('all')
    num_plots = 3 if int(np.sqrt(len(tour_indices))) >= 3 else 1
    _, axes = plt.subplots(nrows=num_plots, ncols=num_plots,sharex='col', sharey='row')
    if num_plots == 1:
        axes = [[axes]]
    axes = [a for ax in axes for a in ax]
    for i, ax in enumerate(axes):

        # Convert the indices back into a tour
        idx = tour_indices[i]
        if len(idx.size()) == 1:
            idx = idx.unsqueeze(0)

        idx = idx.expand(static.size(1), -1)
        data = torch.gather(static[i].data, 1, idx).cpu().numpy()

        start = static[i, :, 0].cpu().data.numpy()
        x = np.hstack((start[0], data[0], start[0]))
        y = np.hstack((start[1], data[1], start[1]))

        # Assign each subtour a different colour & label in order traveled
        idx = np.hstack((0, tour_indices[i].cpu().numpy().flatten(), 0))
        where = np.where(idx == 0)[0]

        for j in range(len(where) - 1):

            low = where[j]
            high = where[j + 1]

            if low + 1 == high:
                continue

            ax.plot(x[low: high + 1], y[low: high + 1], zorder=1, label=j)

        ax.legend(loc="upper right", fontsize=3, framealpha=0.5)
        ax.scatter(x, y, s=4, c='r', zorder=2)
        ax.scatter(x[0], y[0], s=20, c='k', marker='*', zorder=3)

        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=200)

In [None]:
import torch
import time

def update_state(demand,dynamic_capcity,selected,c = 0):#dynamic_capcity(num,1)
    
    depot  =  selected.squeeze(-1).eq(0) # Is there a group to access the depot
    current_demand = torch.gather(demand,1,selected)
    dynamic_capcity = dynamic_capcity - current_demand
    
    if depot.any():
        dynamic_capcity[depot.nonzero().squeeze()] = c
        
    return dynamic_capcity.detach()#(bach_size,1)


def update_mask(demand,capcity,selected,mask,i):
    
    # If there is a route to select a depot, mask the depot, otherwise it will not mask the depot
    go_depot = selected.squeeze(-1).eq(0)
    mask1 = mask.scatter(1, selected.expand(mask.size(0), -1), 1)

    if capcity.gt(1).any():
        print("warning")

    if (~go_depot).any():
        mask1[(~go_depot).nonzero(),0] = 0

    if i+1 > (demand.size(1) / 2):
        is_done = (mask1[:, 1:].sum(1) >= (demand.size(1) - 1)).float()
        combined = is_done.gt(0)
        mask1[combined.nonzero(), 0] = 0

    # Mask any city if its demand is greater than the current truck's cap
    a = demand > capcity + 1e-3
    # Mask any city if its demand and the remaining capcity are greater than the truck limit "1"
    b = torch.neg(demand.masked_fill(demand.gt(0), 0.)) + capcity > 1
    mask = a + mask1 + b
    """
    print("mask",mask)
    print("mask1",mask1)
    print('demand',demand)
    print('capcity',capcity)
    print('*************************')
    """
    return mask.detach(),mask1.detach()

# HPN for Random PDP

In [None]:
class Transformer_encoder_net(nn.Module):
    """
    Encoder network based on self-attention transformer
    Inputs :  
      h of size      (bsz, nb_nodes+1, dim_emb)    batch of input cities
    Outputs :  
      h of size      (bsz, nb_nodes+1, dim_emb)    batch of encoded cities
      score of size  (bsz, nb_nodes+1, nb_nodes+1) batch of attention scores
    """
    
    def __init__(self, nb_layers, dim_emb, nb_heads, dim_ff, batchnorm):
        super(Transformer_encoder_net, self).__init__()
        assert dim_emb == nb_heads* (dim_emb//nb_heads) # check if dim_emb is divisible by nb_heads
        self.MHA_layers = nn.ModuleList( [nn.MultiheadAttention(dim_emb, nb_heads) for _ in range(nb_layers)] )
        self.linear1_layers = nn.ModuleList( [nn.Linear(dim_emb, dim_ff) for _ in range(nb_layers)] )
        self.linear2_layers = nn.ModuleList( [nn.Linear(dim_ff, dim_emb) for _ in range(nb_layers)] )   
        if batchnorm:
            self.norm1_layers = nn.ModuleList( [nn.BatchNorm1d(dim_emb) for _ in range(nb_layers)] )
            self.norm2_layers = nn.ModuleList( [nn.BatchNorm1d(dim_emb) for _ in range(nb_layers)] )
        else:
            self.norm1_layers = nn.ModuleList( [nn.LayerNorm(dim_emb) for _ in range(nb_layers)] )
            self.norm2_layers = nn.ModuleList( [nn.LayerNorm(dim_emb) for _ in range(nb_layers)] )
        self.nb_layers = nb_layers
        self.nb_heads = nb_heads
        self.batchnorm = batchnorm
        
    def forward(self, h):      
        # PyTorch nn.MultiheadAttention requires input size (seq_len, bsz, dim_emb) 
        h = h.transpose(0,1) # size(h)=(nb_nodes, bsz, dim_emb)  
        # L layers
        for i in range(self.nb_layers):
            h_rc = h # residual connection, size(h_rc)=(nb_nodes, bsz, dim_emb)
            h, score = self.MHA_layers[i](h, h, h) # size(h)=(nb_nodes, bsz, dim_emb), size(score)=(bsz, nb_nodes, nb_nodes)
            # add residual connection
            
            h = h_rc + h # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                # Pytorch nn.BatchNorm1d requires input size (bsz, dim, seq_len)
                h = h.permute(1,2,0).contiguous() # size(h)=(bsz, dim_emb, nb_nodes)
                h = self.norm1_layers[i](h)       # size(h)=(bsz, dim_emb, nb_nodes)
                h = h.permute(2,0,1).contiguous() # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                h = self.norm1_layers[i](h)       # size(h)=(nb_nodes, bsz, dim_emb) 
            # feedforward
            h_rc = h # residual connection
            h = self.linear2_layers[i](torch.relu(self.linear1_layers[i](h)))
            h = h_rc + h # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                h = h.permute(1,2,0).contiguous() # size(h)=(bsz, dim_emb, nb_nodes)
                h = self.norm2_layers[i](h)       # size(h)=(bsz, dim_emb, nb_nodes)
                h = h.permute(2,0,1).contiguous() # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                h = self.norm2_layers[i](h) # size(h)=(nb_nodes, bsz, dim_emb)
        # Transpose h
        h = h.transpose(0,1) # size(h)=(bsz, nb_nodes, dim_emb)
        return h, score

class Attention(nn.Module):
    def __init__(self, n_hidden):
        super(Attention, self).__init__()
        self.size = 0
        self.batch_size = 0
        self.dim = n_hidden
        
        v  = torch.FloatTensor(n_hidden)
        self.v  = nn.Parameter(v)
        self.v.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        
        # parameters for pointer attention
        self.Wref = nn.Linear(n_hidden, n_hidden)
        self.Wq = nn.Linear(n_hidden, n_hidden)
    
    
    def forward(self, q, ref):       # query and reference
        self.batch_size = q.size(0)
        self.size = int(ref.size(0) / self.batch_size)
        q = self.Wq(q)     # (B, dim)
        ref = self.Wref(ref)
        ref = ref.view(self.batch_size, self.size, self.dim)  # (B, size, dim)
        
        q_ex = q.unsqueeze(1).repeat(1, self.size, 1) # (B, size, dim)
        # v_view: (B, dim, 1)
        v_view = self.v.unsqueeze(0).expand(self.batch_size, self.dim).unsqueeze(2)
        
        # (B, size, dim) * (B, dim, 1)
        u = torch.bmm(torch.tanh(q_ex + ref), v_view).squeeze(2)
        
        return u, ref
    
class LSTM(nn.Module):
    def __init__(self, n_hidden):
        super(LSTM, self).__init__()
        
        # parameters for input gate
        self.Wxi = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Whi = nn.Linear(n_hidden, n_hidden)    # W(ht)
        self.wci = nn.Linear(n_hidden, n_hidden)    # w(ct)
        
        # parameters for forget gate
        self.Wxf = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Whf = nn.Linear(n_hidden, n_hidden)    # W(ht)
        self.wcf = nn.Linear(n_hidden, n_hidden)    # w(ct)
        
        # parameters for cell gate
        self.Wxc = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Whc = nn.Linear(n_hidden, n_hidden)    # W(ht)
        
        # parameters for forget gate
        self.Wxo = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Who = nn.Linear(n_hidden, n_hidden)    # W(ht)
        self.wco = nn.Linear(n_hidden, n_hidden)    # w(ct)
    
    
    def forward(self, x, h, c):       # query and reference
        
        # input gate
        i = torch.sigmoid(self.Wxi(x) + self.Whi(h) + self.wci(c))
        # forget gate
        f = torch.sigmoid(self.Wxf(x) + self.Whf(h) + self.wcf(c))
        # cell gate
        c = f * c + i * torch.tanh(self.Wxc(x) + self.Whc(h))
        # output gate
        o = torch.sigmoid(self.Wxo(x) + self.Who(h) + self.wco(c))
        
        h = o * torch.tanh(c)
        
        return h, c

class HPN_PDP(nn.Module):
    def __init__(self, n_feature, n_hidden):
        super(HPN_PDP, self).__init__()
        
        self.city_size = 0
        self.batch_size = 0
        self.dim = n_hidden
        # pointer layer
        self.pointer = Attention(n_hidden)
        self.TransPointer = Attention(n_hidden)
        
        # LSTM encoder
        self.encoder = LSTM(n_hidden)
        
        # trainable first hidden input
        h0 = torch.FloatTensor(n_hidden)
        c0 = torch.FloatTensor(n_hidden)
        
        self.h0 = nn.Parameter(h0)
        self.c0 = nn.Parameter(c0)
        self.h0.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        self.c0.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        
        r1 = torch.ones(1)
        r2 = torch.ones(1)
        r3 = torch.ones(1)
        self.r1 = nn.Parameter(r1)
        self.r2 = nn.Parameter(r2)
        self.r3 = nn.Parameter(r3)
        
        # embedding for the feature tensor
        self.embedding_all = nn.Linear(2 * n_feature  , n_hidden)
        
        self.fc = nn.Linear(n_hidden + 1, n_hidden, bias=False)
        self.fc1 = nn.Linear(n_hidden, n_hidden, bias=False)
        
        # transformer's encoder
        self.Transembedding_all = Transformer_encoder_net(6, 128, 8, 512, batchnorm=True)
        
        # weights for GNN
        self.W1 = nn.Linear(n_hidden, n_hidden)
        self.W2 = nn.Linear(n_hidden, n_hidden)
        self.W3 = nn.Linear(n_hidden, n_hidden)
        
        # aggregation function for GNN
        self.agg_1 = nn.Linear(n_hidden, n_hidden)
        self.agg_2 = nn.Linear(n_hidden, n_hidden)
        self.agg_3 = nn.Linear(n_hidden, n_hidden)
    
    
    def forward(self,static, dynamic,deterministic = False,decoder_input = None):
        
        """
        Parameters
        ----------
        static: Array of size (batch_size, feats, num_cities)
            Defines the elements to consider as static. For the TSP, this could be
            things like the (x, y) coordinates, which won't change
        dynamic: Array of size (batch_size, feats, num_cities)
            Defines the elements to consider as static. For the VRP, this can be
            things like the (load, demand) of each city. If there are no dynamic
            elements, this can be set to None
        decoder_input: Array of size (batch_size, num_feats)
            Defines the outputs for the decoder. Currently, we just use the
            static elements (e.g. (x, y) coordinates), but this can technically
            be other things as well
        """
            
        tour_idx, tour_logp = [], []
        self.batch_size, self.city_size,input_size= static.size() # (B,size,feat)
        # Always use a mask - if no function is provided, we don't update it
        
        mask1 = torch.zeros((self.batch_size, self.city_size)).to(device)
        mask = torch.zeros((self.batch_size, self.city_size)).to(device)
        
        dynamic_capcity = dynamic[:,0,0].view(self.batch_size,-1)#bat_size
        demands = dynamic[:,:,1].view(self.batch_size, self.city_size)#（batch_size,seq_len）
        
        # Handle hidden and cell state for LSTM
        h0 = self.h0.unsqueeze(0).expand(self.batch_size, self.dim)
        c0 = self.c0.unsqueeze(0).expand(self.batch_size, self.dim)
        h0 = h0.unsqueeze(0).contiguous()
        c0 = c0.unsqueeze(0).contiguous()
        # let h0, c0 be the hidden variable of first turn
        h = h0.squeeze(0)
        c = c0.squeeze(0)
        
        max_steps = 2 * self.city_size
        # Special Embedding for depot
        # Cat both feature for embedding
        all_feature = torch.cat((static,demands.unsqueeze(2)),dim = 2) 
        all_feature = torch.cat((all_feature,torch.cdist(static,decoder_input[:,:2].unsqueeze(1),p=2)),dim = 2)
        
        # init embedding for feature vector
        context = self.embedding_all(all_feature) #(B,size,n_hidden)
        # ==================================================
        # graph neural network encoder & transformer encoder
        # ==================================================
        Trans_hidden,_ = self.Transembedding_all(context) # (B,size,n_hidden)
        TransPooled  = Trans_hidden.mean(dim=1)
        #Trans_hidden = Trans_hidden.reshape(-1, self.dim) # (B*size,n_hidden)
        context = context.reshape(-1, self.dim)           # (B*size,n_hidden)

        # GNN layers
        context = self.r1 * self.W1(context) + (1-self.r1) * F.relu(self.agg_1(context/(self.city_size-1)))
        context = self.r2 * self.W2(context) + (1-self.r2) * F.relu(self.agg_2(context/(self.city_size-1)))
        context = self.r3 * self.W3(context) + (1-self.r3) * F.relu(self.agg_3(context/(self.city_size-1)))
        contextPooled = context.reshape(self.batch_size,self.city_size,self.dim).mean(dim=1)
        pool = TransPooled + contextPooled

        index = torch.zeros(self.batch_size).to(device).long()
        
        for t in range(max_steps):
            if not mask1[:, 1:].eq(0).any():
                break
            if t == 0:                
                _input = Trans_hidden[:, 0, :]  # depot
                
            decoder_input = torch.cat((_input,dynamic_capcity),dim = 1)
            decoder_input = self.fc(decoder_input)
            
            pool = self.fc1(pool)
            decoder_input = decoder_input + pool
            
            if t == 0:
                mask, mask1 = update_mask(demands, dynamic_capcity, index.unsqueeze(-1), mask1, t)
                
            # LSTM encoder
            h, c = self.encoder(decoder_input, h, c)
            # pointer
            u1, _ = self.pointer(h, context.reshape(-1, self.dim))
            u2 ,_ = self.TransPointer(h,Trans_hidden.reshape(-1, self.dim))
            u = u1 + u2
            u = 10 * torch.tanh(u)
            u = u.masked_fill(mask.bool(), float("-inf"))
            probs = F.softmax(u, dim=1)
            
            # When training, sample the next step according to its probability.
            # During testing, we can take the greedy approach and choose highest
            if  deterministic:
                prob, index = torch.max(probs,dim=1)  # Greedy
                logp = prob.log()
            else:
                # Sampling
                m = torch.distributions.Categorical(probs)
                index = m.sample()
                logp = m.log_prob(index)
                
            is_done = (mask1[:, 1:].sum(1) >= (Trans_hidden.size(1) - 1)).float()
            logp = logp * (1. - is_done)
                
            # After visiting a node update the dynamic representation
            #dynamic = update_fn(dynamic.permute(0,2,1), ptr.data).permute(0,2,1)
            # Since we compute the VRP in minibatches, some tours may have
            # number of stops. We force the vehicles to remain at the depot 
            # in these cases, and logp := 0
            #is_done = dynamic.permute(0,2,1)[:, 1].sum(1).eq(0).float()
            #logp = logp * (1. - is_done)
            
            dynamic_capcity = update_state(demands, dynamic_capcity, index.unsqueeze(-1),c = 0.5)
            mask, mask1 = update_mask(demands, dynamic_capcity, index.unsqueeze(-1), mask1, t)
            
            # And update the mask so we don't re-visit if we don't need to
            tour_logp.append(logp.unsqueeze(1))
            tour_idx.append(index.data.unsqueeze(1))
            
            #mask = mask_fn(ptr_prev, dynamic.permute(0,2,1), ptr.data)
            _input = torch.gather(Trans_hidden, 1,
                                  index.unsqueeze(-1).unsqueeze(-1).expand(Trans_hidden.size(0), -1,
                                                                           Trans_hidden.size(2))).squeeze(1)
            
        tour_idx = torch.cat(tour_idx, dim=1)  # (batch_size, seq_len)
        tour_logp = torch.cat(tour_logp, dim=1)  # (batch_size, seq_len)
        return tour_idx, tour_logp
    
def validate(data_loader, Critic, reward_fn, render_fn=None, save_dir='.',num_plot=5):
    """Used to monitor progress on a validation set & optionally plot solution."""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    rewards = []
    for batch_idx, batch in enumerate(data_loader):
        with torch.no_grad():
            static, dynamic, x0 = batch
            static  = torch.movedim(static,1,2).to(device) # (B,size,feat)
            dynamic = torch.movedim(dynamic,1,2).to(device) # (B,size,feat)
            x0      = torch.movedim(x0,1,2).squeeze(1).to(device) if len(x0) > 0 else None  # (B,size,feat)

            with torch.no_grad():
                tour_indices, _ = Critic(static, dynamic, decoder_input = x0,deterministic=True)

            reward = reward_fn(static.permute(0,2,1), tour_indices).mean().item()
            rewards.append(reward)

            if render_fn is not None and batch_idx < num_plot:
                name = 'batch%d_%2.4f.png'%(batch_idx, reward)
                path = os.path.join(save_dir, name)
                render_fn(static.permute(0,2,1), tour_indices, path)
    return np.mean(rewards)

# Training

In [None]:
########################
# Training Hyperparameters
#######################
# Dynamic ones
size = 20                 # Size of the CVRP Problem
max_load = 30             # Max load for the truck
MAX_DEMAND = 9            # Max Demand for each agent

# Fixed Parameter for all sizes
TOL  =  1e-3              # Tolerance for Actor-critic 
TINY =  1e-15
learn_rate = 1e-4         # learning rate
batch_size = 512         # batch_size
train_size = 512         
compare_size = 512       

valid_size = 10000         # validation size
valid_batch = 10000

B_valLoop = 20
steps = 2500              # training steps
n_epoch = 100             # epochs

print('=========================')
print('prepare to train')
print('=========================')
print('Hyperparameters:')
print('size', size)
print('learning rate', learn_rate)
print('batch size', batch_size)
print('validation size', valid_size)
print('steps', steps)
print('epoch', n_epoch)
print('=========================')

###################
# Instantiate a training network and a baseline network
###################
try: 
    del Actor # remove existing model
    del Critic # remove existing model
except:
    pass

valid_data = VehicleRoutingDataset(valid_size,size,max_load,MAX_DEMAND)
valid_loader = DataLoader(valid_data, valid_batch, False, num_workers=0)
Actor  = HPN_PDP(n_feature=2, n_hidden=128)
Critic = HPN_PDP(n_feature=2, n_hidden=128)
optimizer = optim.Adam(Actor.parameters(), lr=learn_rate)

# Putting Critic model on the eval mode
Actor = Actor.to(device)
Critic = Critic.to(device)
Critic.eval()

# uncomment these lines if trained with multiple GPUs
print(torch.cuda.device_count())
if torch.cuda.device_count()>1:
    Actor = nn.DataParallel(Actor)
    Critic = nn.DataParallel(Critic)
# uncomment these lines if trained with multiple GPUs

########################
# Remember to first initialize the model and optimizer, then load the dictionary locally.
#######################

epoch_ckpt = 0
tot_time_ckpt = 0

val_mean = []
val_std  = []

plot_performance_train = []
plot_performance_baseline = []

################################################################# Restart Training With Check Points ######################################################
#********************************************# Uncomment these lines to re-start training with saved checkpoint #********************************************#
#************************************************************************************************************************************************************#
"""
checkpoint_file = "../input/pdpsize50/checkpoint_21-11-22--01-27-17-n50-gpu0.pkl"
checkpoint = torch.load(checkpoint_file, map_location=device)
epoch_ckpt = checkpoint['epoch'] + 1
tot_time_ckpt = checkpoint['tot_time']
plot_performance_train = checkpoint['plot_performance_train']
plot_performance_baseline = checkpoint['plot_performance_baseline']
Critic.load_state_dict(checkpoint['model_baseline'])
Actor.load_state_dict(checkpoint['model_train'])
optimizer.load_state_dict(checkpoint['optimizer'])

print('Re-start training with saved checkpoint file={:s}\n  Checkpoint at epoch= {:d} and time={:.3f}min\n'.format(checkpoint_file,epoch_ckpt-1,tot_time_ckpt/60))
del checkpoint
"""
#*************************************************************************************************************************************************************#
#*********************************************# Uncomment these lines to re-start training with saved checkpoint #********************************************#


###################
#  Main training loop 
###################
start_training_time = time.time()
time_stamp = datetime.datetime.now().strftime("%y-%m-%d--%H-%M-%S")
zero_to_bsz = torch.arange(batch_size, device=device) # [0,1,...,bsz-1]
R = 0
C = 0
for epoch in range(0,n_epoch):
    # re-start training with saved checkpoint
    epoch += epoch_ckpt
    
    ###################
    # Train model for one epoch
    ###################
    start = time.time()
    Actor.train()
    
    for i in range(1,steps+1):
        
        train_data = VehicleRoutingDataset(train_size,size,max_load,MAX_DEMAND)
        train_loader = DataLoader(train_data, batch_size, False, num_workers=0)
        for batch_idx, batch in enumerate(train_loader):

            static, dynamic, x0 = batch
            static  = torch.movedim(static,1,2).to(device) # (B,size,feat)
            dynamic = torch.movedim(dynamic,1,2).to(device) # (B,size,feat)
            x0      = torch.movedim(x0,1,2).squeeze(1).to(device) if len(x0) > 0 else None
            
            tour_indices, logprobs = Actor(static, dynamic,decoder_input = x0,deterministic=False)
            R = reward_fn(static.permute(0,2,1), tour_indices)
            
            with torch.no_grad():
                tour_indices, _ = Critic(static, dynamic, decoder_input = x0,deterministic=True)
            C = reward_fn(static.permute(0,2,1), tour_indices)
            
            ###################
            # Loss and backprop handling 
            ###################
            loss = torch.mean((R - C) * logprobs.sum(dim=1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if i % 50 == 0:
            print("epoch:{}, batch:{}/{}, reward:{}".format(epoch, i, steps, R.mean().item()))

        time_one_epoch = time.time() - start
        time_tot = time.time() - start_training_time + tot_time_ckpt

    ###################
    # Evaluate train model and baseline on 1k random TSP instances
    ###################
    Actor.eval()
    
    mean_tour_length_actor = 0
    mean_tour_length_critic = 0

    for step in range(0,B_valLoop):
        # compute tour for model and baseline
        comp_data = VehicleRoutingDataset(compare_size,size,max_load,MAX_DEMAND)
        comp_loader = DataLoader(comp_data, compare_size, False, num_workers=0)
        
        for batch_idx, batch in enumerate(comp_loader):
            
            static, dynamic, x0 = batch
            static  = torch.movedim(static,1,2).to(device) # (B,size,feat)
            dynamic = torch.movedim(dynamic,1,2).to(device) # (B,size,feat)
            x0      = torch.movedim(x0,1,2).squeeze(1).to(device) if len(x0) > 0 else None  # (B,size,feat)
            
            with torch.no_grad():
                tour_indicesActor, _ = Actor(static, dynamic, decoder_input = x0,deterministic =  True)
                tour_indicesCritic, _ = Critic(static, dynamic, decoder_input = x0,deterministic = True)
            
        R = reward_fn(static.permute(0,2,1), tour_indicesActor)
        C = reward_fn(static.permute(0,2,1), tour_indicesCritic)
                
        mean_tour_length_actor  += R.mean().item()
        mean_tour_length_critic += C.mean().item()
        
    mean_tour_length_actor  =  mean_tour_length_actor  / B_valLoop
    mean_tour_length_critic =  mean_tour_length_critic / B_valLoop

    # evaluate train model and baseline and update if train model is better
    update_baseline = mean_tour_length_actor + TOL < mean_tour_length_critic
    print('Avg Actor {} --- Avg Critic {}'.format(mean_tour_length_actor,mean_tour_length_critic))
    if update_baseline:
        Critic.load_state_dict(Actor.state_dict())
        print('My actor is going on the right road Hallelujah :) Updated')
        
    ###################
    # val train model and baseline on 1k random TSP instances
    ###################
    
    # Saving checkpoint and valied images
    checkpoint_dir = os.path.join("checkpoint")
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
     
    with torch.no_grad():
        tour_len = validate(valid_loader, Critic, reward_fn, render_fn,checkpoint_dir, num_plot=5)
    print('validation tour length:', tour_len)
        
    # For checkpoint
    plot_performance_train.append([(epoch+1), mean_tour_length_actor])
    plot_performance_baseline.append([(epoch+1), mean_tour_length_critic])
    
    # Print and save in txt file
    mystring_min = 'Epoch: {:d}, epoch time: {:.3f}min, tot time: {:.3f}day, L_actor: {:.3f}, L_critic: {:.3f}, update: {}'.format(
        epoch, time_one_epoch/60, time_tot/86400, mean_tour_length_actor, mean_tour_length_critic, update_baseline)
    
    print(mystring_min)
    print('Save Checkpoints')
        
    torch.save({
        'epoch': epoch,
        'time': time_one_epoch,
        'tot_time': time_tot,
        'loss': loss.item(),
        'plot_performance_train': plot_performance_train,
        'plot_performance_baseline': plot_performance_baseline,
        'mean_tour_length_val': tour_len,
        'model_baseline': Critic.state_dict(),
        'model_train': Actor.state_dict(),
        'optimizer': optimizer.state_dict(),
        }, '{}.pkl'.format(checkpoint_dir + "/checkpoint_" + time_stamp + "-n{}".format(size) + "-gpu{}".format(gpu_id)))

# Simple Test

In [None]:
# Saving checkpoint and valied images
checkpoint_dir = os.path.join("checkpoint")
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
valid_data = VehicleRoutingDataset(1000,size,max_load,MAX_DEMAND)
valid_loader = DataLoader(valid_data, 1000, False, num_workers=0)
with torch.no_grad():
    tour_len = validate(valid_loader, Critic, reward_fn, render_fn,checkpoint_dir, num_plot=5)
print('validation tour length:', tour_len)