In [None]:
#Load the required packages 
import torch
from torch import nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import time
import argparse
import pprint as pp
import os
from numpy.random import default_rng
from utils.graph_utils import *
from Gurobi_tsp_reader import GurobiTSPReader
from logging import root
from utils.model_utils import *
import networkx as nx

from models.Descriminator import Critic
from models.Generator import Generator
from optim import *



In [None]:
def plot_model_results(mse_val,status):
    
    # plot the mse of the validation over epochs 
    plt.plot(mse_va, linewidth = 2)
    plt.title('MSE over iterations during' + status, fontsize = 14)
    plt.xlabel('Epoch number',fontsize = 14)
    plt.ylabel('MSE',fontsize = 14)
    xticks(np.arange(0, mse_val.shape[0], step=1))  # Set label locations.
    plt.xticks(np.arange(0, mse_val.shape[0], step=1))
    plt.show()

In [None]:
# Helper Functions.
# create models
# load the datasets 
# Train the model 
# Test the model
# Other helper functions 

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

def get_gradient(crit, real, fake, epsilon):

    mixed_images = real * epsilon + fake * (1 - epsilon)
    mixed_scores = crit(mixed_images)
    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

def gradient_penalty(gradient):

    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)
    penalty = torch.mean((gradient_norm - 1)**2)
    return penalty

def get_gen_loss(crit_fake_pred, fake, real,g_on_real_pred,epoch_num):
    
    lambda_recon = 200
    recon_criterion =    nn.L1Loss() 
    gen_rec_loss = recon_criterion(real, fake)
    
    recon_criterion_iden =  nn.L1Loss() 
    identity_loss = recon_criterion_iden(real, g_on_real_pred)
    gen_loss = -1. * torch.mean(crit_fake_pred) + (lambda_recon) * gen_rec_loss+ identity_loss
    
    return gen_loss


def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp 
    return crit_loss

In [None]:
def load_dataset(train_filepath,val_filepath,num_nodes = 20,train_dataset_size = 1e10, valid_dataset_size_all = 10000):
    
    
    num_neighbors= -1
    batch_size= 1
    # Read training dataset 
    train_dataset = GurobiTSPReader(num_nodes, num_neighbors, batch_size, train_filepath)
    xx = np.zeros((train_dataset_size ,1,num_nodes,num_nodes)) # 
    z_norm = np.zeros((train_dataset_size,1,num_nodes,num_nodes)) #
    optimal_tour_len_train = np.zeros((train_dataset_size,1))
    optimal_tour_nodes_train = np.zeros((train_dataset_size,num_nodes))
    i = iter(train_dataset)
    for itr_num in range(np.int32(train_dataset_size)):
        
        next_batch = next(i)
        xx[itr_num] = next_batch.edges_target
        z_norm[itr_num] = next_batch.edges_values
        # Get the optimal tour
        optimal_tour_len_train[itr_num]=next_batch.tour_len # Get the optimal tour
        optimal_tour_nodes_train[itr_num] = next_batch.tour_nodes  # Get the tour nodes 
    
    # plot matrix for testing validity 
    plt.imshow(xx[0,0,:,:])
    plt.show()
    
    plt.imshow(z_norm[0,0,:,:])
    plt.show()
    
    # Read validation dataset
    val_dataset = GurobiTSPReader(num_nodes, num_neighbors, batch_size, val_filepath)
    validation_set_sample = np.zeros((valid_dataset_size_all,1,num_nodes,num_nodes)) # 
    z_norm_valid = np.zeros((valid_dataset_size_all,1,num_nodes,num_nodes)) # 
    optimal_tour_len_val = np.zeros((valid_dataset_size_all,1))
    optimal_tour_nodes_val = np.zeros((valid_dataset_size_all,num_nodes))
    
    i = iter(val_dataset)
    for valid_itr_num in range(np.int32(valid_dataset_size_all)):#
        
        
        next_batch = next(i)
        validation_set_sample[valid_itr_num] = next_batch.edges_target
        z_norm_valid[valid_itr_num] = next_batch.edges_values
        
        # Get the optimal tour
        optimal_tour_len_val[valid_itr_num]=next_batch.tour_len # Get the optimal tour
        optimal_tour_nodes_val[valid_itr_num] = next_batch.tour_nodes  # Get the tour nodes 
        
    return xx,z_norm,validation_set_sample,z_norm_valid


In [None]:
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_nodes", type=int, default=20)
    parser.add_argument("--train_filepath", type=int, default=None)
    parser.add_argument("--val_filepath", type=str, default=None)
    parser.add_argument("--test_filepath", type=str, default=None)
    parser.add_argument("--train_dataset_size",type=int,default=1e6)
    parser.add_argument("--valid_dataset_size",type=int,default=1e4)
    parse.add_argument("--testing_datset_size", type=int,default = 1e4)
    parser.add_argument("--load_best_train", type = bool,default=True)
    parser.add_argument("--load_best_test", type = bool,default=True)
    parser.add_argument("--pretrained",type=bool,default=False)
    parser.add_argument("--n_epochs",type=int,default=100)
    parser.add_argument("--beam_size",type=int,default=1024)
    
    opts = parser.parse_args()
    # if the filee names are not specified
    if train_filepath ==None:
        train_filepath = f"mmwave{num_nodes}_Gurobi_multi_proc.txt"
    
    if val_filepath == None:
        val_filepath = f"mmwave{num_nodes}_val_Gurobi_multi_proc.txt"
        
    if test_filepath ==  None:
        test_filepath = f"mmwave{num_nodes}_test_Gurobi_multi_proc.txt"
        
        
    display_step = 1000
    lr = 0.0002
    beta_1 = 0.5
    beta_2 = 0.999
    c_lambda = 10
    crit_repeats = 5
    device = 'cuda'
    val_mse = []
    
    
    gen = Generator().to(device) 
    crit = Critic().to(device) 
    
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
    crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))
    
    gen = gen.apply(weights_init)
    crit = crit.apply(weights_init)
    with open('Model_results_summary.txt',"a" , encoding="utf-8") as f:
        
        f.write('Model Parameters')
        f.write("Number of epochs = " + str(n_epochs))
        f.write("Pretrained = " + str(pretrained))
        f.write("Pre-trained with best = " + str(load_best_train))
        f.write("Tested with best = " + str(load_best_test))
       
        
    gen,val_mse = train_model(gen,crit,train_filepath,val_filepath,train_dataset_size,valid_dataset_size,n_epochs,pretrained,load_best_train)
    model_testing(gen,num_nodes, testing_datset_size, beam_size,test_filepath,load_best_test)
    status = 'train'
    plot_model_results(np.asarray(val_mse,status)) 