In [1]:
import torch
import torch.nn as nn
import numpy as np
import os
import json
import random
from typing import Optional, List, Tuple
from multiprocessing import cpu_count, Pool

from tqdm.notebook import tqdm

from tsp_tinker_utils import TSPPackage, get_tsp_problem_folders

In [2]:
DATA_PATH = os.getenv('DATA_PATH', './data')
BSSF_PATH = os.getenv('BSSF_PATH', './bssf')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
VALIDATION_PROP = 0.1
TRAIN_PROBLEM_SIZE_CUTOFF = 100
TEST_PROBLEM_SIZE_CUTOFF = 201
LR = 0.001
NUM_PATH_VARIATIONS_PER_EXAMPLE = 8
NUM_EPOCHS = 50

In [3]:
def sample_path_variations_from_problem(city_connections_w_costs: np.ndarray, edge_matrix: np.ndarray, edges: List[Tuple[int, int]], num_samples: int) -> Tuple[List[List[Tuple[int, int]]], np.ndarray, np.ndarray]:
    num_cities = city_connections_w_costs.shape[0]
    
    batch = []
    num_edges_lb = 0
    num_edges_ub = num_cities - 2
    for i in range(num_samples):
        num_edges_to_include = random.randint(num_edges_lb, num_edges_ub)
        direction = bool(random.randint(0, 1))

        if num_edges_to_include == 0:
            num_edges_lb = 1

        random.shuffle(edges)
        if direction:
            batch.append(edges[:num_edges_to_include])
        else:
            reversed_edges = [(out_c, in_c) for in_c, out_c in edges[:num_edges_to_include]]
            batch.append(reversed_edges)

    target = edge_matrix

    return batch, target, city_connections_w_costs

def load_and_process_tsp_problem(file_path: os.PathLike) -> Tuple[List[List[Tuple[int, int]]], np.ndarray, np.ndarray]:
    with open(file_path, 'r') as file:
        json_data = json.load(file)

    packaged_problem = TSPPackage.from_json(json_data)
    problem = packaged_problem.problem
    best_solution = packaged_problem.best_solution

    # Convert the path to edge matrix
    edge_matrix = np.zeros((problem.num_cities, problem.num_cities))
    edges = []
    previous_city = 0
    for city in best_solution.tour:
        edge_matrix[previous_city, city] = 1
        edges.append((previous_city, city))
        previous_city = city

    return sample_path_variations_from_problem(problem.city_connections_w_costs, edge_matrix, edges, NUM_PATH_VARIATIONS_PER_EXAMPLE)
    

class TSPDataset(torch.utils.data.Dataset):
    def __init__(self, data: List[Tuple[torch.Tensor, torch.Tensor]], num_path_variations: int = NUM_PATH_VARIATIONS_PER_EXAMPLE):
        super(TSPDataset, self).__init__()
        self.data = data
        self.num_path_variations = num_path_variations

    @classmethod
    def from_disk(cls, data_folder_path: os.PathLike, problem_size_lower_bound: Optional[int] = None, problem_size_upper_bound: Optional[int] = None, undirected_only: Optional[bool] = True, max_workers: int = cpu_count() - 1):
        new_instance = cls([])
        new_instance.data = new_instance._multi_threaded_load(data_folder_path, problem_size_lower_bound, problem_size_upper_bound, max_workers, undirected_only)

        return new_instance
    
    def _multi_threaded_load(self, data_folder_path: os.PathLike, problem_size_lower_bound: Optional[int], problem_size_upper_bound: Optional[int], max_workers: int, undirected_only: Optional[bool] = True):
        possible_folders = get_tsp_problem_folders(data_folder_path)
        problem_file_paths = []
        for _, problem_size, folder_path in possible_folders:
            if (problem_size_lower_bound is None or problem_size >= problem_size_lower_bound) and (problem_size_upper_bound is None or problem_size <= problem_size_upper_bound):
                # Add all files in the folder to the list of files to load
                for file in os.listdir(folder_path):
                    if file.endswith('.json'):
                        problem_file_paths.append(os.path.join(folder_path, file))

        formatted_data = []
        with Pool(processes=max_workers) as worker_pool:
            with tqdm(total=len(problem_file_paths), desc="Loading data from disk...") as p_bar:
                for result in worker_pool.imap_unordered(load_and_process_tsp_problem, problem_file_paths, chunksize=10):
                    p_bar.update(1)
                    formatted_data.append(result)

        return formatted_data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        path_variation_edges, cost_matrix, edge_matrix =  self.data[idx]

        target = torch.tensor(edge_matrix, dtype=torch.float32)
        target = target.repeat(self.num_path_variations, 1, 1)

        batch = np.zeros((self.num_path_variations, 2, cost_matrix.shape[0], cost_matrix.shape[1]))
        for i, path in enumerate(path_variation_edges):
            for in_c, out_c in path:
                batch[i, 1, in_c, out_c] = 1

            batch[i, 0] = cost_matrix

        return torch.tensor(batch, dtype=torch.float32), target
    
    def split(self, ration: float = 0.1):
        num_to_take = int(len(self) * ration)
        random.shuffle(self.data)

        split_data = self.data[:num_to_take]
        self.data = self.data[num_to_take:]

        return TSPDataset(split_data)
        

