In [None]:
import random
import torch
from torch_geometric.data import Data
from itertools import permutations
import os
import imageio
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from tqdm import tqdm  # For progress bars

def create_data_model(num_selected_cities=8, seed=None):
    """
    Stores the data for the TSP problem with a variable number of cities.

    Parameters:
    - num_selected_cities (int): Number of cities to include in the TSP instance (including depot).
    - seed (int, optional): Seed for random city selection to ensure reproducibility.

    Returns:
    - data (dict): A dictionary containing TSP data, including selected cities, their coordinates,
                  and the corresponding distance matrix.
    """
    data = {}
    # Coordinates of the cities (latitude, longitude)
    data["cities_info"] = {
        "New York": (40.7128, -74.0060),
        "Los Angeles": (34.0522, -118.2437),
        "Chicago": (41.8781, -87.6298),
        "Denver": (39.7392, -104.9903),
        "Dallas": (32.7767, -96.7970),
        "Seattle": (47.6062, -122.3321),
        "Boston": (42.3601, -71.0589),
        "San Francisco": (37.7749, -122.4194),
        "St. Louis": (38.6270, -90.1994),
        "Houston": (29.7604, -95.3698),
        "Phoenix": (33.4484, -112.0740),
        "Salt Lake City": (40.7608, -111.8910),
    }
    # List of all cities in the order corresponding to the distance matrix
    all_cities = [
        "New York",
        "Los Angeles",
        "Chicago",
        "Denver",
        "Dallas",
        "Seattle",
        "Boston",
        "San Francisco",
        "St. Louis",
        "Houston",
        "Phoenix",
        "Salt Lake City",
    ]
    data["all_cities"] = all_cities  # Keep the original list for reference

    # Distance matrix in miles (12x12 matrix)
    full_distance_matrix = [
        # NY    LA    Chicago Denver Dallas Seattle Boston San Francisco St. Louis Houston Phoenix Salt Lake
        [0,    2451, 713,    1631, 1374, 2408, 213,    2571, 875,    1420, 2145, 1972],  # New York
        [2451, 0,    1745,   831, 1240, 959,  2596, 403,   1589, 1374, 357, 579],      # Los Angeles
        [713,  1745, 0,      920, 803, 1737, 851,  1858, 262,    940, 1453, 1260],      # Chicago
        [1631, 831,  920,    0,   663, 1021, 1769, 949,   796,    879, 586, 371],       # Denver
        [1374, 1240, 803,    663, 0,   1681, 1551, 1765, 547,    225, 887, 999],       # Dallas
        [2408, 959,  1737,  1021, 1681, 0,    2493, 678,   1724, 1891, 1114, 701],      # Seattle
        [213,  2596, 851,    1769, 1551, 2493, 0,    2699, 1038, 1605, 2300, 2099],    # Boston
        [2571, 403,  1858,   949, 1765, 678,  2699, 0,    1744, 1645, 653, 600],       # San Francisco
        [875,  1589, 262,    796, 547, 1724, 1038, 1744, 0,      679, 1272, 1162],      # St. Louis
        [1420, 1374, 940,    879, 225, 1891, 1605, 1645, 679,    0,    1017, 1200],     # Houston
        [2145, 357,  1453,   586, 887, 1114, 2300, 653,   1272, 1017, 0, 504],        # Phoenix
        [1972, 579,  1260,   371, 999, 701,  2099, 600,   1162, 1200, 504, 0],        # Salt Lake City
    ]

    # Validate num_selected_cities
    if seed is not None:
        random.seed(seed)

    if num_selected_cities < 2:
        raise ValueError("At least two cities must be selected (including the depot).")
    if num_selected_cities > len(all_cities):
        raise ValueError(f"Cannot select {num_selected_cities} cities from {len(all_cities)} available.")

    # Define the depot (always included)
    depot = "New York"
    selected_cities = [depot]
    remaining_cities = [city for city in all_cities if city != depot]

    if num_selected_cities > 1:
        selected_additional_cities = random.sample(remaining_cities, num_selected_cities - 1)
        selected_cities.extend(selected_additional_cities)

    # Get indices of selected cities based on the original all_cities list
    selected_indices = [all_cities.index(city) for city in selected_cities]

    # Extract the subset distance matrix
    subset_distance_matrix = []
    for i in selected_indices:
        row = [full_distance_matrix[i][j] for j in selected_indices]
        subset_distance_matrix.append(row)

    # Update data dictionary with selected cities and subset distance matrix
    data["cities"] = selected_cities
    data["distance_matrix"] = subset_distance_matrix
    data["num_cities"] = len(selected_cities)
    data["depot"] = 0  # Depot is always the first city in the selected_cities list

    # Create node features for selected cities
    node_features = []
    for city in selected_cities:
        lat, lon = data["cities_info"][city]
        node_features.append([lat, lon])
    data["node_features"] = torch.tensor(node_features, dtype=torch.float)

    # Create edge index for the selected cities (fully connected graph)
    edge_index = []
    for i in range(data["num_cities"]):
        for j in range(data["num_cities"]):
            if i != j:
                edge_index.append([i, j])
    data["edge_index"] = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Create PyTorch Geometric Data object
    data["pyg_data"] = Data(x=data["node_features"], edge_index=data["edge_index"])

    return data


