In [None]:
import torch
import torch.nn as nn
import time
import argparse

import os
import datetime

from torch.distributions.categorical import Categorical
from scipy.spatial import distance
# visualization 
%matplotlib inline
from IPython.display import set_matplotlib_formats, clear_output
set_matplotlib_formats('png2x','pdf')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
try: 
    import networkx as nx
    from scipy.spatial.distance import pdist, squareform
except:
    pass
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

device = torch.device("cpu"); gpu_id = -1 # select CPU

gpu_id = '0' # select a single GPU  
#gpu_id = '2,3' # select multiple GPUs  
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)  
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('GPU name: {:s}, gpu_id: {:s}'.format(torch.cuda.get_device_name(0),gpu_id))   
    
print(device)

# Model Architecture

In [None]:
import math
import numpy as np
import torch.nn.functional as F
import random
import torch.optim as optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook

class Transformer_encoder_net(nn.Module):
    """
    Encoder network based on self-attention transformer
    Inputs :  
      h of size      (bsz, nb_nodes+1, dim_emb)    batch of input cities
    Outputs :  
      h of size      (bsz, nb_nodes+1, dim_emb)    batch of encoded cities
      score of size  (bsz, nb_nodes+1, nb_nodes+1) batch of attention scores
    """
    
    def __init__(self, nb_layers, dim_emb, nb_heads, dim_ff, batchnorm):
        super(Transformer_encoder_net, self).__init__()
        assert dim_emb == nb_heads* (dim_emb//nb_heads) # check if dim_emb is divisible by nb_heads
        self.MHA_layers = nn.ModuleList( [nn.MultiheadAttention(dim_emb, nb_heads) for _ in range(nb_layers)] )
        self.linear1_layers = nn.ModuleList( [nn.Linear(dim_emb, dim_ff) for _ in range(nb_layers)] )
        self.linear2_layers = nn.ModuleList( [nn.Linear(dim_ff, dim_emb) for _ in range(nb_layers)] )   
        if batchnorm:
            self.norm1_layers = nn.ModuleList( [nn.BatchNorm1d(dim_emb) for _ in range(nb_layers)] )
            self.norm2_layers = nn.ModuleList( [nn.BatchNorm1d(dim_emb) for _ in range(nb_layers)] )
        else:
            self.norm1_layers = nn.ModuleList( [nn.LayerNorm(dim_emb) for _ in range(nb_layers)] )
            self.norm2_layers = nn.ModuleList( [nn.LayerNorm(dim_emb) for _ in range(nb_layers)] )
        self.nb_layers = nb_layers
        self.nb_heads = nb_heads
        self.batchnorm = batchnorm
        
    def forward(self, h):      
        # PyTorch nn.MultiheadAttention requires input size (seq_len, bsz, dim_emb) 
        h = h.transpose(0,1) # size(h)=(nb_nodes, bsz, dim_emb)  
        # L layers
        for i in range(self.nb_layers):
            h_rc = h # residual connection, size(h_rc)=(nb_nodes, bsz, dim_emb)
            h, score = self.MHA_layers[i](h, h, h) # size(h)=(nb_nodes, bsz, dim_emb), size(score)=(bsz, nb_nodes, nb_nodes)
            # add residual connection
            
            h = h_rc + h # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                # Pytorch nn.BatchNorm1d requires input size (bsz, dim, seq_len)
                h = h.permute(1,2,0).contiguous() # size(h)=(bsz, dim_emb, nb_nodes)
                h = self.norm1_layers[i](h)       # size(h)=(bsz, dim_emb, nb_nodes)
                h = h.permute(2,0,1).contiguous() # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                h = self.norm1_layers[i](h)       # size(h)=(nb_nodes, bsz, dim_emb) 
            # feedforward
            h_rc = h # residual connection
            h = self.linear2_layers[i](torch.relu(self.linear1_layers[i](h)))
            h = h_rc + h # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                h = h.permute(1,2,0).contiguous() # size(h)=(bsz, dim_emb, nb_nodes)
                h = self.norm2_layers[i](h)       # size(h)=(bsz, dim_emb, nb_nodes)
                h = h.permute(2,0,1).contiguous() # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                h = self.norm2_layers[i](h) # size(h)=(nb_nodes, bsz, dim_emb)
        # Transpose h
        h = h.transpose(0,1) # size(h)=(bsz, nb_nodes, dim_emb)
        return h, score

class Attention(nn.Module):
    def __init__(self, n_hidden):
        super(Attention, self).__init__()
        
        self.size = 0
        self.batch_size = 0
        self.dim = n_hidden
        
        v  = torch.FloatTensor(n_hidden).cuda()
        self.v  = nn.Parameter(v)
        self.v.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        
        # parameters for pointer attention
        self.Wref = nn.Linear(n_hidden, n_hidden)
        self.Wq = nn.Linear(n_hidden, n_hidden)
    
    
    def forward(self, q, ref):       # query and reference
        self.batch_size = q.size(0)
        self.size = int(ref.size(0) / self.batch_size)
        q = self.Wq(q)     # (B, dim)
        ref = self.Wref(ref)
        ref = ref.view(self.batch_size, self.size, self.dim)  # (B, size, dim)
        
        q_ex = q.unsqueeze(1).repeat(1, self.size, 1) # (B, size, dim)
        # v_view: (B, dim, 1)
        v_view = self.v.unsqueeze(0).expand(self.batch_size, self.dim).unsqueeze(2)
        
        # (B, size, dim) * (B, dim, 1)
        u = torch.bmm(torch.tanh(q_ex + ref), v_view).squeeze(2)
        
        return u, ref
    
class LSTM(nn.Module):
    def __init__(self, n_hidden):
        super(LSTM, self).__init__()
        
        # parameters for input gate
        self.Wxi = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Whi = nn.Linear(n_hidden, n_hidden)    # W(ht)
        self.wci = nn.Linear(n_hidden, n_hidden)    # w(ct)
        
        # parameters for forget gate
        self.Wxf = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Whf = nn.Linear(n_hidden, n_hidden)    # W(ht)
        self.wcf = nn.Linear(n_hidden, n_hidden)    # w(ct)
        
        # parameters for cell gate
        self.Wxc = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Whc = nn.Linear(n_hidden, n_hidden)    # W(ht)
        
        # parameters for forget gate
        self.Wxo = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Who = nn.Linear(n_hidden, n_hidden)    # W(ht)
        self.wco = nn.Linear(n_hidden, n_hidden)    # w(ct)
    
    
    def forward(self, x, h, c):       # query and reference
        
        # input gate
        i = torch.sigmoid(self.Wxi(x) + self.Whi(h) + self.wci(c))
        # forget gate
        f = torch.sigmoid(self.Wxf(x) + self.Whf(h) + self.wcf(c))
        # cell gate
        c = f * c + i * torch.tanh(self.Wxc(x) + self.Whc(h))
        # output gate
        o = torch.sigmoid(self.Wxo(x) + self.Who(h) + self.wco(c))
        
        h = o * torch.tanh(c)
        
        return h, c

class HPN(nn.Module):
    def __init__(self, n_feature, n_hidden):

        super(HPN, self).__init__()
        self.city_size = 0
        self.batch_size = 0
        self.dim = n_hidden
        
        # lstm for first turn
        #self.lstm0 = nn.LSTM(n_hidden, n_hidden)
        
        # pointer layer
        self.pointer = Attention(n_hidden)
        self.TransPointer = Attention(n_hidden)
        
        # lstm encoder
        self.encoder = LSTM(n_hidden)
        
        # trainable first hidden input
        h0 = torch.FloatTensor(n_hidden)
        c0 = torch.FloatTensor(n_hidden)
        
        # trainable latent variable coefficient
        alpha = torch.ones(1).cuda()
        
        self.h0 = nn.Parameter(h0)
        self.c0 = nn.Parameter(c0)
        
        self.alpha = nn.Parameter(alpha)
        self.h0.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        self.c0.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        
        r1 = torch.ones(1)
        r2 = torch.ones(1)
        r3 = torch.ones(1)
        self.r1 = nn.Parameter(r1)
        self.r2 = nn.Parameter(r2)
        self.r3 = nn.Parameter(r3)
        
        # embedding
        self.embedding_x = nn.Linear(n_feature, n_hidden)
        self.embedding_all = nn.Linear(n_feature, n_hidden)
        self.Transembedding_all = Transformer_encoder_net(6, 128, 8, 512, batchnorm=True)
        
        # vector to start decoding 
        self.start_placeholder = nn.Parameter(torch.randn(n_hidden))
        
        # weights for GNN
        self.W1 = nn.Linear(n_hidden, n_hidden)
        self.W2 = nn.Linear(n_hidden, n_hidden)
        self.W3 = nn.Linear(n_hidden, n_hidden)
        
        # aggregation function for GNN
        self.agg_1 = nn.Linear(n_hidden, n_hidden)
        self.agg_2 = nn.Linear(n_hidden, n_hidden)
        self.agg_3 = nn.Linear(n_hidden, n_hidden)
    
    
    def forward(self,context,Transcontext, x, X_all, mask, h=None, c=None, latent=None):
        '''
        Inputs (B: batch size, size: city size, dim: hidden dimension)
        
        x: current city coordinate (B, 2)
        X_all: all cities' cooridnates (B, size, 2)
        mask: mask visited cities
        h: hidden variable (B, dim)
        c: cell gate (B, dim)
        latent: latent pointer vector from previous layer (B, size, dim)
        
        Outputs
        
        softmax: probability distribution of next city (B, size)
        h: hidden variable (B, dim)
        c: cell gate (B, dim)
        latent_u: latent pointer vector for next layer
        '''
        
        self.batch_size = X_all.size(0)
        self.city_size = X_all.size(1)
        

        # the weights share across all the cities
        # Embedd All Cities
        if h is None or c is None:
            x          = self.start_placeholder    
            context = self.embedding_all(X_all)
            Transcontext,_ = self.Transembedding_all(context)
            
            # =============================
            # graph neural network encoder
            # =============================

            # (B, size, dim)
            context = context.reshape(-1, self.dim)
            Transcontext = Transcontext.reshape(-1, self.dim)

            context = self.r1 * self.W1(context)\
                + (1-self.r1) * F.relu(self.agg_1(context/(self.city_size-1)))

            context = self.r2 * self.W2(context)\
                + (1-self.r2) * F.relu(self.agg_2(context/(self.city_size-1)))

            context = self.r3 * self.W3(context)\
                + (1-self.r3) * F.relu(self.agg_3(context/(self.city_size-1)))
            h0 = self.h0.unsqueeze(0).expand(self.batch_size, self.dim)
            c0 = self.c0.unsqueeze(0).expand(self.batch_size, self.dim)

            h0 = h0.unsqueeze(0).contiguous()
            c0 = c0.unsqueeze(0).contiguous()
            
            # let h0, c0 be the hidden variable of first turn
            h = h0.squeeze(0)
            c = c0.squeeze(0)
        else:
            x          = self.embedding_x(x)
        # LSTM encoder
        h, c = self.encoder(x, h, c)
        # query vector
        q = h
        # pointer
        u1, _ = self.pointer(q, context)
        u2 ,_ = self.TransPointer(q,Transcontext)
        u = u1 + u2
        latent_u = u.clone()
        u = 10 * torch.tanh(u) + mask
        
        if latent is not None:
            u += self.alpha * latent
    
        return context,Transcontext,F.softmax(u, dim=1), h, c, latent_u

# Data Generation

In [None]:
'''
generate training data
'''
DataGen  = HPN(n_feature=2, n_hidden=128)
DataGen = DataGen.to(device)
DataGen.eval()
# Upload checkpoint For pre-trained model "HPN for TSP"
checkpoint_file = "../input/checkpoint_21-09-05--08-53-44-n50-gpu0.pkl"
checkpoint = torch.load(checkpoint_file, map_location=device)
DataGen.load_state_dict(checkpoint['model_baseline'])
print("Done")
del checkpoint

def ModelSolution(B,size,Critic):
    zero_to_bsz = torch.arange(B, device=device) # [0,1,...,bsz-1]
    X = torch.rand(B, size, 2,device = device)
    mask = torch.zeros(B,size,device = device)
    solution = []    
    Y = X.view(B, size, 2)           # to the same batch size
    x = Y[:,0,:]
    h = None
    c = None
    context = None
    Transcontext = None
    
    with torch.no_grad():
        for k in range(size):
            context,Transcontext,output, h, c, _ = Critic(context,Transcontext,x=x, X_all=X, h=h, c=c, mask=mask)
            idx = torch.argmax(output, dim=1)
            x = Y[zero_to_bsz, idx.data]
            solution.append(x.cpu().numpy())
            mask[zero_to_bsz, idx.data] += -np.inf
        graph = torch.tensor(solution).permute(1,0,2)# Shape = (B,size,2)
    return graph

def generate_data(model,B=512, size=50):
    #X = np.zeros([B, size, 4])  # xi, yi, ei, li, ci
    solutions = torch.zeros(B,device = 'cuda')
    route = [x for x in range(size)] + [0]
    route = torch.tensor(route).unsqueeze(0).repeat(B,1)
    X = ModelSolution(B,size,model).to('cuda')
    
    arange_vec = torch.arange(B, device=X.device)
    ColAdded = torch.zeros((B,size,2),device = X.device)
    X = torch.cat((X,ColAdded),dim = 2).to(X.device)
    X[arange_vec,0,3] = 2 * torch.rand(B,device = X.device)  # l0 = rand
    first_cities = X[arange_vec, route[:,0], :2] # size(first_cities)=(bsz,2) 
    previous_cities = first_cities
    cur_time = torch.zeros(B, device=X.device)
    tour_len = torch.zeros(B, device=X.device)
    zeros = torch.zeros(B,device = X.device)
    with torch.no_grad():
        for k in range(1, size):
            # generate data with approximate solutions
            current_cities = X[arange_vec, route[:,k], :2] 
            cur_time += torch.sum( (current_cities - previous_cities)**2 , dim=1 )**0.5
            tour_len += torch.sum( (current_cities - previous_cities)**2 , dim=1 )**0.5
            
            previous_cities = current_cities
            X[arange_vec,k,2] = torch.maximum(zeros, (cur_time - 2*torch.rand(B,device = X.device)))  # entering time 0<= ei <= cur_time
            X[arange_vec,k,3] = cur_time + 2*torch.rand(B,device = X.device) + 1  # leaving time li >= cur_time
        
        tour_len += torch.sum( (current_cities - first_cities)**2 , dim=1 )**0.5   
        solutions += tour_len
        
    X = np.array(X.cpu().numpy())
    np.random.shuffle(X)
    X = torch.tensor(X).to('cuda')
    
    return X, solutions

# Training

In [None]:
size = 20
learn_rate = 1e-4    # learning rate
B = 512              # batch_size
TOL  =  1e-3
B_val = 1000           # validation size
B_valLoop = 20
steps = 2500 # training steps
n_epoch = 100       # epochs

print('=========================')
print('prepare to train')
print('=========================')
print('Hyperparameters:')
print('size', size)
print('learning rate', learn_rate)
print('batch size', B)
print('validation size', B_val)
print('steps', steps)
print('epoch', n_epoch)
print('=========================')

###################
# Instantiate a training network and a baseline network
###################

try: 
    del ActorLow # remove existing model
    del CriticLow # remove existing model
except:
    pass

ActorLow  = HPN(n_feature=4, n_hidden=128)
CriticLow = HPN(n_feature=4, n_hidden=128)
optimizer = optim.Adam(ActorLow.parameters(), lr=learn_rate)

# Putting Critic model on the eval mode
ActorLow = ActorLow.to(device)
CriticLow = CriticLow.to(device)

CriticLow.eval()

########################
# Remember to first initialize the model and optimizer, then load the dictionary locally.
#######################
epoch_ckpt = 0
tot_time_ckpt = 0
val_mean = []
val_std = []
val_accuracy = []
plot_performance_train = []
plot_performance_baseline = []
#********************************************# Uncomment these lines to re-start training with saved checkpoint #********************************************#
"""
checkpoint_file = "../input/nonhiersize20/checkpoint_21-09-05--08-55-01-n50-gpu0.pkl"
checkpoint = torch.load(checkpoint_file, map_location=device)
epoch_ckpt = checkpoint['epoch'] + 1
tot_time_ckpt = checkpoint['tot_time']
plot_performance_train = checkpoint['plot_performance_train']
plot_performance_baseline = checkpoint['plot_performance_baseline']
CriticLow.load_state_dict(checkpoint['model_baseline'])
ActorLow.load_state_dict(checkpoint['model_train'])
optimizer.load_state_dict(checkpoint['optimizer'])

print('Re-start training with saved checkpoint file={:s}\n  Checkpoint at epoch= {:d} and time={:.3f}min\n'.format(checkpoint_file,epoch_ckpt-1,tot_time_ckpt/60))
del checkpoint
"""
#*********************************************# Uncomment these lines to re-start training with saved checkpoint #********************************************#


###################
#  Main training loop 
###################
start_training_time = time.time()
time_stamp = datetime.datetime.now().strftime("%y-%m-%d--%H-%M-%S")

C = 0     # baseline
R = 0     # reward

zero_to_bsz = torch.arange(B, device=device) # [0,1,...,bsz-1]
zero_to_bsz_val = torch.arange(B_val, device=device) # [0,1,...,bsz-1]
for epoch in range(0,n_epoch):
    
    # re-start training with saved checkpoint
    epoch += epoch_ckpt

    ###################
    # Train model for one epoch
    ###################
    
    start = time.time()
    ActorLow.train()
    for i in range(1,steps+1):
        
        X, _ = generate_data(DataGen,B=B, size=size)

        Enter = X[:,:,2]   # Entering time
        Leave = X[:,:,3]   # Leaving time
        mask = torch.zeros(B,size).cuda()
    
        R = 0
        logprobs = 0
        reward = 0
        
        time_wait = torch.zeros(B).cuda()
        time_penalty = torch.zeros(B).cuda()
        total_time_penalty_train = torch.zeros(B).cuda()
        total_time_cost_train = torch.zeros(B).cuda()
        total_time_wait_train = torch.zeros(B).cuda()

        
        # X = X.view(B,size,3)
        # Time = Time.view(B,size)

        x = X[:,0,:]
        h = None
        c = None
        
        context = None
        Transcontext = None 
        #Actor Sampling phase
        for k in range(size):
            context,Transcontext,output, h, c, _ = ActorLow(context,Transcontext,x=x, X_all=X, h=h, c=c, mask=mask)            
            sampler = torch.distributions.Categorical(output)
            idx = sampler.sample()
            
            y_cur = X[zero_to_bsz, idx.data].clone()
            if k == 0:
                y_ini = y_cur.clone()
            if k > 0:
                reward = torch.norm(y_cur[:,:2] - y_pre[:,:2], dim=1)
                
            y_pre = y_cur.clone()
            x = X[zero_to_bsz, idx.data].clone()
            
            R += reward
            total_time_cost_train += reward
            # enter time
            enter = Enter[zero_to_bsz, idx.data]
            leave = Leave[zero_to_bsz, idx.data]
            
            # determine the total reward and current enter time
            time_wait = torch.lt(total_time_cost_train, enter).float()*(enter - total_time_cost_train)  
            total_time_wait_train += time_wait     # total time cost
            total_time_cost_train += time_wait
            time_penalty = torch.lt(leave, total_time_cost_train).float()*10
            total_time_cost_train += time_penalty
            total_time_penalty_train += time_penalty
            logprobs += torch.log(output[zero_to_bsz, idx.data]) 
            
            mask[zero_to_bsz, idx.data] += -np.inf 
        R += torch.norm(y_cur[:,:2] - y_ini[:,:2], dim=1)
        total_time_cost_train += torch.norm(y_cur[:,:2] - y_ini[:,:2], dim=1)

       
       
        # Critic Baseline phase
        C = 0
        baseline = 0
        mask = torch.zeros(B,size).cuda()        
        time_wait = torch.zeros(B).cuda()
        time_penalty = torch.zeros(B).cuda()
        total_time_penalty_base = torch.zeros(B).cuda()
        total_time_cost_base = torch.zeros(B).cuda()
        total_time_wait_base = torch.zeros(B).cuda()

        x = X[:,0,:]
        h = None
        c = None
        
        context = None
        Transcontext = None 

        # compute tours for baseline without grad
        with torch.no_grad():
            for k in range(size):
                context,Transcontext,output, h, c, _ = CriticLow(context,Transcontext,x=x, X_all=X, h=h, c=c, mask=mask)
                idx = torch.argmax(output, dim=1) # ----> greedy baseline critic
                y_cur = X[zero_to_bsz, idx.data].clone()
                if k == 0:
                    y_ini = y_cur.clone()
                if k > 0:
                    baseline = torch.norm(y_cur[:,:2] - y_pre[:,:2], dim=1)

                y_pre = y_cur.clone()
                x = X[zero_to_bsz, idx.data].clone()

                C += baseline
                total_time_cost_base += baseline
                # enter time
                enter = Enter[zero_to_bsz, idx.data]
                leave = Leave[zero_to_bsz, idx.data]

                # determine the total reward and current enter time
                time_wait = torch.lt(total_time_cost_base, enter).float()*(enter - total_time_cost_base)  
                total_time_wait_base += time_wait     # total time cost
                total_time_cost_base += time_wait
                time_penalty = torch.lt(leave, total_time_cost_base).float()*10
                total_time_cost_base += time_penalty
                total_time_penalty_base += time_penalty

                mask[zero_to_bsz, idx.data] += -np.inf 
            C += torch.norm(y_cur[:,:2] - y_ini[:,:2], dim=1)
            total_time_cost_base += torch.norm(y_cur[:,:2] - y_ini[:,:2], dim=1)

        ###################
        # Loss and backprop handling 
        ###################
        
        loss = torch.mean((total_time_cost_train - total_time_cost_base) * logprobs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 50 == 0:
            print("epoch:{}, batch:{}/{},  total time:{}, reward:{}, time:{}"
                .format(epoch, i, steps, total_time_cost_train.mean().item(),
                        R.mean().item(), total_time_wait_train.mean().item()))
                
    time_one_epoch = time.time() - start
    time_tot = time.time() - start_training_time + tot_time_ckpt
    
    ###################
    # Evaluate train model and baseline on 1k random TSP instances
    ###################
    
    ActorLow.eval()
    mean_tour_length_actor = 0
    mean_tour_length_critic = 0

    for step in range(0,B_valLoop):
        # compute tour for model and baseline
        X, solutions = generate_data(DataGen,B=B, size=size)
        Enter = X[:,:,2]   # Entering time
        Leave = X[:,:,3]   # Leaving time

        mask = torch.zeros(B,size).cuda()
    
        R = 0
        reward = 0
        
        time_wait = torch.zeros(B).cuda()
        time_penalty = torch.zeros(B).cuda()
        total_time_penalty_train = torch.zeros(B).cuda()
        total_time_cost_train = torch.zeros(B).cuda()
        total_time_wait_train = torch.zeros(B).cuda()

        
        # X = X.view(B,size,3)
        # Time = Time.view(B,size)

        x = X[:,0,:]
        h = None
        c = None
        
        context = None
        Transcontext = None 
        #Actor ِGreedy phase
        with torch.no_grad():
            for k in range(size):
                context,Transcontext,output, h, c, _ = ActorLow(context,Transcontext,x=x, X_all=X, h=h, c=c, mask=mask)            
                idx = torch.argmax(output, dim=1) # ----> greedy baseline critic
                
                y_cur = X[zero_to_bsz, idx.data].clone()
                if k == 0:
                    y_ini = y_cur.clone()
                if k > 0:
                    reward = torch.norm(y_cur[:,:2] - y_pre[:,:2], dim=1)

                y_pre = y_cur.clone()
                x = X[zero_to_bsz, idx.data].clone()

                R += reward
                total_time_cost_train += reward
                # enter time
                enter = Enter[zero_to_bsz, idx.data]
                leave = Leave[zero_to_bsz, idx.data]

                # determine the total reward and current enter time
                time_wait = torch.lt(total_time_cost_train, enter).float()*(enter - total_time_cost_train)  
                total_time_wait_train += time_wait     # total time cost
                total_time_cost_train += time_wait
                
                time_penalty = torch.lt(leave, total_time_cost_train).float()*10
                #total_time_cost_train += time_penalty
                total_time_penalty_train += time_penalty

                mask[zero_to_bsz, idx.data] += -np.inf
        R += torch.norm(y_cur[:,:2] - y_ini[:,:2], dim=1)
        total_time_cost_train += torch.norm(y_cur[:,:2] - y_ini[:,:2], dim=1)

       
       
        # Critic Baseline phase
        C = 0
        baseline = 0
        mask = torch.zeros(B,size).cuda()        
        time_wait = torch.zeros(B).cuda()
        time_penalty = torch.zeros(B).cuda()
        total_time_penalty_base = torch.zeros(B).cuda()
        total_time_cost_base = torch.zeros(B).cuda()
        total_time_wait_base = torch.zeros(B).cuda()

        x = X[:,0,:]
        h = None
        c = None
        
        context = None
        Transcontext = None 

        # compute tours for baseline without grad
        with torch.no_grad():
            for k in range(size):
                context,Transcontext,output, h, c, _ = CriticLow(context,Transcontext,x=x, X_all=X, h=h, c=c, mask=mask)
                idx = torch.argmax(output, dim=1) # ----> greedy baseline critic
                
                y_cur = X[zero_to_bsz, idx.data].clone()
                if k == 0:
                    y_ini = y_cur.clone()
                if k > 0:
                    baseline = torch.norm(y_cur[:,:2] - y_pre[:,:2], dim=1)

                y_pre = y_cur.clone()
                x = X[zero_to_bsz, idx.data].clone()

                C += baseline
                total_time_cost_base += baseline
                # enter time
                enter = Enter[zero_to_bsz, idx.data]
                leave = Leave[zero_to_bsz, idx.data]

                # determine the total reward and current enter time
                time_wait = torch.lt(total_time_cost_base, enter).float()*(enter - total_time_cost_base)  
                total_time_wait_base += time_wait     # total time cost
                total_time_cost_base += time_wait
                
                time_penalty = torch.lt(leave, total_time_cost_base).float()*10
                #total_time_cost_base += time_penalty
                total_time_penalty_base += time_penalty

                mask[zero_to_bsz, idx.data] += -np.inf 
                
        C += torch.norm(y_cur[:,:2] - y_ini[:,:2], dim=1)
        total_time_cost_base += torch.norm(y_cur[:,:2] - y_ini[:,:2], dim=1)
        
        mean_tour_length_actor  += total_time_cost_train.mean().item()
        mean_tour_length_critic += total_time_cost_base.mean().item()
        
    mean_tour_length_actor  =  mean_tour_length_actor  / B_valLoop
    mean_tour_length_critic =  mean_tour_length_critic / B_valLoop

    # evaluate train model and baseline and update if train model is better
    update_baseline = mean_tour_length_actor < mean_tour_length_critic
    print('Avg Actor {} --- Avg Critic {}'.format(mean_tour_length_actor,mean_tour_length_critic))
    if update_baseline:
        CriticLow.load_state_dict(ActorLow.state_dict())
        print('My actor is going on the right road Hallelujah :) Updated')
        
    ###################
    # Valdiation train model and baseline on 1k random TSP instances
    ###################
    
    with torch.no_grad():
        print("optimal upper bound:{}".format(solutions.mean()))
        X_val, _ = generate_data(DataGen,B=B_val, size=size)
        Enter = X_val[:,:,2]   # Entering time
        Leave = X_val[:,:,3]   # Leaving time
        mask = torch.zeros(B_val, size).to(device)

        baseline = 0
        time_wait = torch.zeros(B_val).to(device)
        time_penalty = torch.zeros(B_val).to(device)
        total_time_cost = torch.zeros(B_val).to(device)
        total_time_penalty = torch.zeros(B_val).to(device)

        x = X_val[:,0,:]
        h = None
        c = None
        context = None
        Transcontext = None 

        for k in range(size):
            context,Transcontext,output, h, c, _ = CriticLow(context,Transcontext,x=x, X_all=X_val, h=h, c=c, mask=mask)
            idx = torch.argmax(output, dim=1)    # greedy baseline
            y_cur = X_val[zero_to_bsz_val, idx.data].clone()
            if k == 0:
                y_ini = y_cur.clone()
            if k > 0:
                baseline = torch.norm(y_cur[:,:2] - y_pre[:,:2], dim=1)
            y_pre = y_cur.clone()
            x = X_val[zero_to_bsz_val, idx.data].clone()
            total_time_cost += baseline

            # enter time
            enter = Enter[zero_to_bsz_val, idx.data]
            leave = Leave[zero_to_bsz_val, idx.data]

            # determine the total reward and current enter time
            time_wait = torch.lt(total_time_cost, enter).float()*(enter - total_time_cost)  
            total_time_cost += time_wait

            time_penalty = torch.lt(leave, total_time_cost).float()*10
            total_time_cost += time_penalty
            total_time_penalty += time_penalty
            mask[zero_to_bsz_val, idx.data] += -np.inf 
            
        total_time_cost += torch.norm(y_cur[:,:2] - y_ini[:,:2], dim=1)
        accuracy = 1 - torch.lt(torch.zeros_like(total_time_penalty), total_time_penalty).sum().float() / total_time_penalty.size(0)
        print('validation result:{}, accuracy:{}'
                  .format(total_time_cost.mean().item(), accuracy))

        val_mean.append(total_time_cost.mean().item())
        val_std.append(total_time_cost.std().item())
        val_accuracy.append(accuracy)

    # For checkpoint
    plot_performance_train.append([(epoch+1), mean_tour_length_actor])
    plot_performance_baseline.append([(epoch+1), mean_tour_length_critic])
    
    # Compute optimality gap
    if size==50: gap_train = mean_tour_length_actor/5.692- 1.0
    elif size==100: gap_train = mean_tour_length_actor/7.765- 1.0
    else: gap_train = -1.0
        
    # Print and save in txt file
    mystring_min = 'Epoch: {:d}, epoch time: {:.3f}min, tot time: {:.3f}day, L_actor: {:.3f}, L_critic: {:.3f}, gap_train(%): {:.3f}, update: {}'.format(
        epoch, time_one_epoch/60, time_tot/86400, mean_tour_length_actor, mean_tour_length_critic, 100 * gap_train, update_baseline)
    
    print(mystring_min)
    print('Save Checkpoints')
    
    # Saving checkpoint
    checkpoint_dir = os.path.join("checkpoint")
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
        
    torch.save({
        'epoch': epoch,
        'time': time_one_epoch,
        'tot_time': time_tot,
        'loss': loss.item(),
        'plot_performance_train': plot_performance_train,
        'plot_performance_baseline': plot_performance_baseline,
        'mean_tour_length_val': total_time_penalty,
        'model_baseline': CriticLow.state_dict(),
        'model_train': ActorLow.state_dict(),
        'optimizer': optimizer.state_dict(),
        }, '{}.pkl'.format(checkpoint_dir + "/checkpoint_" + time_stamp + "-n{}".format(size) + "-gpu{}".format(gpu_id)))