Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes agencies table #414

Merged
merged 7 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aequilibrae/paths/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def set_graph(self, cost_field) -> None:
if cost_field in self.graph.columns:
self.cost_field = cost_field
self.compact_cost = np.zeros(self.compact_graph.id.max() + 2, self.__float_type)
df = self.__graph_groupby.sum()[[cost_field]].reset_index()
df = self.__graph_groupby.sum(numeric_only=True)[[cost_field]].reset_index()
self.compact_cost[df.index.values] = df[cost_field].values
if self.graph[cost_field].dtype == self.__float_type:
self.cost = np.array(self.graph[cost_field].values, copy=True)
Expand Down
8 changes: 4 additions & 4 deletions aequilibrae/transit/transit_elements/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlite3 import Connection

from aequilibrae.project.database_connection import database_connection
from aequilibrae.transit.constants import Constants, WALK_AGENCY_ID
from aequilibrae.transit.constants import Constants
from aequilibrae.transit.transit_elements.basic_element import BasicPTElement


Expand Down Expand Up @@ -33,9 +33,9 @@ def save_to_database(self, conn: Connection) -> None:

def __get_agency_id(self):
with closing(database_connection("transit")) as conn:
sql = "Select coalesce(max(distinct(agency_id)), 0) from agencies where agency_id<?;"
data = [x[0] for x in conn.execute(sql, [WALK_AGENCY_ID])]
sql = "Select coalesce(max(distinct(agency_id)), 0) from agencies;"
max_db = int(conn.execute(sql).fetchone()[0])

c = Constants()
c.agencies["agencies"] = max(c.agencies.get("agencies", 1), data[0])
c.agencies["agencies"] = max(c.agencies.get("agencies", 0), max_db) + 1
return c.agencies["agencies"]
5 changes: 4 additions & 1 deletion tests/aequilibrae/transit/test_gtfs_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from os.path import join, dirname, abspath
from pathlib import Path

import pytest

import pandas as pd
Expand Down Expand Up @@ -26,7 +28,8 @@ def test_set_feed_path(gtfs_loader, gtfs_fldr):


def test_load_data(gtfs_loader, gtfs_fldr):
cap = pd.read_csv(join(abspath(dirname("tests")), "tests/data/gtfs/transit_max_speeds.txt"))
pth = Path(__file__).parent.parent.parent
cap = pd.read_csv(pth / "data/gtfs/transit_max_speeds.txt")

df = cap[cap.city == "Coquimbo"]
df.loc[df.min_distance < 100, "speed"] = 10
Expand Down
32 changes: 31 additions & 1 deletion tests/aequilibrae/transit/test_transit.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
from os.path import isfile, join

from aequilibrae.project.database_connection import database_connection
from aequilibrae.transit.constants import Constants


def test_new_gtfs_builder(create_gtfs_project, create_path):
c = Constants()
c.agencies["agencies"] = 0

conn = database_connection("transit")
existing = conn.execute("SELECT COALESCE(MAX(DISTINCT(agency_id)), 0) FROM agencies;").fetchone()[0]

transit = create_gtfs_project.new_gtfs_builder(
agency="LISERCO, LISANCO, LINCOSUR",
agency="Agency_1",
day="2016-04-13",
file_path=join(create_path, "gtfs_coquimbo.zip"),
)

assert str(type(transit)) == "<class 'aequilibrae.transit.lib_gtfs.GTFSRouteSystemBuilder'>"

transit2 = create_gtfs_project.new_gtfs_builder(
agency="Agency_2",
day="2016-07-19",
file_path=join(create_path, "gtfs_coquimbo.zip"),
)

transit.save_to_disk()
transit2.save_to_disk()

assert conn.execute("SELECT MAX(DISTINCT(agency_id)) FROM agencies;").fetchone()[0] == existing + 2

transit3 = create_gtfs_project.new_gtfs_builder(
agency="Agency_3",
day="2016-07-19",
file_path=join(create_path, "gtfs_coquimbo.zip"),
)

transit3.save_to_disk()
assert conn.execute("SELECT MAX(DISTINCT(agency_id)) FROM agencies;").fetchone()[0] == existing + 3


def test___create_transit_database(create_gtfs_project):
assert isfile(join(create_gtfs_project.project_base_path, "public_transport.sqlite")) is True