In [None]:
def solve_tsp_brute_force(data):
    """
    Solves the TSP problem using brute force for a single instance.

    Parameters:
    - data (dict): TSP data model.

    Returns:
    - optimal_route (list): List of city indices representing the optimal tour.
    - min_distance (float): Total distance of the optimal tour.
    """
    cities = data["cities"]
    distance_matrix = data["distance_matrix"]
    depot = data["depot"]
    num_cities = data["num_cities"]

    # Generate all possible routes starting and ending at the depot
    city_indices = list(range(num_cities))
    city_indices.remove(depot)  # Fix the depot at the start

    min_distance = float('inf')
    optimal_route = None

    for perm in permutations(city_indices):
        # Construct the full route starting and ending at the depot
        route = [depot] + list(perm) + [depot]
        # Calculate the total distance
        total_distance = 0
        for i in range(len(route) - 1):
            from_city = route[i]
            to_city = route[i + 1]
            total_distance += distance_matrix[from_city][to_city]
        # Update optimal route if necessary
        if total_distance < min_distance:
            min_distance = total_distance
            optimal_route = route

    return optimal_route, min_distance


In [None]:
class TSPDataset(Dataset):
    def __init__(self, data_list, tours):
        """
        data_list: List of PyTorch Geometric Data objects representing TSP instances.
        tours: List of lists, where each sublist is the sequence of city indices representing the tour.
        """
        self.data_list = data_list
        self.tours = tours

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        return self.data_list[idx], self.tours[idx]


In [None]:
def generate_tsp_dataset(num_instances=100, min_cities=5, max_cities=8, seed=None):
    """
    Generates a dataset of TSP instances with optimal tours.

    Parameters:
    - num_instances (int): Number of TSP instances to generate.
    - min_cities (int): Minimum number of cities in an instance.
    - max_cities (int): Maximum number of cities in an instance.
    - seed (int, optional): Seed for reproducibility.

    Returns:
    - data_list (list): List of PyTorch Geometric Data objects.
    - tours (list): List of optimal tours corresponding to each instance.
    """
    if seed is not None:
        random.seed(seed)

    data_list = []
    tours = []

    for i in tqdm(range(num_instances), desc="Generating TSP Instances"):
        num_selected_cities = random.randint(min_cities, max_cities)
        data = create_data_model(num_selected_cities=num_selected_cities, seed=None)
        optimal_route, min_distance = solve_tsp_brute_force(data)
        data_list.append(data["pyg_data"])
        tours.append(optimal_route)

    return data_list, tours


In [None]:
import torch.nn as nn
from torch_geometric.nn import GATConv

class GATEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super(GATEncoder, self).__init__()
        self.gat1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
        self.gat2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index)
        x = self.relu(x)
        x = self.gat2(x, edge_index)
        return x  # Final node embeddings


In [None]:
import torch

# Device configuration: use CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
import torch.nn.functional as F

class PointerNetwork(nn.Module):
    def __init__(self, encoder, hidden_size, output_size):
        super(PointerNetwork, self).__init__()
        self.encoder = encoder
        self.hidden_size = hidden_size
        self.output_size = output_size

        # Decoder RNN
        self.decoder_rnn = nn.LSTMCell(hidden_size, hidden_size)

        # Attention mechanism
        self.pointer = nn.Linear(hidden_size * 2, 1)

    def forward(self, data, target=None):
        # Encode the graph
        encoder_outputs = self.encoder(data.x, data.edge_index)  # [num_nodes, hidden_size]

        # Initialize decoder state
        batch_size = 1  # Assuming single instance
        decoder_hidden = torch.zeros(batch_size, self.hidden_size)
        decoder_cell = torch.zeros(batch_size, self.hidden_size)

        # Start with the depot
        input = encoder_outputs[data.depot].unsqueeze(0)  # [batch_size, hidden_size]

        # Store the tour
        tour = []
        pointers = []

        # Mask to keep track of visited cities
        mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        mask[data.depot] = True  # Start at depot

        for _ in range(data.num_cities - 1):  # Exclude depot at the end
            # Decode step
            decoder_hidden, decoder_cell = self.decoder_rnn(input, (decoder_hidden, decoder_cell))

            # Calculate attention scores
            attn_scores = self.pointer(torch.cat([decoder_hidden, encoder_outputs], dim=1))  # [num_nodes, 1]
            attn_scores = attn_scores.squeeze(-1)  # [num_nodes]
            attn_scores = attn_scores.masked_fill(mask, float('-inf'))  # Mask visited cities
            attn_weights = F.softmax(attn_scores, dim=0)  # [num_nodes]

            # Pointer to next city
            pointer = torch.argmax(attn_weights).item()
            tour.append(pointer)
            pointers.append(attn_weights)

            # Update mask
            mask[pointer] = True

            # Prepare input for next step
            input = encoder_outputs[pointer].unsqueeze(0)

        # Append depot to complete the tour
        tour.append(data.depot)

        return tour, pointers


