In [33]:
%pip install pandas numpy matplotlib plotly networkx pyDatalog

Note: you may need to restart the kernel to use updated packages.


In [34]:
# Boilerplate for AI Assignment — Knowledge Representation, Reasoning and Planning
# CSE 643

# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import networkx as nx
from pyDatalog import pyDatalog
from collections import defaultdict, deque

## ****IMPORTANT****
## Don't import or use any other libraries other than defined above
## Otherwise your code file will be rejected in the automated testing

In [35]:
# ------------------ Global Variables ------------------
route_to_stops = defaultdict(list)  # Mapping of route IDs to lists of stops
trip_to_route = {}                   # Mapping of trip IDs to route IDs
stop_trip_count = defaultdict(int)    # Count of trips for each stop
fare_rules = {}                      # Mapping of route IDs to fare information
merged_fare_df = None                # To be initialized in create_kb()

# Load static data from GTFS (General Transit Feed Specification) files
df_stops = pd.read_csv('GTFS/stops.txt')
df_routes = pd.read_csv('GTFS/routes.txt')
df_stop_times = pd.read_csv('GTFS/stop_times.txt')
df_fare_attributes = pd.read_csv('GTFS/fare_attributes.txt')
df_trips = pd.read_csv('GTFS/trips.txt')
df_fare_rules = pd.read_csv('GTFS/fare_rules.txt')

In [41]:
# ------------------ Function Definitions ------------------

# Function to create knowledge base from the loaded data
def create_kb():
    """
    Create knowledge base by populating global variables with information from loaded datasets.
    It establishes the relationships between routes, trips, stops, and fare rules.
    
    Returns:
        None
    """
    global route_to_stops, trip_to_route, stop_trip_count, fare_rules, merged_fare_df

    df_stops['stop_id'] = df_stops['stop_id'].astype(str)
    df_stop_times['trip_id'] = df_stop_times['trip_id'].astype(str)
    df_stop_times['stop_id'] = df_stop_times['stop_id'].astype(str)
    df_routes['route_id'] = df_routes['route_id'].astype(str)
    df_trips['trip_id'] = df_trips['trip_id'].astype(str)
    df_trips['route_id'] = df_trips['route_id'].astype(str)
    df_fare_rules['fare_id'] = df_fare_rules['fare_id'].astype(str)
    df_fare_rules['route_id'] = df_fare_rules['route_id'].astype(str)
    df_stop_times['arrival_time'] = pd.to_datetime(df_stop_times['arrival_time'], format='%H:%M:%S', errors='coerce')
    df_stop_times['departure_time'] = pd.to_datetime(df_stop_times['departure_time'], format='%H:%M:%S', errors='coerce')

    for tmp, row in df_trips.iterrows():
        trip_to_route[row['trip_id']] = row['route_id']

    for tmp, row in df_stop_times.iterrows():
        route_id = trip_to_route.get(row['trip_id'])
        if route_id:
            if route_id not in route_to_stops:
                route_to_stops[route_id] = []
            route_to_stops[route_id].append((row['stop_sequence'], row['stop_id']))
            stop_trip_count[row['stop_id']] += 1

    for route_id, stops in route_to_stops.items():
        if all(isinstance(stop, tuple) and len(stop) == 2 for stop in stops):
            unique_stops = sorted(set(stops), key=lambda x: x[0])
            route_to_stops[route_id] = [stop_id for _, stop_id in unique_stops]
        else:
            print(f"Unexpected structure in stops for route_id {route_id}: {stops}")

    fare_rules = df_fare_rules.set_index('route_id').T.to_dict()

    merged_fare_df = pd.merge(df_fare_rules, df_fare_attributes, on='fare_id', how='inner')

In [42]:
# Function to find the top 5 busiest routes based on the number of trips
def get_busiest_routes():
    """
    Identify the top 5 busiest routes based on trip counts.

    Returns:
        list: A list of tuples, where each tuple contains:
              - route_id (int): The ID of the route.
              - trip_count (int): The number of trips for that route.
    """
    route_trip_count = defaultdict(int)

    for trip_id, route_id in trip_to_route.items():
        route_trip_count[route_id] += 1

    busiest_routes = sorted(route_trip_count.items(), key=lambda x: x[1], reverse=True)[:5]

    return busiest_routes

# Function to find the top 5 stops with the most frequent trips
def get_most_frequent_stops():
    """
    Identify the top 5 stops with the highest number of trips.

    Returns:
        list: A list of tuples, where each tuple contains:
              - stop_id (int): The ID of the stop.
              - trip_count (int): The number of trips for that stop.
    """
    most_frequent_stops = sorted(stop_trip_count.items(), key=lambda x: x[1], reverse=True)[:5]

    return most_frequent_stops

# Function to find the top 5 busiest stops based on the number of routes passing through them
def get_top_5_busiest_stops():
    """
    Identify the top 5 stops with the highest number of different routes.

    Returns:
        list: A list of tuples, where each tuple contains:
              - stop_id (int): The ID of the stop.
              - route_count (int): The number of routes passing through that stop.
    """
    stop_to_routes = defaultdict(set)

    for route_id, stops in route_to_stops.items():
        for stop_id in stops:
            stop_to_routes[stop_id].add(route_id)

    stop_route_count = {stop_id: len(routes) for stop_id, routes in stop_to_routes.items()}

    top_5_busiest_stops = sorted(stop_route_count.items(), key=lambda x: x[1], reverse=True)[:5]

    return top_5_busiest_stops

