In [None]:
def get_date_from_str(str_date):
    return datetime.strptime(str_date, "%Y%m%d")


def get_bus_info_path(date, list_info_files):
    
    def date_to_path(date):
        return date.strftime("/%Y/%m/%d")
    
    path = pathlib.Path(BUSINFO_PATH + date_to_path(date))
    
    file_path = {
        p.name: str(p)
        for p 
        in path.iterdir()
        if p.name in list_info_files
    }
    
    for l in list_info_files:
        if not file_path[l]:
            file_path[l] = None
    
    return file_path


def get_prediction_path(date):
    def date_to_file_name(date):
        return date.strftime("/%Y_%m_%d.csv")
    
    path = pathlib.Path(PREDICTION_PATH + date_to_file_name(date))
    
    return str(path) if path.exists() else None


def read_dim_data(path, is_list_json=True):
    if is_list_json:
        vars_data = pd.read_json(path, lines=True).stack().reset_index(drop=True)

        vars_df = pd.io.json.json_normalize(vars_data[~vars_data.isna()])
    else:
        vars_df = pd.read_json(path, lines=True)
    
    return vars_df


def expand_column(stops_df, column):
    stops_df = stops_df.reset_index()

    stops_stag = pd.concat(
        [pd.DataFrame(d)
         for d 
         in stops_df[column]], 
        keys=stops_df["index"]
    ).reset_index(level=1, drop=True)

    stops_df = (
        stops_df.drop(["index"], axis=1)
            .join(stops_stag)
            .reset_index(drop=True)
    )
    
    return stops_df


def read_paths(path):
    paths_df = (
        read_dim_data(path, False)
            .set_index(["RouteId", "RouteVarId"])
            .apply(pd.Series.explode)
            .reset_index()
    )
    
    paths_df.loc[:, ["lat", "lng"]] = paths_df[["lat", "lng"]].astype(float)
    
    return paths_df


def read_stops(path):
    stops_df = expand_column(
        read_dim_data(path, False),
        "Stops"  # Column contains Stop data
    )
    
    return stops_df


def copy_data(data):
    
    return {key: value.copy() for key, value in data.items()}


def upper_columns_names(data):
    for k, v in data.items():
        v.columns = v.columns.str.upper()


def load_predictions(path, column_headers):
    print(path)
    return pd.read_csv(
        path,
        dtype="str",
        names=column_headers
    ).drop("UNAMED", axis=1)


def preprocessing_prediction(predic):
    predic_df = predic.copy()
    
    int_columns = ["STOPID", "ROUTEVARID", "ROUTEID"]
    float_columns = ["DISTANCE", "SPEED", "TIMETOSTOP"]
    
    predic_df[int_columns] = predic_df.loc[:, int_columns].astype("float").astype("int")
    predic_df[float_columns] = predic_df.loc[:, float_columns].astype("float")
    
    return predic_df


def get_stops_full_data(all_data):
    vars_pre = all_data["vars.json"].merge(
        all_data["routes.json"][["ROUTEID", "ROUTENAME"]],
        on="ROUTEID",
        how="left",
        suffixes=("", "_y")
    )
    
    all_data["full_vars"] = vars_pre
    
    stop_pre = all_data["stops.json"].merge(
        vars_pre,
        on=["ROUTEID", "ROUTEVARID"],
        how="left",
    ).drop("STOPS", axis=1).drop_duplicates()
    
    all_data["full_stops"] = stop_pre
    

def get_full_prediction_data(predic, stop_data):
    example = predic.copy()
    example_full = example.merge(
        stop_data,
        on=["STOPID", "ROUTEID", "ROUTEVARID"],
        how="left",
        suffixes=("", "_DIM")
    )
    
    return example_full

# List of functions to load data
BUS_INFO_FUNC = MappingProxyType({
    "timetables.json": functools.partial(read_dim_data, is_list_json=True),
    "routes.json": functools.partial(read_dim_data, is_list_json=True),
    "vars.json": functools.partial(read_dim_data, is_list_json=True),
    "trips.json": functools.partial(read_dim_data, is_list_json=True),
    "stops.json": functools.partial(
        read_stops
    ),
    "paths.json": functools.partial(
        read_paths
    ),
})