In [None]:
class TSPModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(TSPModel, self).__init__()
        self.encoder = GATEncoder(in_channels, hidden_channels, out_channels)
        self.pointer_network = PointerNetwork(self.encoder, hidden_size=out_channels, output_size=out_channels)

    def forward(self, data, target=None):
        tour, pointers = self.pointer_network(data, target)
        return tour, pointers


In [None]:
def loss_function(tour_pred, tour_true):
    """
    Calculates the cross-entropy loss for the predicted tour.

    Parameters:
    - tour_pred (list): List of predicted city indices.
    - tour_true (list): List of true city indices.

    Returns:
    - loss (float): Calculated loss value.
    """
    loss = 0
    for pred, true in zip(tour_pred[:-1], tour_true[:-1]):  # Exclude the final depot
        pred_tensor = torch.tensor(pred).unsqueeze(0)  # [1]
        true_tensor = torch.tensor([true])
        loss += F.cross_entropy(pred_tensor, true_tensor)
    return loss


In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader

def train_model(model, dataloader, epochs=50, learning_rate=1e-3, device='cpu'):
    """
    Trains the TSP model.

    Parameters:
    - model (nn.Module): The TSP model to train.
    - dataloader (DataLoader): DataLoader for the training dataset.
    - epochs (int): Number of training epochs.
    - learning_rate (float): Learning rate for the optimizer.
    - device (str): Device to run the training on ('cpu' or 'cuda').

    Returns:
    - model (nn.Module): The trained model.
    """
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()
    for epoch in range(1, epochs + 1):
        total_loss = 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}"):
            data, tour = batch
            data = data.to(device)
            optimizer.zero_grad()
            tour_pred, _ = model(data)
            loss = loss_function(tour_pred, tour)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.4f}")

    return model


In [None]:
def generate_tsp_dataset(num_instances=100, min_cities=5, max_cities=8, seed=None):
    """
    Generates a dataset of TSP instances with optimal tours.

    Parameters:
    - num_instances (int): Number of TSP instances to generate.
    - min_cities (int): Minimum number of cities in an instance.
    - max_cities (int): Maximum number of cities in an instance.
    - seed (int, optional): Seed for reproducibility.

    Returns:
    - data_list (list): List of PyTorch Geometric Data objects.
    - tours (list): List of optimal tours corresponding to each instance.
    """
    if seed is not None:
        random.seed(seed)

    data_list = []
    tours = []

    for _ in tqdm(range(num_instances), desc="Generating TSP Instances"):
        num_selected_cities = random.randint(min_cities, max_cities)
        data = create_data_model(num_selected_cities=num_selected_cities, seed=None)
        optimal_route, min_distance = solve_tsp_brute_force(data)
        data_list.append(data["pyg_data"])
        tours.append(optimal_route)

    return data_list, tours


In [None]:
def evaluate_model(model, data, tour_true, device='cpu'):
    """
    Evaluates the model on a single TSP instance.

    Parameters:
    - model (nn.Module): The trained TSP model.
    - data (Data): PyTorch Geometric Data object representing the TSP instance.
    - tour_true (list): The true optimal tour (list of city indices).
    - device (str): Device to run the evaluation on ('cpu' or 'cuda').

    Returns:
    - total_distance_pred (float): Total distance of the predicted tour.
    - total_distance_true (float): Total distance of the true tour.
    - accuracy (float): Percentage of cities correctly predicted at each step.
    - tour_pred (list): The predicted tour (list of city indices).
    """
    model.eval()
    with torch.no_grad():
        data = data.to(device)
        tour_pred, _ = model(data)

    # Calculate total distance for predicted tour
    distance_matrix = data["distance_matrix"]
    total_distance_pred = calculate_total_distance(tour_pred, distance_matrix)

    # Calculate total distance for true tour
    total_distance_true = calculate_total_distance(tour_true, distance_matrix)

    # Calculate accuracy (cities correctly predicted at each step)
    correct_steps = sum([pred == true for pred, true in zip(tour_pred, tour_true)])
    accuracy = correct_steps / len(tour_true) * 100

    return total_distance_pred, total_distance_true, accuracy, tour_pred


