In [None]:
#!/usr/bin/python
# coding=utf-8

import requests
import time
import json
import ast
import os
import math
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from lxml import etree
import networkx as nx
import matplotlib.pyplot as plt

PAGE_URL = 'http://map.amap.com/subway/index.html?&1100'
DATA_URL = 'http://map.amap.com/service/subway?srhdata='
#HEADER  = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36"}
## Replace Google Chrome version to suit my computer
HEADER  = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/67.0.3396.62 Safari/537.36"}
SAVEPATH = 'C:/Users/zengluci/jupyters_and_slides/2019-autumn/deliverable/Assignment2/data/'

def fetchAllCity(url, header):
    r = requests.get(url, header)
    html = r.content
    element = etree.HTML(html)
    options = element.xpath("//a[contains(@class, 'city')]")
    
    cities = []
    for option in options:
        city = {
            'id': option.get('id'),
            'name': option.get('cityname'),
            'text': option.text
        }
        cities.append(city)

    return cities



def parseOneCityData(city):
    # 启动一个chrome浏览器
    browser = webdriver.Chrome()
    browser.get(PAGE_URL)
    data = [];
    data.append(parseCityData(city, browser))
#    print(data)
    return data 



def parseCityData(city, browser):
    apiData   = parseCityDataFromApi(city)
    #domData   = parseCityDataFromDom(city, browser)
    #return    formatCityData(apiData, domData)
    return  formatCityMetroTree(apiData)
        

def parseCityDataFromApi(city):
    url =  DATA_URL + "{}_drw_{}.json".format(city['id'], city['name'])
    #print(url)
    r = requests.get(url)
    #  字符串转json(ast.literal_eval())
    return eval(r.text.encode('utf-8'))
    


def formatCityMetroTree(apiData):

    metrotree = []
   
    metrolines = apiData['l']

    for lidx, line in enumerate(metrolines):
        
        linestations=line['st']

        #print("Line Name:",line['kn']+line['la'])
        for sdx,station in enumerate(linestations):
            metrostation={
                        'sid': "",                     #id of the station
                        'sname':"",                    #name of the station       
                        't':0,                         #an exchange station or not
                        'geoCoord':[],                 #coordination of the station
                        'ns': []                      #lid,lname,sid,sname
            }
          
          
            existednode=0
            for ndx,node in enumerate(metrotree):      # existed station in metrotree
                if node['sid']==station['sid']:
                    existednode=1
                    metrotree[ndx]['ns']=FindNextStations(linestations,sdx,metrotree[ndx]['ns'])                  
                    break
            
            if existednode==0:
                                
                metrostation={'sid':station['sid'],
                             'sname':station['n'],
                             't':station['t'],
                             'geoCoord':station['sl'],
                             'ns':FindNextStations(linestations,sdx,[])
                            }           
                metrotree.append(metrostation)                  
    return metrotree

def FindNextStations(stations,currentsin,nextstations):                 #this is to search the connection points of a station          
    sequence=currentsin                                     
    maxstation=len(stations)
    newstations=[]

    if sequence>0 and sequence<maxstation-1:                            #middle stations in a line
            if stations[sequence-1]['n'] not in nextstations:
                newstations.append(stations[sequence-1]['n'])
            if stations[sequence+1]['n']not in nextstations:
                newstations.append(stations[sequence+1]['n'])
    else:
        if sequence==0:                                                 # start station
            if stations[sequence+1]['n']not in nextstations:
                nextstations.append(stations[sequence+1]['n'])
        else:                                                           # end station
            if stations[sequence-1]['n']not in nextstations:
                nextstations.append(stations[sequence-1]['n'])   
    nextstations.extend(newstations) 
    
    return nextstations
              



def DrawMetroMap(metro_info, matro_connection):
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False

    metro_info.keys()
    metro_graph = nx.Graph()

    metro_graph.add_nodes_from(list(metro_info.keys()))
    nx.draw(metro_graph, metro_info, with_labels=False, node_size=5)

    metro_connection_graph = nx.Graph(metro_connection)
    nx.draw(metro_connection_graph,metro_info,with_labels=False,node_size=5)
    
    return