def get_all_bus_info(date, list_info_files):
    bus_data = {}
    
    bus_info_path = get_bus_info_path(date, list_info_files)
    
    for key, value in bus_info_path.items():
        print(f"Get data from {value}")
        bus_data[key] = BUS_INFO_FUNC[key](path=value)
        
    return bus_data


def remove_missing_records_on_column(data, cols):
    missing_rec_mask = data[cols].isna().any(axis=1)
    
    return data[~missing_rec_mask].copy()


def get_distance_ratio_each_route(stops_ddf):
    stops_distance = stops_ddf.copy()

    stops_distance = stops_distance.merge(
        stops_distance.groupby(["ROUTEID", "ROUTEVARID"]).RANK.max().to_frame("MAXRANK"),
        on=["ROUTEID", "ROUTEVARID"],
        how="left",
    )

    # For First Point
    pri_keys = ["ROUTEID", "ROUTEVARID"]
    value_cols = ["SPOINT"]
    first_stops = stops_distance.loc[stops_distance.RANK==1, pri_keys + value_cols].copy()
    stops_distance = stops_distance.merge(
        first_stops,
        on=pri_keys,
        how="left",
        suffixes=("", "FSPOINT")
    )

    # For Last Point
    pri_keys = ["ROUTEID", "ROUTEVARID"]
    value_cols = ["SPOINT"]
    first_stops = stops_distance.loc[stops_distance.RANK==stops_distance.MAXRANK, pri_keys + value_cols].copy()
    stops_distance = stops_distance.merge(
        first_stops,
        on=pri_keys,
        how="left",
        suffixes=("", "LSPOINT")
    )

    def _change_slice(rec):
        new_linestr = list(rec["LINESTRING"].coords)
        new_linestr[0] = rec["SPOINTFSPOINT"].coords[0]
        new_linestr[-1] = rec["SPOINTLSPOINT"].coords[0]

        return LineString(new_linestr)

    _ = stops_distance.apply(lambda x: _change_slice(x), axis=1)
    stops_distance["STOPS_LINESTRING"] = _
    stops_distance["STOPSDISTANCE"] = stops_distance.apply(lambda x: x["STOPS_LINESTRING"].project(x["SPOINT"]), axis=1)
    stops_distance["ALLSTOPSDISTANCE"] = stops_distance.apply(lambda x: x["STOPS_LINESTRING"].project(x["SPOINTLSPOINT"]), axis=1)
    stops_distance["RATIOSTOPSDISTANCE"] = stops_distance.eval("STOPSDISTANCE / ALLSTOPSDISTANCE")
    
    return stops_distance


def get_log_distance_ratio_predict(predic_df):
    # Convert to integer for easier to predict
    predic_df["TIMESTAMP"] = pd.to_datetime(predic_df["LOCTIMESTAMP"]).astype("int64") // 10**9

    # preprocessing some columns in predic data
    predic_df["BUSIDSTR"] = predic_df["BUSID"].str.strip()  #remove space
    predic_df["LOCTIMESTAMP"] = pd.to_datetime(predic_df["LOCTIMESTAMP"])  #convert to datetime
    predic_df["STOPIDSTR"] = predic_df["STOPID"].astype("str")  #convert to string
    predic_df["DISTANCERATIO"] = predic_df.eval("DISTANCE / DISTANCE_DIM")

    return predic_df


def remove_outlier_whole_trips(df, plot=False):
    center_data = df[df.FIRSTHALF==99].copy()
    
    data_half = {
        "first": df[df.FIRSTHALF==1].copy(),
        "center": center_data,
        "second": df[df.FIRSTHALF==2].copy(),
    }
    
    def _mark_outlier(data, is_reverse):
        outlier = mark_all_velocity_outlier(data["DISTANCE"], data["TIMESTAMP"], is_reverse=is_reverse, threshold=1)

        # Show detected outliers
        data["VREMOVE"] = outlier

        if plot:
            fig, ax = plt.subplots(figsize=(10, 8))
            _draw_time_distance(data, hue="VREMOVE", alpha=0.8)
            plt.show()

            # data after remove outliers
            fig, ax = plt.subplots(figsize=(10, 8))
            _draw_time_distance(data[data.VREMOVE==0], alpha=0.8)
            plt.show()
    
        return data[data.VREMOVE==0].copy()
    
    dt = []
    for i, v in data_half.items():
        if i == "center":
            dt.append(v)
        else:
            is_rv = True if i == "first" else False
            dt.append(_mark_outlier(v, is_rv))
            
    df = pd.concat(dt, axis=0, sort=False)

    return df

