## Load Libraries

In [None]:
import numpy as np
import os
import pyarrow
import sys
import json
import math
import mpl_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import polars as pl
import xml.etree.ElementTree as ET

from xopen import xopen
from datetime import time

## Set-up Directories

In [None]:
general_directory = '/Users/andre/Desktop/Cergy/'

berlin_driectory = 'MATSim/matsim-berlin/input/v6.4/'

pt_10pct_dir = 'Python_Scripts/runs/pt_10pct/'

# supply paths
NETWORK_PATH = (os.path.join(general_directory, berlin_driectory, "berlin-v6.4-network.xml.gz"))
VEHICLE_PATH = (os.path.join(general_directory, berlin_driectory, "berlin-v6.4-vehicleTypes.xml"))

# demand path
NETWORK_PATH = (os.path.join(general_directory, berlin_driectory, "berlin-v6.4.output_plans.xml.gz"))

# metropolis path
METRO_INPUT = (os.path.join(general_directory, pt_10pct_dir, "metro_inputs/"))
METRO_OUTPUT = (os.path.join(general_directory, pt_10pct_dir, "metro_outputs/"))


In [None]:
def hhmmss_str_to_seconds_expr(col: str) -> pl.Expr:
    return (
        pl.col(col)
        .map_elements(
            lambda t: sum(x * m for x, m in zip(map(int, str(t).split(":")), [3600, 60, 1]))
            if isinstance(t, str) and ":" in t else None,
            return_dtype=pl.Int32
        )
        .alias(f"{col}_secs")
    )

# Config and parameters

In [None]:
POPULATION_SHARE = 0.10 # 10% of total population to match MATSim Berlin

# Parameters to use for the simulation.
PARAMETERS ={
    "input_files": {
      "agents": (os.path.join(METRO_INPUT, "agents.parquet")) ,
      "alternatives": (os.path.join(METRO_INPUT, "alts.parquet")),
      "trips": (os.path.join(METRO_INPUT, "trips.parquet")),
      "edges": (os.path.join(METRO_INPUT, "edges.parquet")),
      "vehicle_types": (os.path.join(METRO_INPUT, "vehicles.parquet"))
                },
    "output_directory": METRO_OUTPUT,
    "period": [0.0, 86400.0],
    "road_network": {
        "recording_interval": 950.0,
        "approximation_bound": 1.0,
        "spillback": True,
        "backward_wave_speed": 15.0,
        "max_pending_duration": 30.0,
        "constrain_inflow": True,
        "algorithm_type": "Best"
    },
    "learning_model": {
      "type": "Linear"
    },
    "init_iteration_counter": 1,
    "max_iterations": 1,
    "update_ratio": 1.0,
    "random_seed": 13081996,
    "nb_threads": 16,
    "saving_format": "Parquet",
    "only_compute_decisions": False
}

# Read MATSim output

# Supply

## Vehicles

In [None]:
def vehicle_reader(vehcile_path):
    tree = ET.iterparse(xopen(vehcile_path, "r"), events=["start", "end"])
    vehicle_types = []
    current_vehicle_type = {}
    is_parsing_vehicle_type = False
    for xml_event, elem in tree:
        _, _, elem_tag = elem.tag.partition("}")  # Removing xmlns tag from tag name
        # VEHICLETYPES
        if elem_tag == "vehicleType" and xml_event == "start":
            parse_attributes(elem, current_vehicle_type)
            is_parsing_vehicle_type = True
        # ATTRIBUTES
        elif elem_tag == "attribute" and xml_event == "start":
            current_vehicle_type[elem.attrib["name"]] = elem.text
        # LENGTH / WIDTH
        elif elem_tag in ["length", "width"] and xml_event == "start":
            current_vehicle_type[elem_tag] = elem.attrib["meter"]
        # VEHICLETYPES
        elif elem_tag == "vehicleType" and xml_event == "end":
            vehicle_types.append(current_vehicle_type)
            current_vehicle_type = {}
            elem.clear()
            is_parsing_vehicle_type = False
        # EVERYTHING ELSE
        elif is_parsing_vehicle_type and elem_tag not in ["attribute", "length", "width"]:
            parse_attributes(elem, current_vehicle_type)
    vehicle_types = pd.DataFrame.from_records(vehicle_types)
    col_types = {
        "accessTimeInSecondsPerPerson": float,
        "egressTimeInSecondsPerPerson": float,
        "seats": int,
        "standingRoomInPersons": int,
        "length": float,
        "width": float,
        "pce": float,
        "factor": float,
    }
    for col, dtype in col_types.items():
        if col in vehicle_types.columns:
            try:
                vehicle_types[col] = vehicle_types[col].astype(dtype)
            except:
                print(f"dataframe types conversion failed for column {col}")
    return vehicle_types