def geo_distance(origin, destination):
    """
    Calculate the Haversine distance.

    Parameters
    ----------
    origin : tuple of float
        (lat, long)
    destination : tuple of float
        (lat, long)

    Returns
    -------
    distance_in_km : float

    Examples
    --------
    >>> origin = (48.1372, 11.5756)  # Munich
    >>> destination = (52.5186, 13.4083)  # Berlin
    >>> round(distance(origin, destination), 1)
    504.2
    """
    lat1, lon1 = origin
    lat2, lon2 = destination
    radius = 6371  # km

    dlat = math.radians(lat2 - lat1)
    dlon = math.radians(lon2 - lon1)
    a = (math.sin(dlat / 2) * math.sin(dlat / 2) +
         math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) *
         math.sin(dlon / 2) * math.sin(dlon / 2))
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    d = radius * c

    return d


def get_metro_distance(station1,station2):
    return geo_distance(metro_info[station1],metro_info[station2])



def search_bfs(graph,start,destination):
    pathes = [[start]]
    visited = set()
    
    while pathes:
        path = pathes.pop(0)
        froniter = path[-1]
        
        if froniter in visited: continue
            
        successsors = graph[froniter]
        
        for city in successsors:
            if city in path: continue  # check loop
            
            new_path = path+[city]
            
            pathes.append(new_path)  #bfs
            #pathes = [new_path] + pathes #dfs
            
            if city == destination:
                return new_path
        visited.add(froniter)


    
def search_dfs(graph,start,destination):
    pathes = [[start]]
    visited = set()
    
    while pathes:
        path = pathes.pop(0)
        froniter = path[-1]
        
        if froniter in visited: continue
            
        successsors = graph[froniter]
        
        for city in successsors:
            if city in path: continue  # check loop
            
            new_path = path+[city]
            
            #pathes.append(new_path)  #bfs
            pathes = [new_path] + pathes #dfs
            
            if city == destination:
                return new_path
            
            
        visited.add(froniter)


def search_opt(graph,start,destination,search_strategy):
    pathes = [[start]]

    while pathes:
        path = pathes.pop(0)
        froniter = path[-1]
        #if froniter in visited : continue
        #if froniter == destination:
        #    return path
        successsors = graph[froniter]
        
        for city in successsors:
            if city in path: continue  # check loop
            
            new_path = path+[city]
            
            pathes.append(new_path)  #bfs
            
        pathes = search_strategy(pathes)
  
        if pathes and (destination == pathes[0][-1]):
            return pathes[0]  
        
        
def sort_by_distance(pathes):
#    def get_distance_of_path(path):
#        distance = 0
#        for i,_ in enumerate(path[:-1]):
#            distance += get_metro_distance(path[i],path[i+1])
#        return distance
    return sorted(pathes,key=get_distance_of_path)       
        

def get_distance_of_path(path):
    distance = 0
    for i,_ in enumerate(path[:-1]):
        distance += get_metro_distance(path[i],path[i+1])
    return distance
  




def CityMetro(CITYNAME):                                                #
    cities = fetchAllCity(PAGE_URL, HEADER)
 #   print(cities)
    InScope=0
    metrograph=[]
    for city in cities:
        if city['name']==CITYNAME:
            InScope=1
            metrograph=parseOneCityData(city)
            break
            
    if InScope==0:
        print('Sorry,city {}is not in webiste {}'.format(city['name'],PAGE_URL))
        
    return metrograph[0]
    
    
    
def main():
    CITYNAME = 'shanghai'

    metro_info={}
    metro_connection={}
    station={}
    searchtree=CityMetro(CITYNAME)

    for station in searchtree:
        metro_info[station['sname']]=tuple(map(float,station['geoCoord'].split(",")))             #convert coordination from char to float
        metro_connection[station['sname']]=station['ns']
    
    DrawMetroMap(metro_info,metro_connection)
    
    start_station="桂林路"
    end_station="海伦路"
    
    print('\n',"Compare search result from ",start_station,"to ",end_station, " for 3 search stratedgies: DFS,BFS and sort_by_distance.The result is as below:",'\n\n\n')
    
    dfs_path=search_dfs(metro_connection,start_station,end_station)
    dfs_distance=get_distance_of_path(dfs_path)
    print("DFS pathes is ",dfs_path,'\n',"DFS pathes distance is",dfs_distance,'\n\n')
    
    bfs_path=search_bfs(metro_connection,start_station,end_station)
    bfs_distance=get_distance_of_path(bfs_path)
    print("BFS pathes is ",bfs_path,'\n',"BFS pathes distance is",bfs_distance,'\n\n')

    
    opt_path=search_opt(metro_connection,start_station,end_station,search_strategy=sort_by_distance)
    opt_distance=get_distance_of_path(opt_path)
    print("Optimized pathes is ",opt_path,'\n',"Optimized pathes distance is",opt_distance,'\n\n\n')
    
      
    
    return


if __name__ == "__main__":
    main()   