def remove_outliers_whole_route(df, plot=False):
    trips = df.TRIPS.unique()
    print("There are {} trips: {}".format(trips.shape[0], ", ".join(trips)))
    
    rs = []
    for t in trips:
        print(f"Remove outliers for trips: {t}")
        dt = df[df.TRIPS==t].copy()
        rs.append(remove_outlier_whole_trips(dt, plot))
        
    return pd.concat(rs, sort=False)

def remove_distance_not_increase(ds, n_iters=500):
    ds = ds.copy()
    msk = pd.Series(False, index=ds.index)
    
    for i in range(n_iters):
        delta_d = ds.shift() - ds
        invalid = delta_d <= 0
        if invalid.sum() <= 0:
            print("There is no decrease distance")
            break
        msk |= invalid
        ds = ds[~msk].copy()
    
    return msk.copy()

def interpolate_time(predic_df, stops_df, plot=False):
    msk = remove_distance_not_increase(predic_df.DISTANCE.copy())
    
    predic_df = predic_df[~msk].copy()
    x = 1 - predic_df["DISTANCERATIO"]
    y = predic_df["TIMESTAMP"].astype("int")
    y = y[y.index.isin(x.index)].copy()
    spl = InterpolatedUnivariateSpline(x, y, k=1)
    xs = stops_df["RATIOSTOPSDISTANCE"]
    if plot:
        fig, ax = plt.subplots(figsize=(10, 8))
        plt.plot(xs, spl(xs), 'gv')
        plt.plot(x, y, 'ro', ms=5, alpha=0.2, linewidth=6)
        plt.show()

    return spl(xs)

def predic_whole_route(route_data, stops, plot, cols=None, pref="ARRIVETIMETRIPS"):
    trips = route_data.TRIPS.unique()
    stops = stops.copy()

    ipl = []
    for t in trips:
        print(f"Predict for trips {t}")
        dta = route_data[route_data.TRIPS==t].copy()
        ipl.append(interpolate_time(dta, stops, plot))
        col_names = []
        
    for i, v in enumerate(ipl):
        cn = pref + "{:02d}".format(int(i))
        col_names.append(cn)
        stops[cn] = v
    
    col_pre = stops.columns[stops.columns.str.contains(pref)].tolist()
    for c in col_pre:
        print(f"Convert column name {c}")
        stops[c] = convert_timezone(stops[c])
     
    if cols:
        dt_rt = stops[cols + col_pre].copy()
    else:
        dt_rt = stops.copy()
    
    return dt_rt


def predict_whole_route_all_bus(predic, 
                                stops, 
                                route_id, 
                                routevar_id, 
                                cols,
                                plot=(False, False),
                                ratio_distance=0.8):
    
    stops_df = stops[(stops.ROUTEID==route_id) & (stops.ROUTEVARID==routevar_id)].copy()
    predic_df = predic[(predic.ROUTEID==route_id) & (predic.ROUTEVARID==routevar_id)].copy()
    
    bus = predic_df.BUSID.unique()
    
    time_tables = []
    for b in bus:
        try:
            bus_data = predic_df[predic_df.BUSID==b].copy()
            print(f"Predic for bus {b} with shape {bus_data.shape}")
            dt_removed = preprocessing_route_bus(bus_data, ratio_distance)
            dt_removed = remove_outliers_whole_route(dt_removed, plot[0])
            print(f"Shape after preprocessing {dt_removed.shape}")

            predicted = predic_whole_route(dt_removed, stops_df, plot[1], cols=cols)
            predicted["BUSPREDIC"] = b
        
            time_tables.append(predicted.copy())
        except Exception as e:
            print(str(e))
            pass

    return pd.concat(time_tables, axis=0, sort=True, ignore_index=True)