## Network

In [None]:
def read_network(network_path):
    tree = ET.iterparse(xopen(network_path, "r"), events=["start", "end"])
    links = []
    
    for xml_event, elem in tree:
        
        
                
        if elem.tag == "link" and xml_event == "start":
            atts = elem.attrib
            
            # Remove '#' from link_id
            atts["link_id"] = atts["id"].replace("#", "")
            atts["numeric_link_id"] = int(atts["id"].split("#")[0])            
            
            atts["from_node"] = atts.pop("from")
            atts["to_node"] = atts.pop("to")
             
            if "cluster" in atts["from_node"]:
                atts["from_node"] = atts["from_node"].replace("cluster_", "").split("_")[0]
            if "cluster" in atts["to_node"]:
                atts["to_node"] = atts["to_node"].replace("cluster_", "").split("_")[0]
            
            
            atts["length"] = float(atts["length"])
            atts["freespeed"] = float(atts["freespeed"])
            atts["capacity"] = float(atts["capacity"])
            atts["permlanes"] = float(atts["permlanes"])
            
            if "volume" in atts:
                atts["volume"] = float(atts["volume"])
                
            links.append(atts)
            
        # clear the element when we're done, to keep memory usage low
        if elem.tag in ["node", "link"] and xml_event == "end":
            elem.clear()
            
    links = pd.DataFrame.from_records(links)
    links = links.loc[links["modes"].str.contains("car")].copy()
    links["link_id"] = links["link_id"].astype(int)
    links["from_node"] = links["from_node"].astype(int)
    links["to_node"] = links["to_node"].astype(int)
    
    
    
    node_pair_counts = links[["from_node", "to_node"]].value_counts()
    if node_pair_counts.max() > 2:
        print("More than two parallel edges")
        
    parallel_idx = node_pair_counts.loc[node_pair_counts > 1].index
    if len(parallel_idx):
        print("Found {} parallel edges".format(len(parallel_idx)))
        next_node_id = max(links["from_node"].max(), links["to_node"].max()) + 1
        next_link_id = links["link_id"].max() + 1
        new_rows = list()
        for (source, target) in parallel_idx:
            mask = (links["from_node"] == source) & (links["to_node"] == target)
            idx = mask[mask].index
            row = links.loc[idx[1]].copy()
            row["length"] = 0.0
            row["from_node"] = next_node_id
            row["link_id"] = next_link_id
            new_rows.append(row)
            links.loc[idx[1], "to_node"] = next_node_id
            next_link_id += 1
            next_node_id += 1
        links = pd.concat((links, pd.DataFrame(new_rows)))
        
    return links

# Demand

### Read `output_plans`