In [4]:
class ConvolutionalSalesmanNet(nn.Module):
    def __init__(self, kernel_size: int = 5, padding: int = 2):
        super(ConvolutionalSalesmanNet, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(2, 4, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(4),
            nn.GELU(),
            nn.Conv2d(4, 8, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(8),
            nn.GELU(),
            nn.Conv2d(8, 16, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(16),
            nn.GELU(),
            nn.Conv2d(16, 32, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.Conv2d(32, 64, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.Conv2d(64, 128, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.Conv2d(128, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            # 10 layers down
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            # 20 layers down, head back up
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            # 10 layers up, 10 layers to go
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 128, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.Conv2d(128, 64, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.Conv2d(64, 32, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.Conv2d(32, 16, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(16),
            nn.GELU(),
            nn.Conv2d(16, 8, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(8),
            nn.GELU(),
            nn.Conv2d(8, 4, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(4),
            nn.GELU(),
            nn.Conv2d(4, 2, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(2),
            nn.GELU(),
            nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).squeeze()
    
    @property
    def device(self):
        return next(self.parameters()).device

In [5]:
# def sample_path_variations_from_problem(city_connections_w_costs: np.ndarray, tour_edges: np.ndarray, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
#     num_cities = city_connections_w_costs.shape[0]
#     edges = tour_edges.nonzero()
#     edges = [(in_c, out_c) for in_c, out_c in zip(edges[0], edges[1])]

    
#     batch: np.ndarray = np.zeros((num_samples, 2, num_cities, num_cities))
#     num_edges_lb = 0
#     num_edges_ub = num_cities - 2
#     for i in range(num_samples):
#         num_edges_to_include = random.randint(num_edges_lb, num_edges_ub)
#         direction = bool(random.randint(0, 1))

#         if num_edges_to_include == 0:
#             # only one possible variation, don't double dip
#             num_edges_lb = 1

#         random.shuffle(edges)
#         for in_c, out_c in edges[:num_edges_to_include]:
#             if direction:
#                 batch[i, 1, in_c, out_c] = 1
#             else:
#                 batch[i, 1, out_c, in_c] = 1
#         batch[i, 0] = city_connections_w_costs

#     target = torch.tensor(tour_edges, dtype=torch.long)
#     target = target.repeat(num_samples, 1, 1)

#     return torch.tensor(np.array(batch), dtype=torch.float32), target

In [6]:
train_dataset = TSPDataset.from_disk(DATA_PATH, problem_size_upper_bound=TRAIN_PROBLEM_SIZE_CUTOFF)
validation_dataset = train_dataset.split(VALIDATION_PROP)

Loading data from disk...:   0%|          | 0/218000 [00:00<?, ?it/s]

In [7]:

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=True)

In [8]:
model = ConvolutionalSalesmanNet().to(DEVICE)
loss_fn = nn.MSELoss().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [9]:
bssf_validation = 465654654566

In [10]:
train_losses = []
validation_losses = []

In [12]:
try:
    num_train_batches = len(train_loader)
    num_validation_batches = len(validation_loader)
    for i in range(NUM_EPOCHS):
        model.train()

        tot_train_loss = 0
        train_bar = tqdm(range(num_train_batches), desc=f"Epoch {i+1}/{NUM_EPOCHS}")
        for j, (batch, target) in enumerate(train_loader):
            optimizer.zero_grad()

            # Format the data and move to GPU
            batch = batch[0].to(DEVICE)
            target = target[0].to(DEVICE)

            # Forward pass and adjust
            path_predictions: torch.Tensor = model(batch)
            loss: torch.Tensor = loss_fn(path_predictions, target)

            loss.backward()
            optimizer.step()

            tot_train_loss += loss.item()
            train_bar.update(1)


            if j % 100 == 0:
                train_bar.set_postfix(train_loss= tot_train_loss / (j + 1))

        train_losses.append(tot_train_loss / num_train_batches)


        # Validation
        model.eval()
        tot_validation_loss = 0
        for j, (batch, target) in enumerate(validation_loader):
            batch = batch[0].to(DEVICE)
            target = target[0].to(DEVICE)

            path_predictions = model(batch)
            loss = loss_fn(path_predictions, target)

            tot_validation_loss += loss.item()

        validation_loss = tot_validation_loss / num_validation_batches
        validation_losses.append(validation_loss)
        train_bar.set_postfix(train_loss= train_losses[-1], validation_loss= validation_loss[-1])

        if validation_loss < bssf_validation:
            print(f"New BSSF: {validation_loss} ----> {bssf_validation - validation_loss} improvement!")
            bssf_validation = validation_loss
            torch.save(model.state_dict(), BSSF_PATH)

except Exception as e:
    print(e)

    # Dang memory leaks
    del model

    raise e




Epoch 1/50:   0%|          | 0/196200 [00:00<?, ?it/s]

In [None]:
# test_dataset = TSPDataset(DATA_PATH, problem_size_lower_bound=TEST_PROBLEM_SIZE_CUTOFF)