In [None]:
def plot_route(route, cities_info, filename="predicted_route.png"):
    """
    Plots the given route on a map.

    Parameters:
    - route (list): List of city names representing the tour.
    - cities_info (dict): Dictionary with city names as keys and (lat, lon) tuples as values.
    - filename (str): Filename to save the plot.
    """
    # Extract latitude and longitude for the route
    lats = []
    lons = []
    for city in route:
        lat, lon = cities_info[city]
        lats.append(lat)
        lons.append(lon)

    # Create a plot with Cartopy
    plt.figure(figsize=(15, 10))
    ax = plt.axes(projection=ccrs.LambertConformal())
    ax.set_extent([-125, -66.5, 24, 50], ccrs.Geodetic())

    # Add map features
    ax.add_feature(cfeature.LAND)
    ax.add_feature(cfeature.OCEAN)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.STATES, linestyle=':')

    # Plot the cities
    ax.scatter(lons, lats, color='red', s=100, transform=ccrs.Geodetic())

    # Annotate city names
    for city, lat, lon in zip(route, lats, lons):
        ax.text(lon + 0.5, lat + 0.5, city, fontsize=9, transform=ccrs.Geodetic())

    # Plot the route lines
    ax.plot(lons, lats, color='blue', linewidth=2, marker='o', transform=ccrs.Geodetic())

    # Add title
    plt.title("Predicted TSP Route", fontsize=16)

    # Save the figure
    plt.savefig(filename)
    plt.close()


In [None]:
import random
import torch
from torch_geometric.data import Data
from itertools import permutations
import os
import imageio
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

# =======================
# Data Model Function
# =======================

def create_data_model(num_selected_cities=12, seed=None):
    """
    Stores the data for the TSP problem with a variable number of cities.

    Parameters:
    - num_selected_cities (int): Number of cities to include in the TSP instance (including depot).
    - seed (int, optional): Seed for random city selection to ensure reproducibility.

    Returns:
    - data (dict): A dictionary containing TSP data, including selected cities, their coordinates,
                  and the corresponding distance matrix.
    """
    data = {}
    # Coordinates of the cities (latitude, longitude)
    data["cities_info"] = {
        "New York": (40.7128, -74.0060),
        "Los Angeles": (34.0522, -118.2437),
        "Chicago": (41.8781, -87.6298),
        "Denver": (39.7392, -104.9903),
        "Dallas": (32.7767, -96.7970),
        "Seattle": (47.6062, -122.3321),
        "Boston": (42.3601, -71.0589),
        "San Francisco": (37.7749, -122.4194),
        "St. Louis": (38.6270, -90.1994),
        "Houston": (29.7604, -95.3698),
        "Phoenix": (33.4484, -112.0740),
        "Salt Lake City": (40.7608, -111.8910),
    }
    # List of all cities in the order corresponding to the distance matrix
    all_cities = [
        "New York",
        "Los Angeles",
        "Chicago",
        "Denver",
        "Dallas",
        "Seattle",
        "Boston",
        "San Francisco",
        "St. Louis",
        "Houston",
        "Phoenix",
        "Salt Lake City",
    ]
    data["all_cities"] = all_cities  # Keep the original list for reference

    # Distance matrix in miles (12x12 matrix)
    full_distance_matrix = [
        # NY    LA    Chicago Denver Dallas Seattle Boston San Francisco St. Louis Houston Phoenix Salt Lake
        [0,    2451, 713,    1631, 1374, 2408, 213,    2571, 875,    1420, 2145, 1972],  # New York
        [2451, 0,    1745,   831, 1240, 959,  2596, 403,   1589, 1374, 357, 579],      # Los Angeles
        [713,  1745, 0,      920, 803, 1737, 851,  1858, 262,    940, 1453, 1260],      # Chicago
        [1631, 831,  920,    0,   663, 1021, 1769, 949,   796,    879, 586, 371],       # Denver
        [1374, 1240, 803,    663, 0,   1681, 1551, 1765, 547,    225, 887, 999],       # Dallas
        [2408, 959,  1737,  1021, 1681, 0,    2493, 678,   1724, 1891, 1114, 701],      # Seattle
        [213,  2596, 851,    1769, 1551, 2493, 0,    2699, 1038, 1605, 2300, 2099],    # Boston
        [2571, 403,  1858,   949, 1765, 678,  2699, 0,    1744, 1645, 653, 600],       # San Francisco
        [875,  1589, 262,    796, 547, 1724, 1038, 1744, 0,      679, 1272, 1162],      # St. Louis
        [1420, 1374, 940,    879, 225, 1891, 1605, 1645, 679,    0,    1017, 1200],     # Houston
        [2145, 357,  1453,   586, 887, 1114, 2300, 653,   1272, 1017, 0, 504],        # Phoenix
        [1972, 579,  1260,   371, 999, 701,  2099, 600,   1162, 1200, 504, 0],        # Salt Lake City
    ]

    # Validate num_selected_cities
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)

    if num_selected_cities < 2:
        raise ValueError("At least two cities must be selected (including the depot).")
    if num_selected_cities > len(all_cities):
        raise ValueError(f"Cannot select {num_selected_cities} cities from {len(all_cities)} available.")

    # Define the depot (always included)
    depot = "New York"
    selected_cities = [depot]
    remaining_cities = [city for city in all_cities if city != depot]

    if num_selected_cities > 1:
        selected_additional_cities = random.sample(remaining_cities, num_selected_cities - 1)
        selected_cities.extend(selected_additional_cities)

    # Get indices of selected cities based on the original all_cities list
    selected_indices = [all_cities.index(city) for city in selected_cities]

    # Extract the subset distance matrix
    subset_distance_matrix = []
    for i in selected_indices:
        row = [full_distance_matrix[i][j] for j in selected_indices]
        subset_distance_matrix.append(row)

    # Update data dictionary with selected cities and subset distance matrix
    data["cities"] = selected_cities
    data["distance_matrix"] = subset_distance_matrix
    data["num_cities"] = len(selected_cities)
    data["depot"] = 0  # Depot is always the first city in the selected_cities list

    # Create node features for selected cities
    node_features = []
    for city in selected_cities:
        lat, lon = data["cities_info"][city]
        node_features.append([lat, lon])
    data["node_features"] = torch.tensor(node_features, dtype=torch.float)

    # Create edge index for the selected cities (fully connected graph)
    edge_index = []
    for i in range(data["num_cities"]):
        for j in range(data["num_cities"]):
            if i != j:
                edge_index.append([i, j])
    data["edge_index"] = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Create PyTorch Geometric Data object
    data["pyg_data"] = Data(x=data["node_features"], edge_index=data["edge_index"])

    return data