In [None]:
def plan_reader_dataframe(plan_path, selected_plans_only=True):
    
    tree = ET.iterparse(xopen(plan_path), events=["start", "end"])
    persons = []
    plans = []
    activities = []
    legs = []
    routes = []
    current_person = {}
    current_plan = {}
    current_activity = {}
    current_leg = {}
    current_route = {}
    current_person = {}
    current_plan = {}
    current_activity = {}
    current_leg = {}
    current_route = {}

    is_parsing_person = False
    is_parsing_activity = False
    is_parsing_leg = False
    is_selected_plan = True

    current_person_id = None
    current_plan_id = 0
    current_activity_id = 0
    current_leg_id = 0
    current_route_id = 0

    for xml_event, elem in tree:
        if xml_event == "start":
            
            if elem.tag == "person":
                current_person["id"] = elem.attrib["id"]
                current_person_id = elem.attrib["id"]
                is_parsing_person = True
            
            # PLAN
            if elem.tag == "plan":
                is_selected_plan = not selected_plans_only or elem.attrib.get("selected", "no") == "yes"
                if not is_selected_plan:
                    continue
                current_plan["id"] = current_plan_id
                current_plan["person_id"] = current_person_id
                current_plan_id += 1
                parse_attributes(elem, current_plan)

            # ACTIVITY
            elif elem.tag == "activity" and is_selected_plan:
                is_parsing_activity = True
                current_activity_id += 1
                current_activity["id"] = current_activity_id
                current_activity["plan_id"] = current_plan_id-1
                parse_attributes(elem, current_activity)

            # LEG
            elif elem.tag == "leg" and is_selected_plan:
                is_parsing_leg = True
                current_leg_id += 1
                current_leg["id"] = current_leg_id
                current_leg["plan_id"] = current_plan_id-1
                parse_attributes(elem, current_leg)

            # ROUTE
            elif elem.tag == "route" and is_selected_plan:
                current_route_id += 1
                current_route["id"] = current_route_id
                current_route["leg_id"] = current_leg_id
                parse_attributes(elem, current_route)
        
        elif xml_event == "end":
            
        # PERSON
            if elem.tag == "person":
                persons.append(current_person)
                current_person = {}
                is_parsing_person = False

        
        # PLAN 
            elif elem.tag == "plan" and is_selected_plan:
                plans.append(current_plan)
                current_plan = {}
                
         # ACTIVITY
            elif elem.tag == "activity" and is_parsing_activity and is_selected_plan:
                activities.append(current_activity)
                current_activity = {}
                is_parsing_activity = False
                
        # LEG
            elif elem.tag == "leg" and is_parsing_leg and is_selected_plan:
                legs.append(current_leg)
                current_leg = {}
                is_parsing_leg = False
                
        # ROUTE
            elif elem.tag == "route" and is_selected_plan:
                current_route["value"] = elem.text
                routes.append(current_route)
                current_route = {}
                
            elif elem.tag == "attribute":
                attribs = elem.attrib
                if is_parsing_activity and is_selected_plan:
                    current_activity[attribs["name"]] = elem.text
                elif is_parsing_leg and is_selected_plan:
                    current_leg[attribs["name"]] = elem.text
                elif is_parsing_person:
                    current_person[attribs["name"]] = elem.text
            elem.clear()


    # Convert to DataFrames
    return (
        pl.DataFrame(persons),
        pl.DataFrame(plans),
        pl.DataFrame(activities),
        pl.DataFrame(legs),
        pl.DataFrame(routes)
    )

# Convert MATSim output to Metropolis input

# Supply

## Vehicles

In [None]:
def make_vehicles_df(vehicle_types):
    vehicle_list = []

    for idx, row in vehicle_types.iterrows():
        if row["id"] == "ride":
            vehicle = {
                "vehicle_id": idx,
                "vehicle_type": row["id"],
                "headway": float(row["length"]),
                "pce": 0.0,
                "speed_function.type": "Base",
                "speed_function.upper_bound": None,
                "speed_function.coef": None,
            }
        else:
            vehicle = {
                "vehicle_id": idx,
                "vehicle_type": row["id"],
                "headway": float(row["length"]),
                "pce": float(row["pce"]) / POPULATION_SHARE, 
                "speed_function.type": "Base",
                "speed_function.upper_bound": None,
                "speed_function.coef": None,
            }
        vehicle_list.append(vehicle)

    vehicles = pl.DataFrame(vehicle_list)
    vehicles = vehicles.filter(pl.col('vehicle_id')<5) # vehicles in the sim: car, bike, ride, truck, freight
    return vehicles

## Network

In [None]:
def make_edges_df(links, alpha=1): # lower alpha for increased congestion
    edge_list = []

    for i, (_, row) in enumerate(links.iterrows()):
        edge = {
            "edge_id": i+1,
            "MATSim_id": row["id"],
            "source": int(row["from_node"]),
            "target": int(row["to_node"]),
            "speed": float(row["freespeed"]),
            "length": float(row["length"]),
            "lanes": float(row["permlanes"]),
            "speed_density.type": "FreeFlow",
            "speed_density.capacity": None,
            "speed_density.min_density": None,
            "speed_density.jam_density": None,
            "speed_density.jam_speed": None,
            "speed_density.beta": None,
            # capacity per lane in vehicles per hour bc MATSim gives capacity per edge (all lanes included) while Metropolis considers 1 lane = 1 edge
            "bottleneck_flow": float(row["capacity"])*alpha/ (row['permlanes']*3600.0),  
            # MATSim handles time as 1-second steps. METROPOLIS handles time as a continuous -> We add a supplementary second to the tt
            "constant_travel_time": math.ceil(float(row["length"]) / float(row["freespeed"])) - float(row["length"]) / float(row["freespeed"]),
            "overtaking": True
        }
        edge_list.append(edge)

    edges = pl.DataFrame(edge_list)
    return edges

