In [None]:
import osmnx as ox
import networkx as nx
import igraph as ig

from roc_bike_growth.loader import POI_graph_from_polygon, bike_infra_from_polygon, carall_from_polygon
from roc_bike_growth.paper_gt import gt_from_scratch

class GTModel:
    def __init__(self, prune=0.1):
        self.prune = prune
        car_infra, bike_infra = self.get_data()
        self.all_infra = self.merge_infras(car_infra, bike_infra)
        self.gt_out = self.perform_gt()
        
    def get_data(self):
        rochester = ox.geocode_to_gdf('rochester, ny').geometry[0]
        bike_infra = bike_infra_from_polygon(rochester)
        car_infra = carall_from_polygon(rochester, add_pois=True)
        
        return car_infra, bike_infra
    
    def merge_infras(self, car_infra, bike_infra):
        car_infra = self._merge_infras_node(car_infra, bike_infra)
        car_infra = self._merge_infras_edge(car_infra, bike_infra)
        
        return car_infra
    
    def perform_gt(self):
        all_infra_ig, poi_ids = self._process_networkx_infra()
        roc = gt_from_scratch(all_infra_ig, poi_ids, prune_factor=self.prune)
        roc = ig.Graph.to_networkx(roc)
        
        return roc
    
    def _merge_infras_node(self, car_infra, bike_infra):
        existing_nodes = set(bike_infra.nodes())
        for bike_node in bike_infra.nodes:
            if bike_node not in existing_nodes:
                car_infra.add_node(bike_node)
                car_infra.nodes[bike_node].update(bike_infra.nodes[bike_node])
                
        return car_infra
    
    def _merge_infras_edge(self, car_infra, bike_infra):
        existing_edges = set(car_infra.edges)
        for bike_edge in bike_infra.edges:
            if bike_edge not in existing_edges:
                edge_1 = bike_edge[0]
                edge_2 = bike_edge[1]
                
                car_infra.add_edge(edge_1,edge_2)
                edge_a = {bike_edge: bike_infra[edge_1][edge_2][bike_edge[2]]}
                nx.set_edge_attributes(car_infra,edge_a)
                
        return car_infra
        
    def _process_networkx_infra(self):
        all_infra_ig = ig.Graph.from_networkx(self.all_infra)
        poi_ids = [v_index for v_index,vertex in enumerate(all_infra_ig.vs) if vertex['poi']]
        for edge in all_infra_ig.es:
            edge['weight'] = edge['length']
        for i,v in enumerate(all_infra_ig.vs):
            v['id'] = i
            
        return all_infra_ig, poi_ids
        