# Augment DF with most likely next pickup

In [1]:
import pandas as pd
import numpy as np
import os


NUMBER_OF_ZONES = 265

In [2]:
df = pd.read_parquet("all_cleaned_data/all_cleaned_data.parquet", engine='fastparquet')
ptm = np.load("probability_transition_matrix.npy")

In [3]:
df.head()


Unnamed: 0,VendorID,PickupDatetime,DropoffDatetime,TripDuration,PassengerCount,TripDistance,PULocationID,DOLocationID,PaymentType,FareAmount,ExtraCharges,MTATax,TipAmount,TollsAmount,ImprovementSurcharge,TotalAmount,CongestionSurcharge,AirportFee
0,2.0,2023-10-01 00:57:33,2023-10-01 01:07:58,10.416667,1.0,1.45,166.0,74.0,1.0,12.1,1.0,0.5,2.92,0.0,1.0,17.52,0.0,0
1,2.0,2023-10-01 01:00:16,2023-10-01 01:06:13,5.95,1.0,0.89,74.0,42.0,2.0,7.9,1.0,0.5,0.0,0.0,1.0,10.4,0.0,0
2,2.0,2023-10-01 00:51:52,2023-10-01 01:00:32,8.666667,1.0,2.38,83.0,129.0,2.0,13.5,1.0,0.5,0.0,0.0,1.0,16.0,0.0,0
3,2.0,2023-10-01 00:03:39,2023-10-01 00:11:20,7.683333,1.0,2.26,74.0,263.0,1.0,11.4,1.0,0.5,3.33,0.0,1.0,19.98,2.75,0
4,2.0,2023-10-01 00:27:42,2023-10-01 00:39:10,11.466667,1.0,2.14,74.0,236.0,1.0,13.5,1.0,0.5,2.81,0.0,1.0,21.559999,2.75,0


In [5]:
from tqdm import tqdm

def augment_df(df, ptm, chunk_size=100_000):
    # Initialize the column in-place to avoid full dataframe copy
    df["NextPU"] = 0
    bad_indices = []
    valid_count = 0

    # Use tqdm to track the progress of chunks
    num_chunks = (len(df) // chunk_size) + 1
    for start in tqdm(range(0, len(df), chunk_size), total=num_chunks, desc="Processing chunks"):
        end = min(start + chunk_size, len(df))
        chunk = df.iloc[start:end]

        for row in chunk.itertuples(index=True):
            i = row.Index

            try:
                hour = row.DropoffDatetime.hour
                do_zone = int(row.DOLocationID)
                distribution = ptm[hour][do_zone]

                if not np.isclose(distribution.sum(), 1.0):
                    bad_indices.append(i)  # mark for deletion
                else:
                    next_pu = np.random.choice(len(distribution), p=distribution)
                    df.at[i, "NextPU"] = int(next_pu)
                    valid_count += 1

            except Exception as e:
                bad_indices.append(i)  # if any error, also mark for deletion
    df.drop(index=bad_indices, inplace=True)
    df.reset_index(drop=True, inplace=True)
    print(f"\n✅ {valid_count} valid entries retained.")
    print(f"🗑️  {len(bad_indices)} invalid entries removed.")



    return df

In [10]:
output_filename = "all_cleaned_data_augmented.parquet"

df_augmented = augment_df(df, ptm)


df_augmented.to_parquet(output_filename, index=False, engine="fastparquet")



Processing chunks: 100%|██████████████████████| 710/710 [21:08<00:00,  1.79s/it]



✅ 24145198 valid entries retained.
🗑️  46800829 invalid entries removed.


In [11]:
df_augmented.head(500)

Unnamed: 0,VendorID,PickupDatetime,DropoffDatetime,TripDuration,PassengerCount,TripDistance,PULocationID,DOLocationID,PaymentType,FareAmount,ExtraCharges,MTATax,TipAmount,TollsAmount,ImprovementSurcharge,TotalAmount,CongestionSurcharge,AirportFee,NextPU
0,2.0,2023-10-01 00:57:33,2023-10-01 01:07:58,10.416667,1.0,1.45,166.0,74.0,1.0,12.100000,1.0,0.5,2.92,0.0,1.0,17.520000,0.00,0,74
1,2.0,2023-10-01 01:00:16,2023-10-01 01:06:13,5.950000,1.0,0.89,74.0,42.0,2.0,7.900000,1.0,0.5,0.00,0.0,1.0,10.400000,0.00,0,82
2,2.0,2023-10-01 00:51:52,2023-10-01 01:00:32,8.666667,1.0,2.38,83.0,129.0,2.0,13.500000,1.0,0.5,0.00,0.0,1.0,16.000000,0.00,0,116
3,2.0,2023-10-01 00:03:39,2023-10-01 00:11:20,7.683333,1.0,2.26,74.0,263.0,1.0,11.400000,1.0,0.5,3.33,0.0,1.0,19.980000,2.75,0,7
4,2.0,2023-10-01 00:27:42,2023-10-01 00:39:10,11.466667,1.0,2.14,74.0,236.0,1.0,13.500000,1.0,0.5,2.81,0.0,1.0,21.559999,2.75,0,80
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,2.0,2023-10-01 12:23:14,2023-10-01 12:35:30,12.266667,1.0,2.22,41.0,239.0,1.0,14.200000,0.0,0.5,4.61,0.0,1.0,23.059999,2.75,0,75
496,2.0,2023-10-01 12:48:59,2023-10-01 12:52:15,3.266667,1.0,0.78,41.0,41.0,2.0,5.800000,0.0,0.5,0.00,0.0,1.0,7.300000,0.00,0,75
497,2.0,2023-10-01 12:35:20,2023-10-01 12:48:32,13.200000,1.0,3.39,92.0,129.0,2.0,18.400000,0.0,0.5,0.00,0.0,1.0,19.900000,0.00,0,74
498,2.0,2023-10-01 12:36:25,2023-10-01 12:44:31,8.100000,1.0,0.81,166.0,151.0,1.0,9.300000,0.0,0.5,0.00,0.0,1.0,10.800000,0.00,0,196