# Demand

In [None]:
def generate_sequence (activities, legs, routes):
    legs = (legs
        .join(routes, how='left', left_on='id', right_on='leg_id')
        .with_columns(pl.col('id').alias('leg_id'),
                      hhmmss_str_to_seconds_expr("dep_time")))

    activities = (activities
                  .drop(["facility", "initialEndTime", "orig_duration"])
                  .with_columns([
                      hhmmss_str_to_seconds_expr("end_time"),
                      hhmmss_str_to_seconds_expr("max_dur")#,
                      #hhmmss_str_to_seconds_expr("trav_time")
                  ])
                 )
    
    # pair seq IDs for activities
    activities = activities.with_columns([
        ((pl.cum_count("plan_id").over("plan_id") - 1) * 2).alias('seq_index'),
        pl.lit('activity').alias('element_type'),
        pl.col('type').alias('type_or_mode'),
        hhmmss_str_to_seconds_expr("max_dur").cast(pl.Float64).alias("duration"),
        pl.col('link').alias('route'),
        pl.col('link').alias('start_link'),
        pl.col('link').alias('end_link')
    ])

    # odd seq IDs for legs
    legs = legs.with_columns([
        ((pl.cum_count("plan_id").over("plan_id") - 1) * 2 + 1).alias('seq_index'),
        pl.lit('leg').alias('element_type'),
        pl.col('mode').alias('type_or_mode'),
        pl.col('trav_time').alias('duration'),
        pl.col('value').alias('route')
                             ])
    activities_secs = activities.select([
    "plan_id",
    ((pl.cum_count("plan_id").over("plan_id") - 1) * 2).alias("seq_index"),
    "end_time_secs",
    "max_dur_secs",
    pl.lit(None).cast(pl.Int32).alias("dep_time_secs"),
    pl.lit(None).cast(pl.Float64).alias("trav_time_secs")
    ])
    
    legs_secs = legs.select([
    "plan_id",
    ((pl.cum_count("plan_id").over("plan_id") - 1) * 2 + 1).alias("seq_index"),
    pl.lit(None).cast(pl.Int32).alias("end_time_secs"),
    pl.lit(None).cast(pl.Int32).alias("max_dur_secs"),
    "dep_time_secs",
    pl.col("trav_time").alias('trav_time_secs')
    ])
    
    
    extra_cols = pl.concat([activities_secs, legs_secs])
    clean_cols = ["plan_id", "seq_index", "element_type", "type_or_mode", "start_link",
                  "end_link", "route", "duration"]

    activities_clean = activities.select(clean_cols)
    legs_clean = legs.select(clean_cols)
    
    matsim_trips = pl.concat([activities_clean, legs_clean]).sort(['plan_id', 'seq_index'])

    matsim_trips = matsim_trips.with_columns([
        # Indicate if activity is not interaction
        ((pl.col('element_type') == 'leg'))
        .cast(pl.Int8).alias('is_trip_start')
    ])

    matsim_trips = (matsim_trips
                    .with_columns([pl.col('is_trip_start').cum_sum().over('plan_id').alias('trip_id')])
                    .drop('is_trip_start')
                    .join(extra_cols, on=["plan_id", "seq_index"], how="left")
        )
    
    # Record start and end times for activities and legs
    matsim_trips = matsim_trips.with_columns([
    pl.col("dep_time_secs").shift(1).alias("prev_leg_dep_secs"),
    pl.col("trav_time_secs").shift(1).alias("prev_leg_trav_secs"),
    ])

    # Activity duration
    matsim_trips = matsim_trips.with_columns([
        pl.when((pl.col("element_type") == "activity") & pl.col("max_dur_secs").is_not_null())
          .then(pl.col("max_dur_secs"))

        .when((pl.col("element_type") == "activity") &
              pl.col("end_time_secs").is_not_null() &
              pl.col("prev_leg_dep_secs").is_not_null() &
              pl.col("prev_leg_trav_secs").is_not_null())
          .then(pl.col("end_time_secs") - (pl.col("prev_leg_dep_secs") + pl.col("prev_leg_trav_secs")))

        .otherwise(None)
        .alias("activity_duration_secs")
    ])
    # Gather "activity_duration" and "travel_time" into a single variable
    matsim_trips = matsim_trips.with_columns([
        pl.when(pl.col("element_type") == "activity")
          .then(pl.col("activity_duration_secs"))
          .when(pl.col("element_type") == "leg")
          .then(pl.col("trav_time_secs"))
          .otherwise(None)
          .alias("duration")
    ])

    # get arrival time for legs
    matsim_trips = matsim_trips.with_columns((pl.col('dep_time_secs')+pl.col('duration')).alias('arrival_time'))
    
    # Start times
    matsim_trips = matsim_trips.with_columns([
        pl.when(pl.col("element_type") == "leg")
          .then(pl.col("dep_time_secs"))

        .when((pl.col("element_type") == "activity") & pl.col("end_time_secs").is_not_null())
          .then(pl.col("end_time_secs") - pl.col("duration"))

        .when((pl.col("element_type") == "activity") & pl.col("prev_leg_dep_secs").is_not_null())
          .then(pl.col("prev_leg_dep_secs") + pl.col("prev_leg_trav_secs"))

        .otherwise(None)
        .alias("start_time_secs")
    ])

    # End times
    matsim_trips = (
        matsim_trips
        .with_columns([
        pl.when(pl.col("element_type") == "leg")
          .then(pl.col("dep_time_secs") + pl.col("duration"))
        .when(pl.col("element_type") == "activity")
          .then(pl.col("start_time_secs") + pl.col("duration"))
        .otherwise(None)
        .alias("end_time_secs")])
        .join(plans.select(['id', 'person_id']), how='left', left_on='plan_id', right_on='id')
        .select(['person_id', "plan_id", "trip_id", "seq_index", "element_type", "type_or_mode", 
                                        "start_time_secs", "end_time_secs", "duration", 
                                        "route", "start_link", "end_link"])    # Select and rearrange variables

    )
    
    # Define tours
    # Record start and end times for activities and legs
    # look for activiy types with an end_time
    tour_anchor_types = (list(set(
        activities.filter(pl.col("end_time").is_not_null())
        .select("type").unique().to_series().to_list()
    )))

    # Add walking legs to separate walking legs in metropolis
    tour_anchor_types = list(set(tour_anchor_types))

    # Create a tour flag
    matsim_trips = matsim_trips.with_columns([
        pl.col("type_or_mode").is_in(tour_anchor_types)
        .alias("is_tour_anchor")
    ])

    # Create tours
    matsim_trips = matsim_trips.with_columns([
        pl.col("is_tour_anchor")
          .cast(pl.Int32)
          .cum_sum()
          .over("plan_id")
          .alias("tour_id")
    ])
    
    # Define stopping times
    stopping_time_df = (
        matsim_trips
        .filter(pl.col("element_type") == "activity")
        .with_columns([
            pl.col("duration").alias("stopping_time")
        ])
        .sort(['plan_id', 'trip_id'])
    )
    
    matsim_trips = (
    matsim_trips
    .filter(pl.col("element_type") == "leg")
    .rename({"start_time_secs":"start_time",
             "end_time_secs": "end_time",
             "type_or_mode":"mode"
            })
    .with_columns([
        # Travel_time per trip
        (pl.col("end_time") - pl.col("start_time")).alias("duration")
    ])
    .select([
        "person_id", "plan_id", "tour_id", "trip_id", "seq_index", "mode", "start_time", "end_time", 
        "duration", "route", "start_link", "end_link", "stopping_time"
    ])
    .sort(["plan_id", "trip_id", "tour_id"])
    )
    # Join stopping_time
    matsim_trips = matsim_trips.join(stopping_time_df, on=["plan_id", "trip_id"], how="left")
    
    
    return matsim_trips

