<a href="https://colab.research.google.com/github/MarioAuditore/TDA-for-Travelling-Salesman/blob/main/train_tsp_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Based on "The Transformer Network for the Traveling Salesman Problem"

Xavier Bresson, Thomas Laurent, Feb 2021<br>

Arxiv : https://arxiv.org/pdf/2103.03012.pdf<br>
Talk : https://ipam.wistia.com/medias/0jrweluovs<br>
Slides : https://t.co/ySxGiKtQL5<br>

This code trains the transformer network by reinforcement learning.<br>
Use the beam search code to test the trained network.


In [1]:
!git clone https://github.com/MarioAuditore/TDA-for-Travelling-Salesman.git

Cloning into 'TDA-for-Travelling-Salesman'...
remote: Enumerating objects: 54, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 54 (delta 3), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (54/54), 104.14 MiB | 19.24 MiB/s, done.
Resolving deltas: 100% (3/3), done.


In [2]:
import os

os.chdir('/content/TDA-for-Travelling-Salesman')

In [3]:
!pip install 'pyconcorde @ git+https://github.com/jvkersch/pyconcorde'

Collecting pyconcorde@ git+https://github.com/jvkersch/pyconcorde
  Cloning https://github.com/jvkersch/pyconcorde to /tmp/pip-install-ewtszkhi/pyconcorde_607f84db627d499195be1d0b86fe6290
  Running command git clone --filter=blob:none --quiet https://github.com/jvkersch/pyconcorde /tmp/pip-install-ewtszkhi/pyconcorde_607f84db627d499195be1d0b86fe6290
  Resolved https://github.com/jvkersch/pyconcorde to commit 8a6b193b79ebdf8f07e0b0635722b3b4edbc1560
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tsplib95 (from pyconcorde@ git+https://github.com/jvkersch/pyconcorde)
  Downloading tsplib95-0.7.1-py2.py3-none-any.whl (25 kB)
Collecting Deprecated~=1.2.9 (from tsplib95->pyconcorde@ git+https://github.com/jvkersch/pyconcorde)
  Downloading Deprecated-1.2.14-py2.py3-none-any.whl (9.6 kB)
Collecting network

In [4]:
# ================
# Libs
# ================

import torch
import torch.nn as nn
from tqdm import tqdm

# Models
from tsp_transformer.model import TSP_net, compute_tour_length


import time
import argparse
import os
import datetime


# visualization
%matplotlib inline
# from IPython.display import set_matplotlib_formats, clear_output
# set_matplotlib_formats('png2x','pdf')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


import networkx as nx
from scipy.spatial.distance import pdist, squareform
from concorde.tsp import TSPSolver # !pip install -e pyconcorde


import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [5]:
###################
# Hardware : CPU / GPU(s)
###################

if torch.backends.mps.is_available():
    gpu_id = '0'
    device = torch.device("mps")

elif torch.cuda.is_available():
    gpu_id = '0' # select a single GPU
    # gpu_id = '2,3' # select multiple GPUs
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    device = torch.device("cuda")
    print('GPU name: {:s}, gpu_id: {:s}'.format(torch.cuda.get_device_name(0),gpu_id))

else:
    device = torch.device("cpu")
    gpu_id = -1 # select CPU


print(device)

cpu


In [6]:
# ================
# Hyper-parameters
# ================

class DotDict(dict):
    def __init__(self, **kwds):
        self.update(kwds)
        self.__dict__ = self

args = DotDict()
args.gpu_id = gpu_id

# TSP problem number of nodes
args.nb_nodes = 10 # TSP20
# args.nb_nodes = 50 # TSP50
# args.nb_nodes = 100 # TSP100

# Transformer parameters
args.dim_emb = 128 # dimension of embeddings in transformer
args.dim_ff = 512 # dimension of feed forward layers
args.dim_input_nodes = 2 # dimension of input features
args.nb_layers_encoder = 6
args.nb_layers_decoder = 2
args.nb_heads = 8

#
args.nb_epochs = 20 # number of epochs
args.batch_size = 128 # batch size
args.nb_batch_per_epoch = 250 # number of batches to generate on each epoch for training
args.nb_batch_eval = 100 # number of batches to generate on each epoch for evaluation
args.lr = 1e-4 # optimiser lr
args.tol = 1e-3 # model should perform better w.r.t tolerance to be updated
args.batchnorm = True  # if batchnorm=True  than batch norm is used
# args.batchnorm = False # if batchnorm=False than layer norm is used
args.max_len_PE = 1000

print(args)


{'gpu_id': -1, 'nb_nodes': 10, 'dim_emb': 128, 'dim_ff': 512, 'dim_input_nodes': 2, 'nb_layers_encoder': 6, 'nb_layers_decoder': 2, 'nb_heads': 8, 'nb_epochs': 20, 'batch_size': 128, 'nb_batch_per_epoch': 250, 'nb_batch_eval': 100, 'lr': 0.0001, 'tol': 0.001, 'batchnorm': True, 'max_len_PE': 1000}


# Training
## Setup

In [7]:
###################
# Instantiate a training network and a baseline network
###################
try:
    del model_train # remove existing model
    del model_baseline # remove existing model
except:
    pass

model_train = TSP_net(args.dim_input_nodes,
                      args.dim_emb,
                      args.dim_ff,
                      args.nb_layers_encoder,
                      args.nb_layers_decoder,
                      args.nb_heads,
                      args.max_len_PE,
                      batchnorm=args.batchnorm)

model_baseline = TSP_net(args.dim_input_nodes,
                         args.dim_emb,
                         args.dim_ff,
                         args.nb_layers_encoder,
                         args.nb_layers_decoder,
                         args.nb_heads,
                         args.max_len_PE,
                         batchnorm=args.batchnorm)

if torch.cuda.device_count() > 1:
    print(torch.cuda.device_count() + " cuda devices found, doing parallel training.")
    model_train = nn.DataParallel(model_train)
    model_baseline = nn.DataParallel(model_baseline)

optimizer = torch.optim.Adam(model_train.parameters(), lr = args.lr)

model_train = model_train.to(device)
model_baseline = model_baseline.to(device)
model_baseline.eval()

print(args); print('')

# Logs
os.system("mkdir logs")
time_stamp=datetime.datetime.now().strftime("%y-%m-%d--%H-%M-%S")
file_name = 'logs'+'/'+time_stamp + "-n{}".format(args.nb_nodes) + "-gpu{}".format(args.gpu_id) + ".txt"
file = open(file_name,"w",1)
file.write(time_stamp+'\n\n')
for arg in vars(args):
    file.write(arg)
    hyper_param_val="={}".format(getattr(args, arg))
    file.write(hyper_param_val)
    file.write('\n')
file.write('\n\n')
plot_performance_train = []
plot_performance_baseline = []
all_strings = []
epoch_ckpt = 0
tot_time_ckpt = 0


# # Uncomment these lines to re-start training with saved checkpoint
# ====================================================================
# checkpoint_file = "checkpoint/checkpoint_21-03-01--17-25-00-n50-gpu0.pkl"
# checkpoint = torch.load(checkpoint_file, map_location=device)
# epoch_ckpt = checkpoint['epoch'] + 1
# tot_time_ckpt = checkpoint['tot_time']
# plot_performance_train = checkpoint['plot_performance_train']
# plot_performance_baseline = checkpoint['plot_performance_baseline']
# model_baseline.load_state_dict(checkpoint['model_baseline'])
# model_train.load_state_dict(checkpoint['model_train'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# print('Re-start training with saved checkpoint file={:s}\n  Checkpoint at epoch= {:d} and time={:.3f}min\n'.format(checkpoint_file,epoch_ckpt-1,tot_time_ckpt/60))
# del checkpoint
# ====================================================================



{'gpu_id': -1, 'nb_nodes': 10, 'dim_emb': 128, 'dim_ff': 512, 'dim_input_nodes': 2, 'nb_layers_encoder': 6, 'nb_layers_decoder': 2, 'nb_heads': 8, 'nb_epochs': 20, 'batch_size': 128, 'nb_batch_per_epoch': 250, 'nb_batch_eval': 100, 'lr': 0.0001, 'tol': 0.001, 'batchnorm': True, 'max_len_PE': 1000}



## Test nodes

In [8]:
###################
# Small test set for quick algorithm comparison
# Note : this can be removed
###################

save_1000tsp = False

test_size = 10

if save_1000tsp:
    x = torch.rand(test_size, args.nb_nodes, args.dim_input_nodes, device='cpu')
    print(x.size(),x[0])
    data_dir = os.path.join("data")

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    if args.nb_nodes==20 : torch.save({ 'x': x, }, '{}.pkl'.format(data_dir + "/1000tsp20"))
    if args.nb_nodes==50 : torch.save({ 'x': x, }, '{}.pkl'.format(data_dir + "/1000tsp50"))
    if args.nb_nodes==100 : torch.save({ 'x': x, }, '{}.pkl'.format(data_dir + "/1000tsp100"))

checkpoint = None

if args.nb_nodes==20 : checkpoint = torch.load("data/1000tsp20.pkl")
if args.nb_nodes==50 : checkpoint = torch.load("data/1000tsp50.pkl")
if args.nb_nodes==100 : checkpoint = torch.load("data/1000tsp100.pkl")

if checkpoint is not None:
    x_test_tsp = checkpoint['x'].to(device)
    n = x_test_tsp.size(1)
    print('nb of nodes :',n)
else:
    x_test_tsp = torch.rand(test_size, args.nb_nodes, args.dim_input_nodes, device='cpu')
    n = x_test_tsp.size(1)
    print('nb of nodes :',n)


nb of nodes : 10


## Training loop

In [9]:
# ==================
# Main training loop
# ==================


start_training_time = time.time()

for epoch in tqdm(range(0, args.nb_epochs)):

    # re-start training with saved checkpoint
    # epoch += epoch_ckpt

    # -------------------------
    # Train model for one epoch
    # -------------------------
    start = time.time()
    model_train.train()

    for step in range(1, args.nb_batch_per_epoch + 1):

        # generate a batch of random TSP instances
        x = torch.rand(args.batch_size, args.nb_nodes, args.dim_input_nodes, device=device) # size(x)=(batch_size, nb_nodes, dim_input_nodes)

        # generate new features
        # new_f = ...
        data = x # np.concat(x, new_f)

        # compute tours for model
        tour_train, sumLogProbOfActions = model_train(data, deterministic=False) # size(tour_train)=(batch_size, nb_nodes), size(sumLogProbOfActions)=(batch_size)

        # compute tours for baseline
        with torch.no_grad():
            tour_baseline, _ = model_baseline(data, deterministic=True)

        # get the lengths of the tours
        L_train = compute_tour_length(x, tour_train) # size(L_train)=(batch_size)
        L_baseline = compute_tour_length(x, tour_baseline) # size(L_baseline)=(batch_size)

        # backprop
        loss = torch.mean((L_train - L_baseline) * sumLogProbOfActions )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    time_one_epoch = time.time()-start
    time_tot = time.time()-start_training_time + tot_time_ckpt


    # -----------------
    # Evaluate train model and baseline on 10k random TSP instances
    # -----------------
    model_train.eval()
    mean_tour_length_train = 0
    mean_tour_length_baseline = 0

    for step in range(0, args.nb_batch_eval):
        # generate a batch of random tsp instances
        x = torch.rand(args.batch_size, args.nb_nodes, args.dim_input_nodes, device=device)

        # generate new features
        # new_f = ...
        data = x # np.concat(x, new_f)

        # compute tour for model and baseline
        with torch.no_grad():
            tour_train, _ = model_train(data, deterministic=True)
            tour_baseline, _ = model_baseline(data, deterministic=True)

        # get the lengths of the tours
        L_train = compute_tour_length(x, tour_train)
        L_baseline = compute_tour_length(x, tour_baseline)

        # L_tr and L_bl are tensors of shape (batch_size,). Compute the mean tour length
        mean_tour_length_train += L_train.mean().item()
        mean_tour_length_baseline += L_baseline.mean().item()

    mean_tour_length_train =  mean_tour_length_train/ args.nb_batch_eval
    mean_tour_length_baseline =  mean_tour_length_baseline/ args.nb_batch_eval

    # evaluate train model and baseline and update if train model is better
    update_baseline = mean_tour_length_train + args.tol < mean_tour_length_baseline
    if update_baseline:
        model_baseline.load_state_dict(model_train.state_dict())

    # For new baseline compute TSPs for small test set
    with torch.no_grad():
        tour_baseline, _ = model_baseline(x_test_tsp.to(device), deterministic=True)
    mean_tour_length_test = compute_tour_length(x_test_tsp, tour_baseline.to('cpu')).mean().item()

    # For checkpoint
    plot_performance_train.append([(epoch+1), mean_tour_length_train])
    plot_performance_baseline.append([(epoch+1), mean_tour_length_baseline])

    # Compute optimality gap
    if args.nb_nodes==50: gap_train = mean_tour_length_train/5.692 - 1.0
    elif args.nb_nodes==100: gap_train = mean_tour_length_train/7.765 - 1.0
    else: gap_train = -1.0

    # # Print and save in txt file
    # mystring_min = 'Epoch: {:d}, epoch time: {:.3f}min, tot time: {:.3f}day, L_train: {:.3f}, L_base: {:.3f}, L_test: {:.3f}, gap_train(%): {:.3f}, update: {}'.format(
    #     epoch, time_one_epoch/60, time_tot/86400, mean_tour_length_train, mean_tour_length_baseline, mean_tour_length_test, 100*gap_train, update_baseline)
    # print(mystring_min) # Comment if plot display
    # file.write(mystring_min+'\n')

    mystring_min = 'Epoch: {:d}, epoch time: {:.3f} min, tot time: {:.3f}day, L_train: {:.3f}, L_base: {:.3f}, L_test: {:.3f}, gap_train(%): {:.3f}, update: {}'.format(
        epoch, time_one_epoch/60, time_tot/86400, mean_tour_length_train, mean_tour_length_baseline, mean_tour_length_test, 100*gap_train, update_baseline)
    print(mystring_min) # Comment if plot display

    # all_strings.append(mystring_min) # Uncomment if plot display
    # for string in all_strings:
    #     print(string)

    # Saving checkpoint
    checkpoint_dir = os.path.join("checkpoint")
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    torch.save({
        'epoch': epoch,
        'time': time_one_epoch,
        'tot_time': time_tot,
        'loss': loss.item(),
        'TSP_length': [torch.mean(L_train).item(), torch.mean(L_baseline).item(), mean_tour_length_test],
        'plot_performance_train': plot_performance_train,
        'plot_performance_baseline': plot_performance_baseline,
        'mean_tour_length_test': mean_tour_length_test,
        'model_baseline': model_baseline.state_dict(),
        'model_train': model_train.state_dict(),
        'optimizer': optimizer.state_dict(),
        }, '{}.pkl'.format(checkpoint_dir + "/checkpoint_" + time_stamp + "-n{}".format(args.nb_nodes) + "-gpu{}".format(args.gpu_id)))



  0%|          | 0/20 [00:47<?, ?it/s]


KeyboardInterrupt: 

## Final check

In [10]:
# generate a batch of random TSP instances
x = torch.rand(args.batch_size, args.nb_nodes, args.dim_input_nodes, device=device) # size(x)=(batch_size, nb_nodes, dim_input_nodes)

# generate new features
# new_f = ...
data = x # np.concat(x, new_f)

with torch.no_grad():
    tour_baseline, _ = model_baseline(data, deterministic=True)
    print(compute_tour_length(x, tour_train).mean())


tensor(5.2266)


# Visualisation (BROKEN)
## Checkpoint
Here you can load checkpoints of models from original github trained on TSPs of different sizes. Don't use this code if you want to plot the model from above

In [None]:
# ===============
# Load checkpoint
# ===============

load_checkpoint = False # change to True if you need

if load_checkpoint:
    checkpoint_file = "checkpoint/checkpoint_21-03-01--17-25-00-n50-gpu0.pkl"
    checkpoint = torch.load(checkpoint_file, map_location=device)

    epoch_ckpt = checkpoint['epoch'] + 1
    tot_time_ckpt = checkpoint['tot_time']

    plot_performance_train = checkpoint['plot_performance_train']
    plot_performance_baseline = checkpoint['plot_performance_baseline']

    model_baseline.load_state_dict(checkpoint['model_baseline'])

    print('Load checkpoint file={:s}\n  Checkpoint at epoch= {:d} and time={:.3f}min\n'.format(checkpoint_file,epoch_ckpt-1,tot_time_ckpt/60))
    del checkpoint

    mystring_min = 'Epoch: {:d}, tot_time_ckpt: {:.3f}day, L_train: {:.3f}, L_base: {:.3f}\n'.format(
        epoch_ckpt, tot_time_ckpt/3660/24, plot_performance_train[-1][1], plot_performance_baseline[-1][1])
    print(mystring_min)

## Beam search

In [None]:
# ===============================
# Hyper-parameter for beam search
# ===============================

# experiment n50
# B = 1; args.bsz = 10000; greedy = True; beamsearch = False # greedy
# B = 100; args.bsz = 250; greedy = False; beamsearch = True
# B = 1000; args.bsz = 25; greedy = False; beamsearch = True
B = 2500; args.bsz = 10; greedy = False; beamsearch = True
args.nb_batch_eval = (10000+args.bsz-1)// args.bsz
print('nb_nodes: {}, bsz: {}, B: {}, nb_batch_eval: {}, tot TSPs: {}\n'.format(args.nb_nodes, args.bsz, B, args.nb_batch_eval, args.nb_batch_eval* args.bsz))

# # experiment n100
# args.nb_nodes = 100
# B = 1; args.bsz = 5000; greedy = True; beamsearch = False # greedy
# B = 1000; args.bsz = 10; greedy = False; beamsearch = True
# args.nb_batch_eval = (10000+args.bsz-1)// args.bsz
# print('nb_nodes: {}, bsz: {}, B: {}, nb_batch_eval: {}, tot TSPs: {}\n'.format(args.nb_nodes, args.bsz, B, args.nb_batch_eval, args.nb_batch_eval* args.bsz))



# ========
# Test set
# ========

load_points = False

if load_points:
    if args.nb_nodes == 50:
        x_10k = torch.load('data/10k_TSP50.pt').to(device)
        x_10k_len = torch.load('data/10k_TSP50_len.pt').to(device)
        L_concorde = x_10k_len.mean().item()
    if args.nb_nodes == 100:
        x_10k = torch.load('data/10k_TSP100.pt').to(device)
        x_10k_len = torch.load('data/10k_TSP100_len.pt').to(device)
        L_concorde = x_10k_len.mean().item()
else:
    test_size = 50
    x_10k = torch.rand(test_size, args.nb_nodes, args.dim_input_nodes, device='cpu')

nb_TSPs = args.nb_batch_eval* args.bsz


In [None]:
x_10k = torch.load('data/10k_TSP50.pt').to(device)
x_10k_len = torch.load('data/10k_TSP50_len.pt').to(device)
L_concorde = x_10k_len.mean().item()

x_10k_len

In [None]:
# ===============
# Run beam search
# ===============

start = time.time()

mean_tour_length_greedy = 0
mean_tour_length_beamsearch = 0
mean_scores_greedy = 0
mean_scores_beamsearch = 0
gap_greedy = 0
gap_beamsearch = 0

for step in range(0,args.nb_batch_eval):
    print('batch index: {}, tot_time: {:.3f}min'.format(step, (time.time()-start)/60))

    # extract a batch of test tsp instances
    x = x_10k[step*args.bsz:(step+1)*args.bsz,:,:]
    x_len_concorde = x_10k_len[step*args.bsz:(step+1)*args.bsz]

    # compute tour for model and baseline
    with torch.no_grad():
        tours_greedy, tours_beamsearch, scores_greedy, scores_beamsearch = model_baseline(x, B, greedy, beamsearch)

        # greedy
        if greedy:
            L_greedy = compute_tour_length(x, tours_greedy)
            mean_tour_length_greedy += L_greedy.mean().item()
            mean_scores_greedy += scores_greedy.mean().item()
            x_len_greedy = L_greedy
            gap_greedy += (x_len_greedy/ x_len_concorde - 1.0).sum()

        # beamsearch
        if beamsearch:
            tours_beamsearch = tours_beamsearch.view(args.bsz*B, args.nb_nodes)
            x = x.repeat_interleave(B,dim=0)
            L_beamsearch = compute_tour_length(x, tours_beamsearch)
            tours_beamsearch = tours_beamsearch.view(args.bsz, B, args.nb_nodes)
            L_beamsearch = L_beamsearch.view(args.bsz, B)
            L_beamsearch_tmp = L_beamsearch
            L_beamsearch, idx_min = L_beamsearch.min(dim=1)
            mean_tour_length_beamsearch += L_beamsearch.mean().item()
            mean_scores_beamsearch += scores_beamsearch.mean().item()
            x_len_beamsearch = L_beamsearch
            gap_beamsearch += (x_len_beamsearch/ x_len_concorde - 1.0).sum()

    if torch.cuda.is_available():
        torch.cuda.empty_cache() # free GPU reserved memory
if greedy:
    mean_tour_length_greedy =  mean_tour_length_greedy/ args.nb_batch_eval
    mean_scores_greedy =  mean_scores_greedy/ args.nb_batch_eval
    gap_greedy = (gap_greedy/ nb_TSPs).item()

if beamsearch:
    mean_tour_length_beamsearch =  mean_tour_length_beamsearch/ args.nb_batch_eval
    mean_scores_beamsearch =  mean_scores_beamsearch/ args.nb_batch_eval
    gap_beamsearch /= nb_TSPs
tot_time = time.time()-start

In [None]:
# =================
# Write result file
# =================

nb_TSPs = args.nb_batch_eval* args.bsz
file_name = "beamsearch-nb_nodes{}".format(args.nb_nodes) + "-nb_TSPs{}".format(nb_TSPs) + "-B{}".format(B) + ".txt"
file = open(file_name,"w",1)
mystring = '\nnb_nodes: {:d}, nb_TSPs: {:d}, B: {:d}, L_greedy: {:.6f}, L_concorde: {:.5f}, L_beamsearch: {:.5f}, \
gap_greedy(%): {:.5f}, gap_beamsearch(%): {:.5f}, scores_greedy: {:.5f}, scores_beamsearch: {:.5f}, tot_time: {:.4f}min, \
tot_time: {:.3f}hr, mean_time: {:.3f}sec'.format(args.nb_nodes, nb_TSPs, B, mean_tour_length_greedy, L_concorde, \
                                 mean_tour_length_beamsearch, 100*gap_greedy, 100*gap_beamsearch, mean_scores_greedy, \
                                 mean_scores_beamsearch, tot_time/60, tot_time/3600, tot_time/nb_TSPs)
print(mystring)
file.write(mystring)
file.close()