# Step6-plotting & analysis

## 1. Impot necessary modules & start a spark session

In [None]:
# Import necessary modules
from pyspark.sql import SparkSession
from statsmodels.formula.api import ols
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import scipy.stats as stats
import statsmodels.api as sm
import numpy as np

In [None]:
# Create a Spark session
spark = (
    SparkSession.builder.appName('step_1-download_data.py')
    .config('spark.sql.repl.eagerEval.enabled', True)
    .config('spark.sql.parquet.cacheMetadata', 'true')
    .config('spark.sql.session.timeZone', 'Etc/UTC')
    .config('spark.driver.memory', '16g')
    .config('spark.executer.memory', '16g')
    .getOrCreate()
)

## 2. Import merged_data and sample

Import `merged_data` from directory `data/merged_data/`

In [None]:
merged_data_path = '../data/merged_data/merged_data.parquet/'
merged_data = spark.read.parquet(merged_data_path)

Do a sample for `merged_data`, with sample size of 0.01

In [None]:
sample_size = 0.01
sample_merged_data = merged_data.sample(sample_size, seed=1).toPandas()
print('#rows_of_sample_data: ', len(sample_merged_data))

## 3. Distribution of trip duration

Plot the distribution of 'trip_duration'

In [None]:
plt.figure(figsize=(10, 6))
sns.distplot(sample_merged_data['trip_duration'], bins=30, kde=True)  # KDE为True将同时显示密度估计
plt.title('Distribution of Trip Duration')
plt.xlabel('Trip Duration')
plt.ylabel('Density')

plt.savefig('../plots/distribution_of_trip_duration.png')
plt.show()

Replot after doing a log tramsformation

In [None]:
# Log transformation to the 'trip_duration' for better normality.
sample_merged_data['log_trip_duration'] = np.log1p(sample_merged_data['trip_duration'])

plt.figure(figsize=(10, 6))
sns.distplot(sample_merged_data['log_trip_duration'], kde=True, bins=30)
plt.title('Distribution of Log Transformed Trip Duration')
plt.xlabel('Log(Trip Duration)')
plt.ylabel('Density')

plt.savefig('../plots/distribution_of_log_transformed_trip_duration')
plt.show()

In [None]:
# remove 'log_trip_duration' from `sample_merged_data`
del sample_merged_data['log_trip_duration']

## 4. Correlation of continuous features

Separate discussion of continuous and discrete features

In [None]:
continuous_feature_list = [
    '#passenger', 'trip_distance', 'average_speed', 'congestion_fee', 'toll_fee', 'temperature', 'uv_index', 'visibility', 'trip_duration'
]

### 4.1 Pair plot for continuous features

Change the continuous features of type `int` or `object` to float type before plotting

In [None]:
sample_merged_data['trip_duration'] = sample_merged_data['trip_duration'].astype(float)
sample_merged_data['uv_index'] = sample_merged_data['uv_index'].astype(float)
sample_merged_data['temperature'] = pd.to_numeric(sample_merged_data['temperature'], errors='coerce')
sample_merged_data['visibility'] = pd.to_numeric(sample_merged_data['visibility'], errors='coerce')
sample_merged_data.head()

In [None]:
sns.set(style="ticks", color_codes=True)

pair_plot = sns.pairplot(
    sample_merged_data[continuous_feature_list], 
    plot_kws=dict(s=1, edgecolor="b", linewidth=1)
)
pair_plot.fig.suptitle(f"Pair Plot", y = 1)

plt.savefig('../plots/pair_plot_for_continuous_features')
plt.show()

### 4.2 Heat plot for continuous features

In [None]:
sample_merged_data.head()

In [None]:
plt.figure(figsize=(12, 9))

sns.heatmap(
    sample_merged_data[continuous_feature_list].corr(), 
    annot=True, 
    cmap='coolwarm', 
    center=0
)
plt.title('Pearson Correlation Metric')

plt.savefig('../plots/head_plot_for_continuous_features')
plt.show()

### 4.3 Heat map with spearman correlation coefficient

In [None]:
# Use spearman correlation coefficient
spearman_corr = sample_merged_data[continuous_feature_list].corr(method='spearman')

plt.figure(figsize=(12, 9))

sns.heatmap(
    spearman_corr, 
    annot=True, 
    cmap='coolwarm', 
    center=0
)
plt.title('Spearman Correlation Metric')

plt.savefig('../plots/head_plot_with_spearman_for_continuous_features')
plt.show()

## 5. Correlation of discrete features

In [None]:
discrete_feature_list = [
    'trip_duration', 'up_location_id', 'off_location_id', 'if_weekend', 'if_peak_hour', 'if_overnight', 'if_airport', 
    'if_rain', 'if_snow', 'if_overcast', 'if_cloudy', 'if_clear'
]

### 5.1 Pair plot for discrete features

In [None]:
sns.set(style="ticks", color_codes=True)

pair_plot = sns.pairplot(
    sample_merged_data[discrete_feature_list], 
    plot_kws=dict(s=1, edgecolor="b", linewidth=1)
)
pair_plot.fig.suptitle(f"Pair Plot", y = 1)

plt.savefig('../plots/pair_plot_for_discrete_features')
plt.show()

With the pair plot for the discrete features, we did not see any correlation between the discrete features and 'trip_duration', so we decided to do anova on the discrete features

### 5.2 ANOVA for discrete features

Remove 'up_location_id' & 'off_location_id' for now

In [None]:
discrete_feature_list = [
    'if_weekend', 'if_peak_hour', 'if_overnight', 'if_airport', 'if_rain', 'if_snow', 'if_overcast', 'if_cloudy', 'if_clear'
]

In [None]:
results = []

# Loop through each discrete feature in the list
for feature in discrete_feature_list:
    # Group 'trip_duration' where the discrete feature value is 1
    group1 = sample_merged_data[sample_merged_data[feature] == 1]['trip_duration']

    # Group 'trip_duration' where the discrete feature value is 0
    group2 = sample_merged_data[sample_merged_data[feature] == 0]['trip_duration']
    
    # do a one-way ANOVA test to determine if means of the two groups are the same
    f_val, p_val = stats.f_oneway(group1, group2)
    
    results.append({'feature': feature, 'F-value': f_val, 'p-value': p_val})

results_df = pd.DataFrame(results).sort_values(by='p-value')

print(results_df)

### 5.3   2-way ANOVA with intersection for 'up_location_id' & 'off_location_id'

In [None]:
# Build a linear regression model using Ordinary Least Squares
model = ols(
    f'trip_duration ~ up_location_id * off_location_id', 
    data=sample_merged_data
).fit()

# Create the ANOVA table for the given model
anova_table = sm.stats.anova_lm(model, typ=2)

print(anova_table)

## 6. Stop spark session

In [None]:
spark.stop()