# =======================
# Calculate Total Distance
# =======================

def calculate_total_distance(route, distance_matrix):
    """
    Calculates the total distance of the given route.

    Parameters:
    - route (list): List of city indices representing the tour.
    - distance_matrix (list of lists): Distance matrix in miles.

    Returns:
    - total_distance (float): Total distance of the tour.
    """
    total_distance = 0
    for i in range(len(route) - 1):
        from_city = route[i]
        to_city = route[i + 1]
        total_distance += distance_matrix[from_city][to_city]
    return total_distance

# =======================
# Plot Route Function
# =======================

def plot_route(route, cities_info, temp_dir, frame_number):
    """
    Plots the route and saves the plot as an image in the temporary directory.

    Parameters:
    - route (list): List of city names representing the tour.
    - cities_info (dict): Dictionary with city names as keys and (lat, lon) tuples as values.
    - temp_dir (str): Path to the directory where frames are saved.
    - frame_number (int): Frame number for naming the saved image.
    """
    # Extract latitude and longitude for the route
    lats = []
    lons = []
    for city in route:
        lat, lon = cities_info[city]
        lats.append(lat)
        lons.append(lon)

    # Create a plot with Cartopy
    plt.figure(figsize=(15, 10))
    ax = plt.axes(projection=ccrs.LambertConformal())
    ax.set_extent([-125, -66.5, 24, 50], ccrs.Geodetic())

    # Add map features
    ax.add_feature(cfeature.LAND)
    ax.add_feature(cfeature.OCEAN)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.STATES, linestyle=':')

    # Plot the cities
    ax.scatter(lons, lats, color='red', s=100, transform=ccrs.Geodetic())

    # Annotate city names
    for city, lat, lon in zip(route, lats, lons):
        ax.text(lon + 0.5, lat + 0.5, city, fontsize=9, transform=ccrs.Geodetic())

    # Plot the route lines
    ax.plot(lons, lats, color='blue', linewidth=2, marker='o', transform=ccrs.Geodetic())

    # Add frame number as subtitle
    plt.title("Selected TSP Route", fontsize=16)
    plt.suptitle(f"Frame {frame_number}", fontsize=10, y=0.95)

    # Save the figure
    frame_filename = os.path.join(temp_dir, f"frame_{frame_number:05d}.png")
    plt.savefig(frame_filename)
    plt.close()

# =======================
# Create GIF Function
# =======================

