# Flight Delay Prediction

## Data Loading and Preprocessing

In [None]:
import pandas as pd

In [None]:
# Load the cleaned flight data
flights_df = pd.read_parquet("https://github.com/JakeMalis/DS-3000-Final/raw/refs/heads/main/cleaned_flights.parquet").sample(frac=0.8)

In [None]:
# Extract the hour from the 'DATE' column and create a new column 'DEPARTURE_HOUR'
flights_df['DEPARTURE_HOUR'] = flights_df['DATE'].dt.hour

In [None]:
# Replace missing values in 'DAILY_SNOWFALL' with 0
flights_df['DAILY_SNOWFALL'] = flights_df['DAILY_SNOWFALL'].fillna(0)

In [None]:
# Replace missing values in 'ARRIVAL_DELAY' with 0
flights_df['ARRIVAL_DELAY'] = flights_df['ARRIVAL_DELAY'].fillna(0)

## JAX Implementation

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, grad, jit, value_and_grad
from sklearn.metrics import mean_squared_error, r2_score

In [None]:
# Standardize numeric columns for PyTorch models
for col in ['MONTH', 'DEPARTURE_HOUR', 'DAY_OF_WEEK', 'DISTANCE']:
    col_mean = flights_df[col].mean()
    col_std  = flights_df[col].std()
    flights_df[col] = (flights_df[col] - col_mean) / col_std

In [None]:
# Extract features for JAX models
numeric_feats = flights_df[['MONTH', 'DEPARTURE_HOUR', 'DAY_OF_WEEK', 'DISTANCE', 'DAILY_SNOWFALL']].astype(np.float32).values # Use numpy
categorical_feats = pd.get_dummies(flights_df[['AIRLINE', 'origin_airport/AIRPORT', 'destination_airport/AIRPORT']]).values.astype(np.float32) # Use numpy
X = np.hstack([numeric_feats, categorical_feats])
y = flights_df['ARRIVAL_DELAY'].values.astype(np.float32)

In [None]:
# Initialize TPU for JAX
jax_devices = jax.devices("tpu")
print(f"Using device: {jax_devices[0]}")

In [None]:
# Split data into train/test sets
train_size = int(0.8 * len(X))
test_size = len(X) - train_size

X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

# Define batch size
batch_size = 32

# Function to create batches
def create_batches(X, y, batch_size):
    n_batches = len(X) // batch_size
    for i in range(n_batches):
        yield X[i * batch_size: (i + 1) * batch_size], y[i * batch_size: (i + 1) * batch_size]

In [None]:
# Define Feedforward Neural Network in JAX
class FeedForwardNN:
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size, key):
        # Initialize weights and biases with a PRNGKey
        keys = random.split(key, 3)
        self.W1 = random.normal(keys[0], (input_size, hidden_size1)) * jnp.sqrt(2.0 / input_size)
        self.b1 = jnp.zeros(hidden_size1)
        self.W2 = random.normal(keys[1], (hidden_size1, hidden_size2)) * jnp.sqrt(2.0 / hidden_size1)
        self.b2 = jnp.zeros(hidden_size2)
        self.W3 = random.normal(keys[2], (hidden_size2, output_size)) * jnp.sqrt(2.0 / hidden_size2)
        self.b3 = jnp.zeros(output_size)

    # Function to get trainable parameters
    def get_params(self):
        return [(self.W1, self.b1), (self.W2, self.b2), (self.W3, self.b3)]

    # Function to set trainable parameters
    def set_params(self, params):
        self.W1, self.b1 = params[0]
        self.W2, self.b2 = params[1]
        self.W3, self.b3 = params[2]

    def forward(self, X):
        # First hidden layer
        hidden1 = jax.nn.relu(jnp.dot(X, self.W1) + self.b1)
        # Second hidden layer
        hidden2 = jax.nn.relu(jnp.dot(hidden1, self.W2) + self.b2)
        # Output layer
        output = jnp.dot(hidden2, self.W3) + self.b3
        return output

In [None]:
# Instantiate the model, define loss function and optimizer
input_size = X.shape[1]
hidden_size1 = 64  # Example hidden layer size
hidden_size2 = 32   # Example second hidden layer size
output_size = 1     # Predicting a single value (arrival delay)

rng_key = random.PRNGKey(0) # Define the initial random key
model = FeedForwardNN(input_size, hidden_size1, hidden_size2, output_size, rng_key)

In [None]:
# Loss function - takes params as the first argument for grad
def mse_loss(params, X, y):
    # Create a dummy model instance to use the forward method
    # The actual parameters are passed in `params`
    dummy_model = FeedForwardNN(params[0][0].shape[0], params[0][0].shape[1], params[1][0].shape[1], params[2][0].shape[1], random.PRNGKey(0))
    dummy_model.set_params(params) # Set the current parameters

    preds = dummy_model.forward(X)
    return jnp.mean((preds - y) ** 2)

# Predict function - takes params and X, converts X to JAX array
@jit
def predict(params, X):
    dummy_model = FeedForwardNN(params[0][0].shape[0], params[0][0].shape[1], params[1][0].shape[1], params[2][0].shape[1], random.PRNGKey(0))
    dummy_model.set_params(params)
    return dummy_model.forward(X)

# Define the update step (train_step)
@jit
def train_step(params, X_batch, y_batch, learning_rate):
    """Updates model parameters using gradient descent on a batch."""
    loss, grads = value_and_grad(mse_loss)(params, X_batch, y_batch)
    # Update parameters
    updated_params = []
    for param, grad in zip(params, grads):
        updated_params.append((param[0] - learning_rate * grad[0],
                               param[1] - learning_rate * grad[1]))
    return updated_params

In [None]:
# Training loop
params = model.get_params() # Get initial parameters from the model instance
num_epochs = 20
learning_rate = 0.00001

for epoch in range(num_epochs):
    for X_batch_np, y_batch_np in create_batches(X_train, y_train, batch_size):
        # Convert batch to JAX arrays before passing to train_step
        X_batch_jax = jnp.array(X_batch_np, dtype=jnp.float32)
        y_batch_jax = jnp.array(y_batch_np, dtype=jnp.float32).reshape(-1, 1)
        params = train_step(params, X_batch_jax, y_batch_jax, learning_rate)

    # Calculate training loss on a subset of the data to reduce memory usage
    # Convert subset to JAX array before calculating loss
    X_train_subset_jax = jnp.array(X_train[:10000], dtype=jnp.float32)
    y_train_subset_jax = jnp.array(y_train[:10000], dtype=jnp.float32).reshape(-1, 1)
    train_loss = mse_loss(params, X_train_subset_jax, y_train_subset_jax)
    print(f"Epoch {epoch + 1}, Training Loss: {train_loss:.4f}")

In [None]:
# Evaluate on the test set
# Convert test set to JAX array for prediction
X_test_jax = jnp.array(X_test, dtype=jnp.float32)
y_pred = predict(params, X_test_jax)

r2 = r2_score(y_test, y_pred)
mse = mean_squared_error(y_test, y_pred)
rmse = jnp.sqrt(mse)

print(f"Test R-squared (R2): {r2:.4f}")
print(f"Test Mean Squared Error (MSE): {mse:.4f}")
print(f"Test Root Mean Squared Error (RMSE): {rmse:.4f}")