def mark_first_half(df, rate=(0.1, 0.9)):
    df["FIRSTHALF"] = np.nan
    df.loc[df.DISTANCE < (df.DISTANCE_DIM.max() * rate[0]), "FIRSTHALF"] = 2
    df.loc[df.DISTANCE > (df.DISTANCE_DIM.max() * rate[1]), "FIRSTHALF"] = 1
    df.loc[df.FIRSTHALF.isna(), "FIRSTHALF"] = 99


def preprocessing_route_bus(df, ratio_distance_peak=0.8):
    df = df.drop_duplicates(["TIMESTAMP"], keep="first").copy()
    df = df.sort_values(["TIMESTAMP"]).copy()
    df["IS_PEAK"] = mark_peak_point(df, df.DISTANCE.max(), ratio_distance_peak)
    df["TRIPS"] = split_trips(df, xffix=("TRIPS", ""))
    df = df.dropna(subset=["TRIPS"]).copy()
    mark_first_half(df)
    
    return df


def mark_peak_point(df, distance, ratio=1):
    dis = df["DISTANCE"].copy()
    
    shifted_dis = dis - dis.shift()
    msk_peak = shifted_dis > (distance * ratio)
    
    return msk_peak


# Tách các trips
def split_trips(df, cols="IS_PEAK", xffix=("", "")):
    trips = df[cols].cumsum().copy().astype("int").to_frame("PEAK")
    
    mean_log = trips.groupby("PEAK").size().mean() * 0.5
    drop_peak = trips.groupby("PEAK").size()
    drop_peak = drop_peak[drop_peak < mean_log]
    trips = trips[~trips.PEAK.isin(drop_peak.index.tolist())].copy()
    trips.PEAK = (trips.PEAK.rank(method="dense") - 1).astype("int").astype("str")
    trips = trips.PEAK
    
    trips = xffix[0] + trips.str.zfill(2) + xffix[1]
    
    return trips


def get_velocity(dfd, dft, is_reverse=False):
    if is_reverse:
        d = dfd[::-1].copy()
        t = dft[::-1].copy()
        dd = d - d.shift()
    else:
        d = dfd.copy()
        t = dft.copy()
        dd = d.shift() - d
    

    tt = (t - t.shift()).abs()
    
    return (dd / tt)[::-1] if is_reverse else (dd / tt)


def mark_velocity_outlier(v_col, threshold=1.5):
    msk_vel = v_col < threshold
    msk_rm = pd.Series(False, index=v_col.index)
    msk_rm[msk_vel] = True
    
    return msk_rm.copy()


def mark_all_velocity_outlier(dfd, dft, threshold=1, is_reverse=False, n_iters=5):
    index = dfd.index.copy()
    msk_vel = pd.Series(False, index=index)
    
    print(f"Starting to remove outlier with velocity threshold {threshold} "
          f"and max number of iterates is {n_iters} times")
    for i in range(n_iters):
        velocity = get_velocity(dfd[~msk_vel], dft[~msk_vel], is_reverse)
        start_msk_vel = mark_velocity_outlier(velocity, threshold)
        num_outlier = start_msk_vel.sum()
        print(f"Detected {num_outlier} outlier.")
        
        if num_outlier <= 0:  # exit loop if there is no outlier
            print("There is no outliers remain!")
            break
            
        dfd = dfd[~start_msk_vel].copy()
        dft = dft[~start_msk_vel].copy()
        
        msk_vel |= start_msk_vel

    return msk_vel


def _draw_time_distance(df, cols=("TIMESTAMP", "DISTANCE"), hue=None, alpha=0.2, size=None):
    # Plot example
    sns.scatterplot(data=df, x=cols[0], y=cols[1], hue=hue, alpha=alpha)
    

def convert_timezone(sr):
    return pd.to_datetime(
            sr, 
            unit="s"
        ).dt.tz_localize(
            "UTC"
        ).dt.tz_convert("Asia/Ho_Chi_Minh").dt.tz_localize(None)


