# EDA on Processed Training Data

This notebook inspects the output of the `generate_dataset` Beam pipeline. We will check data quality, distributions, and relationships between features.

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pyarrow.parquet as pq

sns.set_theme(style="whitegrid")
%matplotlib inline

## 1. Load Data
Loading the locally downloaded parquet file (downloaded from GCS output).

In [2]:
DATA_PATH = '../local_artifacts/processed_data/training_data.parquet'

try:
    df = pd.read_parquet(DATA_PATH)
    print(f"Loaded dataset with {df.shape[0]} rows and {df.shape[1]} columns.")
except FileNotFoundError:
    print(f"File not found at {DATA_PATH}. Please ensure you have downloaded the data from GCS.")

Loaded dataset with 0 rows and 20 columns.


## 2. Basic Inspection

In [None]:
df.info()

In [None]:
df.head()

## 3. Data Quality Checks
Checking for null values in key columns.

In [None]:
null_counts = df.isnull().sum()
null_counts[null_counts > 0]

### Check Target Variable (`service_headway`)

In [None]:
plt.figure(figsize=(10, 6))
sns.histplot(df['service_headway'], bins=50, kde=True)
plt.title('Distribution of Service Headway (min)')
plt.xlabel('Headway (minutes)')
plt.show()

In [None]:
print("Headway Statistics:")
df['service_headway'].describe()

## 4. Feature Relationships
Analyzing the relationship between valid existing travel times and the target headway.

In [None]:
# Filter for rows where travel_time_34th is not null
valid_tt = df.dropna(subset=['travel_time_34th'])

plt.figure(figsize=(10, 6))
sns.scatterplot(data=valid_tt, x='travel_time_34th', y='service_headway', alpha=0.5)
plt.title('Travel Time from 34th St vs Service Headway')
plt.xlabel('Travel Time (34th -> Target) [min]')
plt.ylabel('Service Headway [min]')
plt.show()

### Correlation Matrix

In [None]:
numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
corr = df[numeric_cols].corr()

plt.figure(figsize=(12, 10))
sns.heatmap(corr, annot=True, cmap='coolwarm', fmt=".2f")
plt.title('Correlation Matrix')
plt.show()

## 5. Temporal Analysis

In [None]:
try:
    # Convert arrival_time back to datetime for plotting
    df['dt'] = pd.to_datetime(df['arrival_time'])
    
    # Plot headway over time for a sample day
    sample_day = df['dt'].dt.date.iloc[0]
    daily_data = df[df['dt'].dt.date == sample_day].sort_values('dt')
    
    plt.figure(figsize=(15, 6))
    plt.plot(daily_data['dt'], daily_data['service_headway'], marker='o')
    plt.title(f'Headway over Time on {sample_day}')
    plt.ylabel('Headway (min)')
    plt.xlabel('Time')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.show()
except Exception as e:
    print(f"Could not plot time series: {e}")