Define the criteria to filter out certain observations from the metropolis input

In [None]:
def summarize_trips(matsim_trips):
    invalid_starts = (
        matsim_trips
        .filter(# no trips longer than 48 hours nor <0 activity duration
            (pl.col("duration") > 86400) | (pl.col("stopping_time") < 0)
        ) 
        .group_by("plan_id")
        .agg(pl.col("trip_id").min().alias("first_invalid_trip"))
    )

    trips_cleaned = (
        matsim_trips
        .join(invalid_starts, on="plan_id", how="left")
        .filter(
            (pl.col("first_invalid_trip").is_null()) |  
            (pl.col("trip_id") < pl.col("first_invalid_trip"))
        )
        .drop("first_invalid_trip", "duration_right", 'route_right', 'start_link_right', 'end_link_right',
              'person_id_right', 'tour_id_right', 'seq_index')
    )
    return trips_cleaned

In [None]:
def generate_trips(matsim_trips, edges, vehicles):
    
    # link (matsim) to edge (metro) dictionary
    matsim_to_metro_links = dict(zip(edges["MATSim_id"].cast(pl.Utf8), edges["edge_id"]))
    
    metro_trips = matsim_trips
    
    # class.vehicle
    metro_trips = (
        metro_trips
        .join(vehicles.select([
            pl.col("vehicle_type").alias("mode"),
            pl.col("vehicle_id").alias("class.vehicle")]),on="mode", how="left")
        .with_columns([

            # class.type
            pl.when(pl.col("mode").is_in(['truck', 'car', 'freight', 'ride']) # define Road trips
                   )
            .then(pl.lit("Road"))
            .otherwise(pl.lit("Virtual"))
            .alias("class.type")])
    )
    
    
    metro_trips = (
        
    # class.type
    metro_trips
    .rename({'start_time':'dt_choice.departure_time'
            })
    .with_columns([
                
    # class.routes
    pl.when(pl.col("class.type") == "Road")
      .then(pl.col("route").str.split(" ") # split route string
            
            # map in the dictionary
            .map_elements(lambda link_list: None if link_list is None
                          else [matsim_to_metro_links.get(link) for link in link_list[1:]],
                          return_dtype=pl.List(pl.Int64))
            .alias("class.route"))
      .otherwise(None),

        
    # class.travel_time
    pl.when(pl.col("class.type") == "Road")
      .then(None)
      .otherwise(pl.col("duration"))
    .alias("class.travel_time")
    ])
    .drop(['person_id' , 'route', 'duration', 'end_time', 'mode'])
    )
    
    # Define trips as starting at the target of the departure link and finsh at the target of the arrival_link
    # Join with edges for start_link's from and to nodes
    metro_trips = (
        metro_trips
        .join(
            edges.select([
                pl.col("MATSim_id").alias("start_link"),
                pl.col("target").alias("class.origin")]), # class.origin
            on="start_link",how="left")
        .drop(pl.col('start_link'))
        .join(
            edges.select([
                pl.col("MATSim_id").alias("end_link"),
                pl.col("target").alias("class.destination")]), # class.destination
            on="end_link", how="left")
        .drop(pl.col('end_link'))
    )
    
    
    metro_trips = (
        metro_trips
        .with_columns([
            pl.lit(1).alias("alt_id"),
            pl.lit("Constant").alias("dt_choice.type"),
            ((pl.col("plan_id")*100).cast(pl.Utf8)+ pl.col("tour_id").cast(pl.Utf8))
            .cast(pl.Int64).alias("agent_id")]) # agent_id ={plan_id*100;tour_id}
    )
    
    # Prep next trip for additional stopping times
    metro_trips = (
        metro_trips
        .with_columns([
            # Get class.type of next trip within each agent
            pl.col("class.type")
            .shift(-1)
            .over("agent_id")
            .alias("next_class_type")])
    )
    
        # Prep next trip for additional stopping times
    metro_trips = (
        metro_trips
        .with_columns([
        # Add +2 to stopping_time if the next trip is of type "Road"
            pl.when(
                pl.col("stopping_time").is_not_null() &
                (pl.col("next_class_type") == "Road")
            ) # +1 for person enters vehicle; +1 for 'vehicle_enters_trafic'
            .then(pl.col("stopping_time") + 2)         
            .otherwise(pl.col("stopping_time"))
            .alias("stopping_time")
        ])
        # Select columns
        .select(['agent_id', 'alt_id', 'trip_id',
                 'class.type', 'class.origin', 'class.destination', 'class.vehicle', 'class.route', 
                 'class.travel_time', 'stopping_time', 'dt_choice.type', 'dt_choice.departure_time'
                ])
    )
    
    # Set the minimal walking leg travel-time before a freight trip to 1 to match `output_events`
    freight_agents = metro_trips.filter(pl.col("class.vehicle") == 3).select("agent_id").unique()
    
    metro_trips = metro_trips.with_columns([
        pl.when(
            pl.col("agent_id").is_in(freight_agents["agent_id"]) &
            pl.col("class.travel_time").is_not_null()
        )
        .then(pl.max_horizontal([pl.col("class.travel_time"), pl.lit(1)])) # max (walking tt; 1) to avoid tt=0
        .otherwise(pl.col("class.travel_time"))
        .alias("class.travel_time")
    ])
            
                
    return metro_trips