def load_data(date):
    # Load data của các chuyến xe bus.
    print("Get bus info.")
    all_data_cp = get_all_bus_info(
        date, 
        list_info_files=list_bus_info_file_names
    )
    # Upper all column names
    upper_columns_names(all_data_cp)
    # Lấy thông tin (merge thông tin) từ nhiều file
    get_stops_full_data(all_data_cp)
    
    # read predictions file
    print("Get log data.")
    prediction_df = load_predictions(
        get_prediction_path(date),
        PREDICTION_COLUMN_HEADERS
    )
    # convert some columns from str to int, float or datetime
    prediction_df = preprocessing_prediction(prediction_df)
    # Lấy thông tin của Stops vào Prediction
    full_predic = get_full_prediction_data(prediction_df, all_data_cp["full_stops"])
    full_predic = remove_missing_records_on_column(full_predic, ["CODE"])
    full_predic = get_log_distance_ratio_predict(full_predic)
    
    return all_data_cp, full_predic


def get_linestring_all_path(path):
    paths_df = path.dropna().copy()

    paths_df["POINT"] = paths_df.apply(lambda x: Point(x["LNG"], x["LAT"]), axis=1)

    lines_df = (
        paths_df.groupby(["ROUTEID", "ROUTEVARID"])
            .POINT.apply(list)
            .apply(LineString)
            .to_frame("LINESTRING")
            .reset_index()
    )
    
    return lines_df


def get_data_processed(date):    
    all_data_cp, full_predic = load_data(date)
    
    # Rank data stops for each route
    print("Ranking stops sequences.")
    line_string = get_linestring_all_path(all_data_cp["paths.json"])
    stops_ddf = all_data_cp["stops.json"].copy()
    stops_ddf = stops_ddf.dropna(subset=["LNG", "LAT"]).copy()
    stops_ddf = stops_ddf.merge(
        line_string,
        on=["ROUTEID", "ROUTEVARID"],
        how="inner",
        suffixes=("", "LINEINFO")
    )

    stops_ddf["SPOINT"] = stops_ddf.apply(
        lambda x: Point(x["LNG"], x["LAT"]), axis=1
    )
    stops_ddf["DISTANCE"] = stops_ddf.apply(
        lambda x: x["LINESTRING"].project(x["SPOINT"]), 
        axis=1
    )
    stops_ddf["NEARESTPOINT"] = stops_ddf.apply(
        lambda x: nearest_points(x["LINESTRING"], x["SPOINT"])[0],
        axis=1,
    )
    stops_ddf["RANK"] = stops_ddf.groupby(["ROUTEID", "ROUTEVARID"]).DISTANCE.rank(method="first")

    # final data with ranking of stops id and distance ratio
    stops_ddf = stops_ddf.sort_values(["ROUTEID", "ROUTEVARID", "RANK"])
    print("Get Ratio distance for each stops.")
    stops_ddf = get_distance_ratio_each_route(stops_ddf).drop("STOPS", axis=1)
    
    return stops_ddf, full_predic


def generate_date_from_range(from_date, to_date):
    curr_date = from_date
    while curr_date <= to_date:
        yield curr_date
        curr_date = curr_date + timedelta(days=1)

        
def run_predict_all_route(pre, stop, list_route, plot=(False, False)):
    
    result = []
    for lr in list_route:
        print("Running for ROUTEID {} and ROUTEVARID {}".format(
            lr[0],
            lr[1],
        ))
        
        time_tables = predict_whole_route_all_bus(
            pre, 
            stop, 
            route_id=lr[0], 
            routevar_id=lr[1], 
            plot=plot,
            cols=cols_to_get, 
            ratio_distance=0.8
        )

        result.append(time_tables)
    
    return pd.concat(result, axis=0, sort=True, ignore_index=True)
         

