In [None]:
from GEO_DATA import GeoData
import pandas as pd
import numpy as np
import networkx as nx
import plotly.graph_objs as go
import random

In [None]:
class NearestNeighbour:
    def __init__(self,distance_matrix,cities,cordinates):
        self.distance_matrix = distance_matrix
        self.cities = cities
        self.cordinates = cordinates
        self.n_cities = self.distance_matrix.shape[0]
        self.city_index = {self.cities[i]:i for i in range(self.n_cities)}  
        self.city_index_inv = {i:self.cities[i] for i in range(self.n_cities)}
        
    def solve(self,init_give=False,initial_city=0):
        V = []
        NV = [i for i in range(self.n_cities)]
        
        if init_give:
            initial_city = initial_city
        else:
            initial_city = random.randint(0,self.n_cities)
        current_city = initial_city
        V.append(current_city)
        NV.remove(current_city)
        
        while len(NV) > 0:
            idx = np.argmin(self.distance_matrix[current_city][np.array(NV)])
            nearest_non_visited_neighbour = NV[idx]
            
            V.append(nearest_non_visited_neighbour)
            NV.remove(nearest_non_visited_neighbour)
            current_city = nearest_non_visited_neighbour
        
        V.append(initial_city)
        total_dist = self.get_cost(V)
        self.create_graph_plot(V,total_dist)

        return V,total_dist
    
    def get_cost(self,V):
        s = 0
        for i in range(len(V)-1):
            d = self.distance_matrix[V[i+1]][V[i]]
            s += d
        return s
    
    def create_graph_plot(self,seq,dist):
        G = nx.Graph()
        for i in self.cordinates:
            lat,lng = self.cordinates[i]
            G.add_node(i, pos=(lng,lat))
        

        for i in range(len(seq)-1):
            if i < len(seq)-2:
                frm = self.city_index_inv[seq[i]]
                to = self.city_index_inv[seq[i+1]]
            if i == len(seq) - 2:
                frm = self.city_index_inv[seq[i]]
                to = self.city_index_inv[seq[0]]

            G.add_edge(frm,to)
        
        self.Graph_plotter(G,dist)
            
            
    def Graph_plotter(self,G,dist):
        edge_x = []
        edge_y = []
        for edge in G.edges():
            x0, y0 = G.nodes[edge[0]]['pos']
            x1, y1 = G.nodes[edge[1]]['pos']
            edge_x.append(x0)
            edge_x.append(x1)
            edge_x.append(None)
            edge_y.append(y0)
            edge_y.append(y1)
            edge_y.append(None)

        edge_trace = go.Scatter(
            x=edge_x, y=edge_y,
            line=dict(width=0.5, color='#888'),
            hoverinfo='none',
            mode='lines')

        node_x = []
        node_y = []
        for node in G.nodes():
            x, y = G.nodes[node]['pos']
            node_x.append(x)
            node_y.append(y)

        node_trace = go.Scatter(
            x=node_x, y=node_y,
            mode='markers',
            hoverinfo='text',
                line_width=2)
        
        node_adjacencies = []
        
        node_text = []
        for node, adjacencies in enumerate(G.adjacency()):
            node_adjacencies.append(len(adjacencies[1]))
            node_text.append(self.city_index_inv[node])

        node_trace.marker.color = node_adjacencies
        node_trace.text = node_text
        
        fig = go.Figure(data=[edge_trace, node_trace],
             layout=go.Layout(
                title='TSP Solution - '+str(dist),
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                annotations=[ dict(
                    showarrow=False,
                    xref="paper", yref="paper",
                    x=0.005, y=-0.002 ) ],
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
        fig.show()

In [None]:
g = GeoData()
capitals = ['Stuttgart','Munich','Berlin','Potsdam','Bremen','Hamburg','Wiesbaden','Hanover','Schwerin','Düsseldorf','Mainz','Saarbrücken','Dresden','Magdeburg', 'Kiel', 'Erfurt']
distance_matrix = g.get_distance_matrix(capitals)
cordinates = g.get_dict_coordinates(capitals)

In [None]:
NN = NearestNeighbour(distance_matrix,capitals,cordinates)

In [None]:
NN.solve(init_give=True,initial_city=0)