# Augment DF with most likely next pickup

In [1]:
import pandas as pd
import numpy as np
import os


NUMBER_OF_ZONES = 265

In [2]:
df = pd.read_parquet("all_cleaned_data/all_cleaned_data.parquet", engine='fastparquet')
ptm = np.load("probability_transition_matrix.npy")

In [3]:
df.head()


Unnamed: 0,VendorID,PickupDatetime,DropoffDatetime,TripDuration,PassengerCount,TripDistance,PULocationID,DOLocationID,PaymentType,FareAmount,ExtraCharges,MTATax,TipAmount,TollsAmount,ImprovementSurcharge,TotalAmount,CongestionSurcharge,AirportFee
0,2.0,2023-10-01 00:57:33,2023-10-01 01:07:58,10.416667,1.0,1.45,166.0,74.0,1.0,12.1,1.0,0.5,2.92,0.0,1.0,17.52,0.0,0
1,2.0,2023-10-01 01:00:16,2023-10-01 01:06:13,5.95,1.0,0.89,74.0,42.0,2.0,7.9,1.0,0.5,0.0,0.0,1.0,10.4,0.0,0
2,2.0,2023-10-01 00:51:52,2023-10-01 01:00:32,8.666667,1.0,2.38,83.0,129.0,2.0,13.5,1.0,0.5,0.0,0.0,1.0,16.0,0.0,0
3,2.0,2023-10-01 00:03:39,2023-10-01 00:11:20,7.683333,1.0,2.26,74.0,263.0,1.0,11.4,1.0,0.5,3.33,0.0,1.0,19.98,2.75,0
4,2.0,2023-10-01 00:27:42,2023-10-01 00:39:10,11.466667,1.0,2.14,74.0,236.0,1.0,13.5,1.0,0.5,2.81,0.0,1.0,21.559999,2.75,0


In [5]:
from tqdm import tqdm

def augment_df(df, ptm, chunk_size=100_000):
    # Initialize the column
    df["NextPU"] = 0
    bad_indices = []
    valid_count = 0

    num_chunks = (len(df) // chunk_size) + 1
    for start in tqdm(range(0, len(df), chunk_size), total=num_chunks, desc="Processing chunks"):
        end = min(start + chunk_size, len(df))
        chunk = df.iloc[start:end]

        for row in chunk.itertuples(index=True):
            i = row.Index

            try:
                hour = row.DropoffDatetime.hour
                do_zone = int(row.DOLocationID)
                distribution = ptm[hour][do_zone]

                if not np.isclose(distribution.sum(), 1.0):
                    bad_indices.append(i)
                else:
                    next_pu = np.random.choice(len(distribution), p=distribution)
                    df.at[i, "NextPU"] = int(next_pu)
                    valid_count += 1

            except Exception as e:
                bad_indices.append(i) 
    df.drop(index=bad_indices, inplace=True)
    df.reset_index(drop=True, inplace=True)
    print(f"\n✅ {valid_count} valid entries retained.")
    print(f"🗑️  {len(bad_indices)} invalid entries removed.")



    return df

In [10]:
output_filename = "all_cleaned_data_augmented.parquet"

df_augmented = augment_df(df, ptm)


df_augmented.to_parquet(output_filename, index=False, engine="fastparquet")



Processing chunks: 100%|██████████████████████| 710/710 [21:08<00:00,  1.79s/it]



✅ 24145198 valid entries retained.
🗑️  46800829 invalid entries removed.


Unnamed: 0,VendorID,PickupDatetime,DropoffDatetime,TripDuration,PassengerCount,TripDistance,PULocationID,DOLocationID,PaymentType,FareAmount,ExtraCharges,MTATax,TipAmount,TollsAmount,ImprovementSurcharge,TotalAmount,CongestionSurcharge,AirportFee,NextPU
36,2.0,2023-10-01 00:12:40,2023-10-01 00:19:55,7.250000,1.0,1.80,80.0,157.0,1.0,10.7,1.0,0.5,2.64,0.0,1.0,15.840000,0.00,0,42
45,2.0,2023-10-01 00:02:10,2023-10-01 00:09:34,7.400000,2.0,0.91,260.0,129.0,1.0,8.6,1.0,0.5,0.00,0.0,1.0,11.100000,0.00,0,42
128,2.0,2023-10-01 02:03:54,2023-10-01 02:23:04,19.166667,1.0,1.86,82.0,173.0,2.0,14.9,1.0,0.5,0.00,0.0,1.0,17.400000,0.00,0,42
141,2.0,2023-10-01 02:24:37,2023-10-01 02:59:18,34.683333,1.0,7.10,80.0,114.0,2.0,60.0,0.0,0.0,0.00,0.0,1.0,63.750000,2.75,0,42
267,2.0,2023-10-01 08:53:38,2023-10-01 09:02:44,9.100000,1.0,1.87,7.0,226.0,1.0,11.4,0.0,0.5,2.10,0.0,1.0,15.000000,0.00,0,42
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1906,2.0,2023-10-03 23:58:43,2023-10-04 00:02:52,4.150000,1.0,1.34,75.0,74.0,1.0,7.9,1.0,0.5,1.00,0.0,1.0,11.400000,0.00,0,42
1973,2.0,2023-10-04 02:00:24,2023-10-04 02:02:30,2.100000,1.0,0.33,74.0,74.0,1.0,4.4,1.0,0.5,0.00,0.0,1.0,6.900000,0.00,0,42
2060,2.0,2023-10-04 07:51:56,2023-10-04 08:11:45,19.816667,1.0,1.34,74.0,75.0,1.0,17.0,0.0,0.5,3.70,0.0,1.0,22.200001,0.00,0,42
2160,1.0,2023-10-04 08:48:15,2023-10-04 09:23:14,34.983333,1.0,1.40,74.0,75.0,1.0,20.5,0.0,1.5,4.40,0.0,1.0,26.400000,0.00,0,42