def run_pipeline_multi_date(date_list, route_id=None, route_var_id=None, plot=(False, False)):
    
    result = {}
    for d in date_list:
        date_str = d.strftime("%Y-%m-%d")
        stop, pre = get_data_processed(d)
        
        if route_id and route_var_id:
            list_route = [(route_id, route_var_id)]
        else:
            list_route = (
                pre[["ROUTEID", "ROUTEVARID"]]
                    .drop_duplicates()
                    .values.tolist()
            )
        
        cols_to_get = ["ROUTEID", "ROUTEVARID", 
                       "STOPID", "CODE", "NAME", "STOPTYPE", 
                       "ZONE", "WARD", "ADDRESSNO", "STREET", "SUPPORTDISABILITY", 
                       "STATUS", "LNG", "LAT", 
                       "SEARCH", "ROUTES", "RANK"]
        print("Predict for date {}".format(date_str))
        
        time_tables = run_predict_all_route(
            pre,
            stop,
            list_route,
            plot,
        )

        result[date_str] = time_tables
        
    return result


def plot_route_stops(df, figsize=(20, 20), show_path=(True, True), show_real_stops=True, show_nearest_point=True,
                     annotation=False):
    geo_df = gpd.GeoDataFrame(df, geometry="SPOINT")
    # Example plot result
    fig, ax = plt.subplots(figsize=figsize)
    ax.grid(False)
    
    if any(show_path):
        if show_path[0]:
            gpd.GeoSeries(df.LINESTRING.values[0]).plot(ax=ax, figure=fig)
        if show_path[1]:
            gpd.GeoSeries([Point(i) for i in plot_example.LINESTRING.values[0].coords]).plot(
                ax=ax,
                figure=fig,
                color="b"
            )
            
    # stop_gpd.plot(ax=ax, figure=fig, color="r", marker="v")
    if annotation:
        for i, x in geo_df.iterrows():
            plt.annotate(s=x["RANK"], xy=(x["LNG"], x["LAT"]), horizontalalignment="right")
    
    if show_real_stops:
        df.set_geometry("SPOINT").copy().plot(ax=ax, figure=fig, color="r", alpha=0.8)
    
    if show_nearest_point:
        df.set_geometry("NEARESTPOINT").copy().plot(ax=ax, figure=fig, color="g", alpha=1, marker="v")
    
    list_route = df[["ROUTEID", "ROUTEVARID"]].drop_duplicates().values.tolist()
    
    
def _draw(route_id, route_var_id, date):
    str_date = date.strftime("%Y%m%d")
    plt.title("ROUTEID={}, ROUTEVARID={}, DATE={}".format(
        route_id,
        route_var_id,
        str_date,
    ))
    plt.show()

    
def plot_kde_stops(example_stop_list, stops_df, save_folder, plot=False):
    def _plot_wrapper(pl):
        if pl:
            plt.show()
        else:
            pass
    
    for s in example_stop_list:
        print(f"Save fig/Show for stop: {s}")
        example_stop_plot = stops_df[stops_df.STOPID == s].copy()
        stop_id = "ST_" + str(s)
        example_stop_plot["STOPCODE"] = "ST: " + example_stop_plot.STOPID.astype("str")
        example_stop_plot = example_stop_plot.query("DELTADIS > 0").copy()

        # Multi KDE
        grid = sns.FacetGrid(example_stop_plot, hue="STOPCODE", height=10)
        grid.map(sns.kdeplot, "DELTADIS")
        grid.add_legend()
        plt.savefig(fname=f"{save_folder}/{stop_id}_KDE.svg", format="svg", dpi=1200)
        _plot_wrapper(plot)

        # Plot kde of DELTADIS
        fig, ax = plt.subplots(figsize=(16, 8))
        sns.violinplot(y="DELTADIS", x="STOPCODE", ax=ax, data=example_stop_plot)
        plt.savefig(fname=f"{save_folder}/{stop_id}_VIOLIN.svg", format="svg", dpi=1200)
        _plot_wrapper(plot)

        # Plot kde of DELTADIS
        fig, ax = plt.subplots(figsize=(16, 8))
        sns.boxplot(y="DELTADIS", x="STOPCODE", ax=ax, data=example_stop_plot)
        plt.savefig(fname=f"{save_folder}/{stop_id}_BOXPLOT.svg", format="svg", dpi=1200)
        _plot_wrapper(plot)
        
        plt.close("all")