def create_gif_from_frames(route_history, temp_dir, gif_filename="tsp_route.gif", duration=0.5):
    """
    Creates a GIF from the saved plot frames based on the route history.

    Parameters:
    - route_history (list): List of routes (each route is a list of city names).
    - temp_dir (str): Directory where the plot images are stored.
    - gif_filename (str): Filename for the output GIF.
    - duration (float): Duration between frames in the GIF in seconds.
    """
    images = []
    print("Creating GIF...")

    for idx, route in enumerate(route_history):
        frame_filename = os.path.join(temp_dir, f"frame_{idx + 1:05d}.png")
        if os.path.exists(frame_filename):
            images.append(imageio.imread(frame_filename))
        else:
            print(f"Warning: {frame_filename} does not exist and will be skipped.")

    if images:
        imageio.mimsave(gif_filename, images, duration=duration)
        print(f"GIF saved as {gif_filename}")
    else:
        print("No images found to create GIF.")

# =======================
# Solve TSP Brute Force
# =======================

def solve_tsp_brute_force_with_history(data, temp_dir):
    """
    Solves the TSP problem using brute force and records route history.

    Parameters:
    - data (dict): TSP data model.
    - temp_dir (str): Path to the directory where frames are saved.

    Returns:
    - optimal_route (list): List of city indices representing the optimal tour.
    - min_distance (float): Total distance of the optimal tour.
    - permutations_checked (int): Number of permutations evaluated.
    - elapsed_time (float): Time taken to solve the TSP.
    - route_history (list): List of optimal routes found during the search.
    """
    cities = data["cities"]
    distance_matrix = data["distance_matrix"]
    depot = data["depot"]
    num_cities = data["num_cities"]

    # Generate all possible routes starting and ending at the depot
    city_indices = list(range(num_cities))
    city_indices.remove(depot)  # Fix the depot at the start

    min_distance = float('inf')
    optimal_route = None
    route_history = []  # To store each new optimal route

    # Calculate total permutations: (n-1)!
    total_permutations = 1
    for i in range(2, num_cities):
        total_permutations *= i

    print(f"Total permutations to evaluate: {total_permutations}")
    start_time = time.time()

    permutations_checked = 0

    for perm in permutations(city_indices):
        # Construct the full route starting and ending at the depot
        route = [depot] + list(perm) + [depot]
        # Calculate the total distance
        total_distance = 0
        valid = True
        for i in range(len(route) - 1):
            from_city = route[i]
            to_city = route[i + 1]
            distance = distance_matrix[from_city][to_city]
            total_distance += distance
            # Branch and Bound: if current distance exceeds min_distance, stop evaluating this permutation
            if total_distance >= min_distance:
                valid = False
                break
        if valid and total_distance < min_distance:
            min_distance = total_distance
            optimal_route = route
            # Convert route indices to city names
            route_cities = [data["cities"][city] for city in optimal_route]
            # Save the new optimal route to history
            route_history.append(list(route_cities))
            # Save the plot frame
            plot_route(list(route_cities), data["cities_info"], temp_dir, len(route_history))
            # Print progress
            elapsed = time.time() - start_time
            print(f"New optimal distance: {min_distance} miles found at permutation {permutations_checked}")
            print(f"Route: {' -> '.join(route_cities)}")
        permutations_checked += 1
        # Progress indicator every 1,000,000 permutations
        if permutations_checked % 1000000 == 0:
            elapsed = time.time() - start_time
            print(f"Checked {permutations_checked} permutations in {elapsed:.2f} seconds...")

    end_time = time.time()
    elapsed_time = end_time - start_time

    if optimal_route:
        return optimal_route, min_distance, permutations_checked, elapsed_time, route_history
    else:
        return None, None, permutations_checked, elapsed_time, route_history

# =======================
# GNN/GAT Model Definitions
# =======================

class GATEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super(GATEncoder, self).__init__()
        self.gat1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
        self.gat2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index)
        x = self.relu(x)
        x = self.gat2(x, edge_index)
        return x  # Final node embeddings