# Format Metropolis input

## Supply

In [None]:
def format_supply(edges, vehicles):
    edges = edges.drop(["MATSim_id"])
    vehicles = vehicles.drop(["vehicle_type"])    
    return [edges, vehicles]

## Demand

In [None]:
def format_demand(trips):
    
    # format trips
    # Eliminate trips departing after 48 hours
    trips = trips.filter(pl.col("dt_choice.departure_time") <= 108000,
                         ~((pl.col("class.type") == "Road") &
                           (pl.col("class.origin").is_null()|pl.col("class.destination").is_null())
                          ))
            
    # format agents
    agents = trips.select("agent_id").unique().with_columns([
        pl.lit("Deterministic").alias("alt_choice.type"),
        pl.lit(0.0).alias("alt_choice.u"),
        pl.lit(None).alias("alt_choice.mu")
    ]).sort("agent_id")

    # format alts
    alts = (
        trips.sort("dt_choice.departure_time")
        .unique(subset=["agent_id"], keep="first")
        .select([
            "agent_id",
            "alt_id",
            pl.lit(None).alias("origin_delay"),
            pl.col("dt_choice.type"),
            "dt_choice.departure_time",

            pl.lit(None).alias("dt_choice.interval"),
            pl.lit(None).alias("dt_choice.model.type"),
            pl.lit(0.0).alias("dt_choice.model.u"),
            pl.lit(0.0).alias("dt_choice.model.mu"),
            pl.lit(None).alias("dt_choice.offset"),

            pl.lit(0.0).alias("constant_utility"),
            pl.lit(None).alias("total_travel_utility.one"),
            pl.lit(None).alias("total_travel_utility.two"),
            pl.lit(None).alias("total_travel_utility.three"),
            pl.lit(None).alias("total_travel_utility.four"),

            pl.lit(None).alias("origin_utility.type"),
            pl.lit(0.0).alias("origin_utility.tstar"),
            pl.lit(0.0).alias("origin_utility.beta"),
            pl.lit(0.0).alias("origin_utility.gamma"),
            pl.lit(0.0).alias("origin_utility.delta"),

            pl.lit(None).alias("destination_utility.type"),
            pl.lit(0.0).alias("destination_utility.tstar"),
            pl.lit(0.0).alias("destination_utility.beta"),
            pl.lit(0.0).alias("destination_utility.gamma"),
            pl.lit(0.0).alias("destination_utility.delta"),

            pl.lit(True).alias("pre_compute_route")
        ])
    )
    alts = alts.sort("agent_id")
    
    trips = trips.drop(["dt_choice.type", "dt_choice.departure_time"])

    
    return agents, alts, trips

# Write Metropolis Input

In [None]:
# Parameters
print("Writing Metropolis parameters")
with open(os.path.join(METRO_INPUT, "parameters.json"), "w") as f:
    f.write(json.dumps(PARAMETERS))

In [None]:
# Writing files
print("Writing Metropolis supply in ", METRO_INPUT)
edges_df.write_parquet(METRO_INPUT + "edges.parquet")
vehicles_df.write_parquet(METRO_INPUT + "vehicles.parquet")

In [None]:
# Formating
agents_df = format_demand(metro_trips)[0]
alts_df = format_demand(metro_trips)[1]
trips_df = format_demand(metro_trips)[2]