In [1]:
import torch
import numpy as np
import os
import json
from typing import Optional, List, Tuple
from multiprocessing import cpu_count, Pool

from tqdm.notebook import tqdm

from tsp_tinker_utils import TSPPackage, TSPProblem, TSPSolution, get_tsp_problem_folders

In [2]:
DATA_PATH = os.getenv('DATA_PATH', './data')

In [3]:
def load_and_process_tsp_problem(file_path: os.PathLike) -> Tuple[np.ndarray, List[int], float]:
    with open(file_path, 'r') as file:
        json_data = json.load(file)

    packaged_problem = TSPPackage.from_json(json_data)
    problem = packaged_problem.problem
    best_solution = packaged_problem.best_solution

    problems_and_solutions = (problem.city_connections_w_costs, best_solution.tour, best_solution.tot_cost)    
    
    return problems_and_solutions

class TSPDataset(torch.utils.data.Dataset):
    def __init__(self, data_folder_path: os.PathLike, problem_size_lower_bound: Optional[int] = None, problem_size_upper_bound: Optional[int] = None, undirected_only: Optional[bool] = True, max_workers: int = cpu_count() - 1):
        super(TSPDataset, self).__init__()
        self.data = self.multi_threaded_load(data_folder_path, problem_size_lower_bound, problem_size_upper_bound, max_workers, undirected_only)

    def multi_threaded_load(self, data_folder_path: os.PathLike, problem_size_lower_bound: Optional[int], problem_size_upper_bound: Optional[int], max_workers: int, undirected_only: Optional[bool] = True):
        possible_folders = get_tsp_problem_folders(data_folder_path)
        problem_file_paths = []
        for _, problem_size, folder_path in possible_folders:
            if (problem_size_lower_bound is None or problem_size >= problem_size_lower_bound) and (problem_size_upper_bound is None or problem_size <= problem_size_upper_bound):
                # Add all files in the folder to the list of files to load
                for file in os.listdir(folder_path):
                    if file.endswith('.json'):
                        problem_file_paths.append(os.path.join(folder_path, file))

        formatted_data = []
        with Pool(processes=max_workers) as worker_pool:
            with tqdm(total=len(problem_file_paths), desc="Loading data from disk...") as p_bar:
                for result in worker_pool.imap_unordered(load_and_process_tsp_problem, problem_file_paths, chunksize=10):
                    p_bar.update(1)
                    formatted_data.append(result)

        return formatted_data



    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


In [4]:
train_dataset = TSPDataset(DATA_PATH, problem_size_upper_bound=200)
print(len(train_dataset))

Loading data from disk...:   0%|          | 0/254000 [00:00<?, ?it/s]

254000


In [5]:
validation_dataset = TSPDataset(DATA_PATH, problem_size_lower_bound=201)
print(len(validation_dataset))

Loading data from disk...:   0%|          | 0/1000 [00:00<?, ?it/s]

1000


In [6]:
def convert_problem_and_tour_to_mse_problems(city_connections_w_costs: np.ndarray, tour: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
    prob_size = city_connections_w_costs.shape[0]
    current_path: np.ndarray = np.zeros((prob_size, prob_size))
    previous_city  = 0

    problems_with_context = []
    targets = []
    for i in range(len(tour)):
        if i != 0:
            current_city = tour[i]
            current_path[previous_city, current_city] = 1
            previous_city = current_city

        problem_with_context = np.stack([city_connections_w_costs.copy(), current_path.copy()], axis=0)

        next_city = 0
        if i != len(tour) - 1:
            next_city = tour[i + 1]

        targets.append(next_city)
        problems_with_context.append(problem_with_context)

    batch = torch.tensor(np.array(problems_with_context), dtype=torch.float32)
    targets = torch.tensor(targets, dtype=torch.long)

    return batch, targets

    


In [7]:
item_indx = len(train_dataset) - 23


formatted_batch = convert_problem_and_tour_to_mse_problems(train_dataset[item_indx][0], train_dataset[item_indx][1])

print(formatted_batch[0].shape)
print(formatted_batch[1].shape)

torch.Size([22, 2, 22, 22])
torch.Size([22])