class PointerNetwork(nn.Module):
    def __init__(self, encoder, hidden_size, output_size):
        super(PointerNetwork, self).__init__()
        self.encoder = encoder
        self.hidden_size = hidden_size
        self.output_size = output_size

        # Decoder RNN
        self.decoder_rnn = nn.LSTMCell(hidden_size, hidden_size)

        # Attention mechanism
        self.pointer = nn.Linear(hidden_size * 2, 1)

    def forward(self, data, target=None):
        # Encode the graph
        encoder_outputs = self.encoder(data.x, data.edge_index)  # [num_nodes, hidden_size]

        # Initialize decoder state
        batch_size = 1  # Assuming single instance
        decoder_hidden = torch.zeros(batch_size, self.hidden_size).to(encoder_outputs.device)
        decoder_cell = torch.zeros(batch_size, self.hidden_size).to(encoder_outputs.device)

        # Start with the depot
        input = encoder_outputs[data.depot].unsqueeze(0)  # [batch_size, hidden_size]

        # Store the tour
        tour = []
        pointers = []

        # Mask to keep track of visited cities
        mask = torch.zeros(data.num_nodes, dtype=torch.bool).to(encoder_outputs.device)
        mask[data.depot] = True  # Start at depot

        for _ in range(data.num_cities - 1):
            # Decode step
            decoder_hidden, decoder_cell = self.decoder_rnn(input, (decoder_hidden, decoder_cell))

            # Calculate attention scores
            attn_scores = self.pointer(torch.cat([decoder_hidden, encoder_outputs], dim=1))  # [num_nodes, 1]
            attn_scores = attn_scores.squeeze(-1)  # [num_nodes]
            attn_scores = attn_scores.masked_fill(mask, float('-inf'))  # Mask visited cities
            attn_weights = F.softmax(attn_scores, dim=0)  # [num_nodes]

            # Pointer to next city
            pointer = torch.argmax(attn_weights).item()
            tour.append(pointer)
            pointers.append(attn_weights)

            # Update mask
            mask[pointer] = True

            # Prepare input for next step
            input = encoder_outputs[pointer].unsqueeze(0)

        # Complete the tour by returning to depot
        tour.append(data.depot)
        return tour, pointers

class TSPModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(TSPModel, self).__init__()
        self.encoder = GATEncoder(in_channels, hidden_channels, out_channels)
        self.pointer_network = PointerNetwork(self.encoder, hidden_size=out_channels, output_size=out_channels)

    def forward(self, data, target=None):
        tour, pointers = self.pointer_network(data, target)
        return tour, pointers

# =======================
# Dataset Class
# =======================

class TSPDataset(Dataset):
    def __init__(self, data_list, tours):
        """
        data_list: List of PyTorch Geometric Data objects representing TSP instances.
        tours: List of lists, where each sublist is the sequence of city indices representing the optimal tour.
        """
        self.data_list = data_list
        self.tours = tours

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        return self.data_list[idx], self.tours[idx]

# =======================
# Generate TSP Dataset
# =======================

def generate_tsp_dataset(num_samples, num_cities, seed=None):
    """
    Generates a dataset of TSP instances and their optimal tours using brute-force.

    Parameters:
    - num_samples (int): Number of TSP instances to generate.
    - num_cities (int): Number of cities per TSP instance (including depot).
    - seed (int, optional): Seed for random selection to ensure reproducibility.

    Returns:
    - data_list (list): List of PyTorch Geometric Data objects representing TSP instances.
    - tours (list): List of lists, where each sublist is the sequence of city indices representing the optimal tour.
    """
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)

    data_list = []
    tours = []

    for sample_idx in range(num_samples):
        # Create a random data model
        data = create_data_model(num_selected_cities=num_cities, seed=seed)

        cities = data["cities"]
        distance_matrix = data["distance_matrix"]
        depot = data["depot"]
        num_cities = data["num_cities"]

        # Generate all possible routes starting and ending at the depot
        city_indices = list(range(num_cities))
        city_indices.remove(depot)  # Fix the depot at the start

        min_distance = float('inf')
        optimal_route = None

        # For practical purposes, limit to small n (e.g., n <= 10)
        if num_cities > 10:
            raise ValueError("Brute-force TSP solver is impractical for more than 10 cities.")

        for perm in permutations(city_indices):
            # Construct the full route starting and ending at the depot
            route = [depot] + list(perm) + [depot]
            # Calculate the total distance
            total_distance = 0
            for i in range(len(route) - 1):
                from_city = route[i]
                to_city = route[i + 1]
                total_distance += distance_matrix[from_city][to_city]
            # Update optimal route if necessary
            if total_distance < min_distance:
                min_distance = total_distance
                optimal_route = route

        if optimal_route is None:
            raise Exception("No valid route found.")

        # Append to datasets
        data_list.append(data["pyg_data"])
        tours.append(optimal_route)

        print(f"Generated sample {sample_idx + 1}/{num_samples}")

    return data_list, tours

# =======================
# Loss Function
# =======================

def loss_function(tour_pred, tour_true):
    """
    Calculates the cross-entropy loss between predicted tour and true tour.

    Parameters:
    - tour_pred (list): List of predicted city indices.
    - tour_true (list): List of true city indices.

    Returns:
    - loss (torch.Tensor): Computed loss value.
    """
    loss = 0.0
    for pred, true in zip(tour_pred, tour_true):
        pred_tensor = torch.tensor([pred], dtype=torch.long).to(tour_pred[0].device)
        true_tensor = torch.tensor([true], dtype=torch.long).to(tour_pred[0].device)
        loss += F.cross_entropy(pred_tensor, true_tensor)
    return loss