# Function to identify the top 5 pairs of stops with only one direct route between them
def get_stops_with_one_direct_route():
    """
    Identify the top 5 pairs of consecutive stops (start and end) connected by exactly one direct route. 
    The pairs are sorted by the combined frequency of trips passing through both stops.

    Returns:
        list: A list of tuples, where each tuple contains:
              - pair (tuple): A tuple with two stop IDs (stop_1, stop_2).
              - route_id (int): The ID of the route connecting the two stops.
    """
    stop_pair_to_route = defaultdict(list)

    for route_id, stops in route_to_stops.items():
        for i in range(len(stops) - 1):
            stop_pair = (stops[i], stops[i + 1])
            reverse_pair = (stops[i + 1], stops[i])
            stop_pair_to_route[stop_pair].append(route_id)
            stop_pair_to_route[reverse_pair].append(route_id)

    result = []
    for stop_pair, routes in stop_pair_to_route.items():
        if len(routes) == 1:
            stop_1, stop_2 = stop_pair
            combined_trip_count = stop_trip_count[stop_1] + stop_trip_count[stop_2]
            result.append((stop_pair, routes[0], combined_trip_count))

    result_sorted = sorted(result, key=lambda x: x[2], reverse=True)

    top_5_pairs = [(pair, route_id) for pair, route_id, _ in result_sorted[:5]]
    return top_5_pairs

# Function to get merged fare DataFrame
# No need to change this function
def get_merged_fare_df():
    """
    Retrieve the merged fare DataFrame.

    Returns:
        DataFrame: The merged fare DataFrame containing fare rules and attributes.
    """
    global merged_fare_df
    return merged_fare_df

# Visualize the stop-route graph interactively
def visualize_stop_route_graph_interactive(route_to_stops):
    """
    Visualize the stop-route graph using Plotly for interactive exploration.

    Args:
        route_to_stops (dict): A dictionary mapping route IDs to lists of stops.

    Returns:
        None
    """
    G = nx.Graph()

    for route_id, stops in route_to_stops.items():
        for i in range(len(stops) - 1):
            G.add_edge(stops[i], stops[i + 1], route=route_id)

    pos = nx.spring_layout(G)

    edge_x = []
    edge_y = []
    edge_text = []

    for edge in G.edges(data=True):
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)
        edge_text.append(f"Route: {edge[2]['route']}")

    node_x = []
    node_y = []
    node_text = []
    for node in G.nodes:
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(f"Stop ID: {node}")

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='text',
        text=edge_text,
        mode='lines'
    )

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            size=10,
            color='#00bfff',
            line_width=2),
        text=node_text
    )

    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Stop-Route Graph',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=0, l=0, r=0, t=40),
                        xaxis=dict(showgrid=False, zeroline=False),
                        yaxis=dict(showgrid=False, zeroline=False))
                    )

    # Save as HTML for viewing in a browser
    fig.write_html("stop_route_graph.html")
    print("Plot saved as 'stop_route_graph.html'. Open this file in a browser to view the interactive plot.")
    
    fig.show()

# REASONING

In [43]:
# Run the Knowledge Base creation
create_kb()  # Ensure this line is executed to populate route_to_stops

Unexpected structure in stops for route_id 142: [146, 148, 149, 488, 233, 915, 916, 2161, 2162, 3569, 2163, 2164, 2165, 2166, 2167, 2168, 2169, 2170, 2171, 2172, 2173, 2174, 2175, 2176, 2177]
Unexpected structure in stops for route_id 10001: [3928, 3929, 3930, 20001, 20002, 20003, 20004, 20005, 545, 1998, 1999, 2000, 2001, 998, 999, 2149, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 1437, 1438, 1439, 1440, 951, 952, 953, 954, 955, 293, 195, 196, 197, 2559, 199, 1879, 201, 202, 203, 1880, 22509, 4310, 1883, 1885, 1886, 1887, 2573, 1888, 1889, 1890, 1891, 1892, 1177, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 2799, 2800, 2801, 4284, 65, 66, 67, 1604, 1605, 1305, 22510]
Unexpected structure in stops for route_id 362: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 756, 19, 20, 21, 1145, 1147, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1320, 1321, 1322, 1323, 1324, 1325, 1326, 1327, 1328, 1329, 1330, 22511, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 22512, 1340, 1341, 1342

  fare_rules = df_fare_rules.set_index('route_id').T.to_dict()


In [39]:
# Brute-Force Approach for finding direct routes (ignoring direction)
def direct_route_brute_force(start_stop, end_stop):
    """
    Find all valid routes between two stops using a brute-force method, ignoring direction.

    Args:
        start_stop (int): The ID of the starting stop.
        end_stop (int): The ID of the ending stop.

    Returns:
        list: A list of route IDs (int) that connect the two stops.
    """
    direct_routes = []

    for route_id, stops in route_to_stops.items():
        # Check if both stops are in the list of stops for this route
        if start_stop in stops and end_stop in stops:
            direct_routes.append(route_id)

    return direct_routes


In [40]:
## Testing Direct Route Brute Force

test_inputs = {
    "direct_route": [
        ((2573, 1177), [10001, 1117, 1407]),  # Input -> Expected output
        ((2001, 2005), [10001, 1151])
    ]
}

def check_output(expected, actual):
    """Function to compare expected and actual outputs."""
    if isinstance(expected, list) and isinstance(actual, list):
        return sorted(expected) == sorted(actual)  # Ensures order-independent comparison
    return expected == actual  # For non-list types

def test_direct_route_brute_force():
    for (start_stop, end_stop), expected_output in test_inputs["direct_route"]:
        actual_output = direct_route_brute_force(start_stop, end_stop)
        print(f"Test direct_route_brute_force ({start_stop}, {end_stop}): ", 
              "Pass" if check_output(expected_output, actual_output) else f"Fail (Expected: {expected_output}, Got: {actual_output})")
        
test_direct_route_brute_force()

Test direct_route_brute_force (2573, 1177):  Pass
Test direct_route_brute_force (2001, 2005):  Pass