# =======================
# Train Model Function
# =======================

def train_model(model, dataloader, epochs=100, learning_rate=1e-3, device='cpu'):
    """
    Trains the GNN/GAT model.

    Parameters:
    - model (nn.Module): The GNN/GAT model to train.
    - dataloader (DataLoader): DataLoader for the training dataset.
    - epochs (int): Number of training epochs.
    - learning_rate (float): Learning rate for the optimizer.
    - device (torch.device): Device to train on (CPU or CUDA).

    Returns:
    - model (nn.Module): The trained model.
    """
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0

        for batch_idx, (data, tour) in enumerate(dataloader):
            data = data.to(device)
            tour = [int(city) for city in tour[0]]  # Assuming batch_size=1

            optimizer.zero_grad()
            tour_pred, _ = model(data)
            loss = loss_function(tour_pred, tour)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch}/{epochs}, Loss: {avg_loss:.4f}")

    return model

# =======================
# Evaluate Model Function
# =======================

def evaluate_model(model, data, tour_true, device='cpu'):
    """
    Evaluates the model on a single TSP instance.

    Parameters:
    - model (nn.Module): The trained GNN/GAT model.
    - data (Data): PyTorch Geometric Data object representing the TSP instance.
    - tour_true (list): List of true city indices representing the optimal tour.
    - device (torch.device): Device to perform evaluation on.

    Returns:
    - total_distance (float): Total distance of the predicted tour.
    - accuracy (float): Percentage of correctly predicted cities in the tour.
    - tour_pred (list): List of predicted city indices representing the tour.
    """
    model.eval()
    with torch.no_grad():
        data = data.to(device)
        tour_pred, _ = model(data)
    # Calculate total distance
    total_distance = calculate_total_distance(tour_pred, data["distance_matrix"])
    # Calculate accuracy (excluding the depot at the end)
    accuracy = sum([pred == true for pred, true in zip(tour_pred[:-1], tour_true[:-1])]) / (len(tour_true) -1)
    return total_distance, accuracy * 100, tour_pred

# =======================
# Main Function
# =======================

def main():
    # Desired number of cities (including depot)
    num_selected_cities = 8  # Change this value as needed (e.g., 5, 7, 10, etc.)

    # Optional: Set a seed for reproducibility
    seed = 42

    # Device configuration: use CUDA if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create the data model with the specified number of cities
    data = create_data_model(num_selected_cities=num_selected_cities, seed=seed)

    # Print selected cities and their coordinates
    print(f"\nSelected {data['num_cities']} Cities:")
    for idx, city in enumerate(data["cities"]):
        print(f"{idx}: {city} at {data['cities_info'][city]}")

    # Display the subset distance matrix
    print("\nSubset Distance Matrix (in miles):")
    for row in data["distance_matrix"]:
        print(row)

    # Create a temporary directory to store plot images
    temp_dir = "tsp_temp_frames"
    os.makedirs(temp_dir, exist_ok=True)

    print("\nGenerating TSP Dataset...")
    # Generate dataset (for demonstration, using a single sample)
    # For effective training, generate multiple samples
    num_samples = 1  # Change as needed
    tours = []
    data_list, tours = generate_tsp_dataset(num_samples=num_samples, num_cities=num_selected_cities, seed=seed)

    # Initialize dataset and dataloader
    dataset = TSPDataset(data_list, tours)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Initialize the model
    model = TSPModel(in_channels=2, hidden_channels=128, out_channels=128)

    # Train the model
    print("\nStarting Training...")
    model = train_model(model, dataloader, epochs=100, learning_rate=1e-3, device=device)

    # Evaluate the model on the first sample
    print("\nEvaluating the Model on the First Sample...")
    test_data = data_list[0]
    test_tour_true = tours[0]
    total_distance, accuracy, tour_pred = evaluate_model(model, test_data, test_tour_true, device=device)
    print(f"\nPredicted Tour: {tour_pred}")
    print(f"True Tour: {test_tour_true}")
    print(f"Total Distance: {total_distance} miles")
    print(f"Accuracy: {accuracy:.2f}%")

    # Plot the predicted tour
    route_cities = [data["cities"][city] for city in tour_pred]
    plot_route(route_cities, data["cities_info"], temp_dir=temp_dir, frame_number=1)

    # Create the GIF from route history (for demonstration, using the optimal tour)
    create_gif_from_frames([route_cities], temp_dir, gif_filename=f"tsp_route_{num_selected_cities}.gif", duration=0.5)

    print("\nProcess Completed.")

if __name__ == "__main__":
    main()
