# Setup

In [None]:
import pickle
import matplotlib.pyplot as plt
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import dataset
import torch
from pprint import pprint
import pandas as pd
import seaborn as sns
import sklearn
from  sklearn.manifold import TSNE
from  sklearn.decomposition import PCA
from  sklearn.preprocessing import KBinsDiscretizer
from  sklearn.preprocessing import MinMaxScaler
from  sklearn.preprocessing import RobustScaler
from  sklearn.preprocessing import StandardScaler
from  sklearn.preprocessing import QuantileTransformer
import numpy as np
import math

In [None]:
pd.option_context('display.max_rows', None, 'display.max_columns', None,'display.max_colwidth', -1)

In [3]:
data_type = "prsa"
name_suffix = ""
# name_suffix = "timedelta"

In [4]:
if data_type == "prsa":
    with open(f"../data/prsa/PRSADataset_labeled{name_suffix}.pkl", "rb") as f:
        dataset = pickle.load(f)

# Full data

In [None]:
# FYI: Stations order when the dataset samples were preprared
stations = []
for group in dataset.data.groupby("station"):
    station_name = group[0]
    stations.append(station_name)
print(stations)

In [None]:
# FYI: raw initial data
os.chdir("..")
raw_data = dataset.read_data(dataset.data_root, dataset.nrows)
raw_data["timestamp_raw"] = pd.to_datetime(dict(year=raw_data['year'], month=raw_data['month'], day=raw_data['day'], hour=raw_data['hour'])).astype(int)
raw_data["date_raw"] = pd.to_datetime(dict(year=raw_data['year'], month=raw_data['month'], day=raw_data['day'], hour=raw_data['hour']))
raw_data['year_month'] = raw_data['date_raw'].dt.strftime('%Y-%m')
raw_data['weekday'] = raw_data['date_raw'].dt.dayofweek
raw_data = raw_data.sort_values(by=["station", "timestamp_raw"])
raw_data

In [None]:
raw_data.columns

In [None]:
preproc_data = dataset.data # Do NOT sort again (keep original timestamp sort)
preprocessed_data = []
for group, data in preproc_data.groupby("station"):
    preprocessed_data.append(data)
preprocessed_data = pd.concat(preprocessed_data)
preprocessed_data

In [None]:
windows_outliers = preprocessed_data[preprocessed_data["station"] == "Aotizhongxin"]["outlier"]
windows_outliers = windows_outliers.rolling(dataset.seq_len, step=dataset.stride, closed="left").sum()[2:]
windows_outliers = np.where(windows_outliers >= 1, 1, 0) # Rolling in pd starts with the n-1 prior rows (unintuitive)
windows_outliers.shape

In [None]:
samples = dataset.samples
print(len(samples))

targets = dataset.targets
print(len(targets))

print(preprocessed_data.shape[0]/dataset.stride)

In [None]:
vocab_keys = list(dataset.vocab.token2id.keys())
for k in vocab_keys:
    print(f"\n--{k}--")
    # pprint(dataset.vocab.token2id[k])
    print(len(dataset.vocab.token2id[k]))

In [None]:
dataset.vocab.token2id["SPECIAL"]

# Sample

In [13]:
def get_final_id(windows_outliers, sample_id):
    """
    Find the starting index in the the original data
    that corresponds to a given dataset sample ID,
    accoutning for detected outliers when buildind the dataset
    """
    noutliers_included =  windows_outliers[:sample_id+1].sum()
    next_safe_sample = 0
    noutliers_additional = 0
    if noutliers_included >= 1:
        noutliers_additional = 0
        next_safe_sample = 0
        i = 0
        while next_safe_sample < noutliers_included:
            if windows_outliers[sample_id+1+i] == 1:
                noutliers_additional += 1
                i += 1
            else:
                next_safe_sample += 1
    final_id = sample_id + noutliers_included + noutliers_additional

    if windows_outliers[final_id] == 1: # is the final index an outlier as well
        return final_id + get_final_id(windows_outliers[final_id:], 0)
    else:
        return final_id

In [14]:
sample_id = 38 # Only works for samples of the first station (because the rolling windows for outliers is NOT grouped by station first)
stride = dataset.stride
final_id = get_final_id(windows_outliers, sample_id)
raw_sample = raw_data[stride * final_id: dataset.seq_len + stride * final_id]
preprocessed_sample = preprocessed_data[stride * final_id: dataset.seq_len + stride * final_id]
pytorch_sample = torch.tensor(samples[sample_id]).reshape(dataset.seq_len, -1)
pytorch_target = torch.tensor(targets[sample_id])

In [None]:
raw_sample

In [None]:
preprocessed_sample

In [None]:
pytorch_sample

In [None]:
pytorch_target

In [19]:
def find_sample_above_pm(min_pm: float, samples: list, targets: list, seq_len: int=10, n:int=1):
    pm = 0
    i = -1
    occur = 0
    while ((pm < min_pm) or (occur < n)):
        i += 1
        target = torch.tensor(targets[i])
        pm = target.max()
        if pm >= min_pm:
            occur += 1
    return (
        i,
        torch.tensor(samples[i]).reshape(seq_len, -1),
        torch.tensor(targets[i])
    )

In [None]:
find_sample_above_pm(min_pm=600, samples=samples, targets=targets, seq_len=10, n=1)

# Explore

In [None]:
# Summary stats on targets
preprocessed_data_wo_meas_outliers = preprocessed_data[preprocessed_data["PM10"] >= preprocessed_data["PM2.5"]].copy(deep=True) # Measurement error otherwise (removed from pytorch samples)
preprocessed_data_wo_meas_outliers[["PM2.5", "PM10"]].describe()

In [None]:
# Distribution of targets and comparisons of transformations

fig, axes = plt.subplots(nrows=7, ncols=2, figsize=(12, 10))
sns.violinplot(
    ax=axes[0][0],
    x='PM2.5', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    cut=preprocessed_data_wo_meas_outliers["PM2.5"].min(),
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[0][1],
    x='PM10', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    cut=preprocessed_data_wo_meas_outliers["PM10"].min(),
    inner="quart",
    bw_adjust=.1,
)

minmax_scaler = MinMaxScaler()
preprocessed_data_wo_meas_outliers[["PM2.5_minmaxscaled", "PM10_minmaxscaled"]] = minmax_scaler.fit_transform(preprocessed_data_wo_meas_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[1][0],
    x='PM2.5_minmaxscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    cut=0,
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[1][1],
    x='PM10_minmaxscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    cut=0,
    bw_adjust=.1,
)

robust_scaler = RobustScaler()
preprocessed_data_wo_meas_outliers[["PM2.5_robustscaled", "PM10_robustscaled"]] = robust_scaler.fit_transform(preprocessed_data_wo_meas_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[2][0],
    x='PM2.5_robustscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[2][1],
    x='PM10_robustscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    bw_adjust=.1,
)

quantile_scaler = QuantileTransformer(output_distribution="normal")
preprocessed_data_wo_meas_outliers[["PM2.5_quantilescaled", "PM10_quantilescaled"]] = quantile_scaler.fit_transform(preprocessed_data_wo_meas_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[3][0],
    x='PM2.5_quantilescaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[3][1],
    x='PM10_quantilescaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    bw_adjust=.1,
)

uniform_scaler = QuantileTransformer(output_distribution="uniform")
preprocessed_data_wo_meas_outliers[["PM2.5_uniformscaled", "PM10_uniformscaled"]] = uniform_scaler.fit_transform(preprocessed_data_wo_meas_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[4][0],
    x='PM2.5_uniformscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    cut=0,
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[4][1],
    x='PM10_uniformscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    cut=0,
    bw_adjust=.1,
)

preprocessed_data_wo_meas_outliers["PM2.5_logscaled"] = preprocessed_data_wo_meas_outliers["PM2.5"].map(lambda x: np.log(x))
preprocessed_data_wo_meas_outliers["PM10_logscaled"] = preprocessed_data_wo_meas_outliers["PM10"].map(lambda x: np.log(x))
sns.violinplot(
    ax=axes[5][0],
    x='PM2.5_logscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[5][1],
    x='PM10_logscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    bw_adjust=.1,
)

std_scaler = StandardScaler()
preprocessed_data_wo_meas_outliers[["PM2.5_stdscaled", "PM10_stdscaled"]] = std_scaler.fit_transform(preprocessed_data_wo_meas_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[6][0],
    x='PM2.5_stdscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[6][1],
    x='PM10_stdscaled', 
    split=True,
    data=preprocessed_data_wo_meas_outliers,
    inner="quart",
    bw_adjust=.1,
)

plt.suptitle("Targets distributions (after removing measurements outliers)")
plt.tight_layout()

# Take aways:
# - Targets distributions are right-skewed, lots of outliers
# - We use StdScaler for the targets (and with no activation in the last layer)

In [None]:
# Dispersion targets outliers

def get_IQR_outlier_cutoff(df, column, iqr_scalar=1.5):

     Q1 = df[column].quantile(0.25)
     Q3 = df[column].quantile(0.75)
     IQR = Q3 - Q1
     lower_outliers_cutoff = (Q1 - iqr_scalar * IQR)
     higher_outliers_cutoff = (Q3 + iqr_scalar * IQR)
     return (lower_outliers_cutoff, higher_outliers_cutoff)

lower_outliers_cutoff_pm25, higher_outliers_cutoff_pm25 = get_IQR_outlier_cutoff(preprocessed_data_wo_meas_outliers, "PM2.5")
print(f"{lower_outliers_cutoff_pm25=}") 
print(f"{higher_outliers_cutoff_pm25=}")

lower_outliers_cutoff_pm10, higher_outliers_cutoff_pm10 = get_IQR_outlier_cutoff(preprocessed_data_wo_meas_outliers, "PM10")
print(f"{lower_outliers_cutoff_pm10=}") 
print(f"{higher_outliers_cutoff_pm10=}")

In [None]:
# Distribution of *outlier* targets and comparisons of transformations

preprocessed_data_wo_disp_outliers = preprocessed_data_wo_meas_outliers[
    (preprocessed_data_wo_meas_outliers["PM2.5"] <= higher_outliers_cutoff_pm25) &
    (preprocessed_data_wo_meas_outliers["PM10"] <= higher_outliers_cutoff_pm10)
].copy(deep=True)

fig, axes = plt.subplots(nrows=7, ncols=2, figsize=(12, 10))
sns.violinplot(
    ax=axes[0][0],
    x='PM2.5', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    cut=preprocessed_data_wo_disp_outliers["PM2.5"].min(),
    inner="quart",
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[0][1],
    x='PM10', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    cut=preprocessed_data_wo_disp_outliers["PM10"].min(),
    inner="quart",
    bw_adjust=.1,
)

minmax_scaler = MinMaxScaler()
preprocessed_data_wo_disp_outliers[["PM2.5_minmaxscaled", "PM10_minmaxscaled"]] = minmax_scaler.fit_transform(preprocessed_data_wo_disp_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[1][0],
    x='PM2.5_minmaxscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    cut=0,
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[1][1],
    x='PM10_minmaxscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    cut=0,
    bw_adjust=.1,
)

robust_scaler = RobustScaler()
preprocessed_data_wo_disp_outliers[["PM2.5_robustscaled", "PM10_robustscaled"]] = robust_scaler.fit_transform(preprocessed_data_wo_disp_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[2][0],
    x='PM2.5_robustscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[2][1],
    x='PM10_robustscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    bw_adjust=.1,
)

quantile_scaler = QuantileTransformer(output_distribution="normal")
preprocessed_data_wo_disp_outliers[["PM2.5_quantilescaled", "PM10_quantilescaled"]] = quantile_scaler.fit_transform(preprocessed_data_wo_disp_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[3][0],
    x='PM2.5_quantilescaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[3][1],
    x='PM10_quantilescaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    bw_adjust=.1,
)

uniform_scaler = QuantileTransformer(output_distribution="uniform")
preprocessed_data_wo_disp_outliers[["PM2.5_uniformscaled", "PM10_uniformscaled"]] = uniform_scaler.fit_transform(preprocessed_data_wo_disp_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[4][0],
    x='PM2.5_uniformscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    cut=0,
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[4][1],
    x='PM10_uniformscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    cut=0,
    bw_adjust=.1,
)

preprocessed_data_wo_disp_outliers["PM2.5_logscaled"] = preprocessed_data_wo_disp_outliers["PM2.5"].map(lambda x: np.log(x))
preprocessed_data_wo_disp_outliers["PM10_logscaled"] = preprocessed_data_wo_disp_outliers["PM10"].map(lambda x: np.log(x))
sns.violinplot(
    ax=axes[5][0],
    x='PM2.5_logscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[5][1],
    x='PM10_logscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    bw_adjust=.1,
)

std_scaler = StandardScaler()
preprocessed_data_wo_disp_outliers[["PM2.5_stdscaled", "PM10_stdscaled"]] = std_scaler.fit_transform(preprocessed_data_wo_disp_outliers[["PM2.5", "PM10"]] )
sns.violinplot(
    ax=axes[6][0],
    x='PM2.5_stdscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[6][1],
    x='PM10_stdscaled', 
    split=True,
    data=preprocessed_data_wo_disp_outliers,
    inner="quart",
    bw_adjust=.1,
)

plt.suptitle("Preprocessed targets distributions (after removing dispersion outliers, based on 1.5 IQR)")
plt.tight_layout()

# Take aways:
# - Targets (without dispersion outliers) are still right-skewed

In [None]:
# Distribution of *outlier* targets and comparisons of transformations

outliers = preprocessed_data_wo_meas_outliers[
    (preprocessed_data_wo_meas_outliers["PM2.5"] > higher_outliers_cutoff_pm25) |
    (preprocessed_data_wo_meas_outliers["PM10"] > higher_outliers_cutoff_pm10)
].copy(deep=True)

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(6, 2))
sns.violinplot(
    ax=axes[0],
    x='PM2.5', 
    split=True,
    data=outliers,
    inner="quart",
    cut=preprocessed_data_wo_meas_outliers["PM2.5"].min(),
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[1],
    x='PM10', 
    split=True,
    data=outliers,
    inner="quart",
    cut=preprocessed_data_wo_meas_outliers["PM10"].min(),
    bw_adjust=.1,
)

plt.suptitle("Preprocessed outliers targets distributions (outliers based on 1.5 IQR)")
plt.tight_layout()

# Take aways:
# - Outlier targets are also right skewed

In [None]:
# Dispersion targets extreme outliers

extreme_lower_outliers_cutoff_pm25, extreme_higher_outliers_cutoff_pm25 = get_IQR_outlier_cutoff(preprocessed_data_wo_meas_outliers, "PM2.5", 2.5)
print(f"{extreme_lower_outliers_cutoff_pm25=}") 
print(f"{extreme_higher_outliers_cutoff_pm25=}")

extreme_lower_outliers_cutoff_pm10, extreme_higher_outliers_cutoff_pm10 = get_IQR_outlier_cutoff(preprocessed_data_wo_meas_outliers, "PM10",2.5)
print(f"{extreme_lower_outliers_cutoff_pm10=}") 
print(f"{extreme_higher_outliers_cutoff_pm10=}")

extreme_outliers = preprocessed_data_wo_meas_outliers[
    (preprocessed_data_wo_meas_outliers["PM2.5"] > extreme_higher_outliers_cutoff_pm25) |
    (preprocessed_data_wo_meas_outliers["PM10"] > extreme_higher_outliers_cutoff_pm10)
].copy(deep=True)

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(6, 2))
sns.violinplot(
    ax=axes[0],
    x='PM2.5', 
    split=True,
    data=extreme_outliers,
    inner="quart",
    cut=preprocessed_data_wo_meas_outliers["PM2.5"].min(),
    bw_adjust=.1,
)
sns.violinplot(
    ax=axes[1],
    x='PM10', 
    split=True,
    data=extreme_outliers,
    inner="quart",
    cut=preprocessed_data_wo_meas_outliers["PM10"].min(),
    bw_adjust=.1,
)

plt.suptitle("Preprocessed extreme outliers targets distributions (outliers based on 2.5 IQR)")
plt.tight_layout()

# Take aways:
# 


In [None]:
# Evolution of pollution over the years

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 8))
sns.lineplot(
    ax=ax1,
    x='year_month', 
    y='PM2.5', 
    hue='station', 
    data=raw_data
)
sns.lineplot(
    ax=ax2,
    x='year_month', 
    y='PM10', 
    hue='station', 
    data=raw_data
)
ax1.legend(loc='upper left', ncol=6)
ax2.get_legend().remove()
ax1.tick_params(axis='x', labelrotation=70)
ax2.tick_params(axis='x', labelrotation=70)
plt.tight_layout()

# Take aways: 
# - year and month are important. 
# - PM2.5 and PM10 follow the same trends
# - station name matters

In [None]:
# Evolution of pollution by day of the month

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 6))
sns.lineplot(
    ax=ax1,
    x='day', 
    y='PM2.5', 
    hue='station', 
    data=raw_data
)
sns.lineplot(
    ax=ax2,
    x='day', 
    y='PM10', 
    hue='station', 
    data=raw_data
)
ax1.legend(loc='upper left', ncol=6)
ax2.get_legend().remove()
ax1.tick_params(axis='x', labelrotation=70)
ax2.tick_params(axis='x', labelrotation=70)
plt.tight_layout()

# Take aways: 
# - day of the month matters

In [None]:
# Evolution of pollution by weekday

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 6))
sns.boxplot(
    ax=ax1,
    x='station', 
    y='PM2.5', 
    hue='weekday', 
    data=raw_data
)
sns.boxplot(
    ax=ax2,
    x='station', 
    y='PM10', 
    hue='weekday', 
    data=raw_data
)
ax1.legend(loc='upper left', ncol=7)
ax2.get_legend().remove()
ax1.tick_params(axis='x', labelrotation=0)
ax2.tick_params(axis='x', labelrotation=0)
plt.tight_layout()

# Take aways: 
# - day of the week matters a little bit

In [None]:
# Outliers in pollution levels (PM10 < PM2.5)
n_hard_outliers = raw_data[raw_data["PM10"] < raw_data["PM2.5"]].shape[0]
print(n_hard_outliers / raw_data.shape[0])

# Take away:
# 4% of samples are inconsistent outliers (PM2.5 are included in PM10)
# cf. https://aqicn.org/faq/2013-02-02/why-is-pm25-often-higher-than-pm10/

In [None]:
# Evolution of pollution by hour of the day

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 4))
sns.lineplot(
    ax=ax1,
    x='hour', 
    y='PM2.5', 
    hue='station', 
    data=raw_data
)
sns.lineplot(
    ax=ax2,
    x='hour', 
    y='PM10', 
    hue='station', 
    data=raw_data
)
ax1.legend(loc='upper left', ncol=6)
ax2.get_legend().remove()
ax1.tick_params(axis='x', labelrotation=0)
ax2.tick_params(axis='x', labelrotation=0)
plt.tight_layout()

# Take aways: 
# - hour of the day matters, differs by the station name

In [None]:
# Numerical variables

sns.pairplot(
    raw_data,
    kind="reg", 
    # kind="scatter", 
    y_vars=["PM2.5", "PM10"], 
    x_vars=['SO2', 'NO2', 'CO', 'O3', 'TEMP', 'PRES', 'DEWP', 'RAIN','WSPM'],
    # hue="station",
    plot_kws={
        'line_kws':{'color':'red'}, 
        'scatter_kws': {'alpha': 0.1}
        # 'alpha': 0.25
    }
)

# Take away:
# - NO2 and CO correlate well with PM concentration (should be salient in attention maps)
# - Maybe some outliers clusters in NO2, CO and O3

In [None]:
# Comparison of raw vs. preprocessed numerical features

num_features = ['SO2', 'NO2', 'CO', 'O3', 'TEMP', 'PRES', 'DEWP', 'RAIN','WSPM']
fig, axes = plt.subplots(nrows=len(num_features), ncols=2, figsize=(16, 10))

for i, f in enumerate(num_features):
    sns.histplot(
        ax=axes[i][0],
        stat="percent",
        x=f, 
        data=raw_data
    )
    # sns.stripplot(
    sns.histplot(
        ax=axes[i][1],
        stat="percent",
        x=f, 
        data=preprocessed_data
    )
plt.tight_layout()

# Take away:
# - need to compare various encoding schemes for numerical variables, e.g., 
# - Gorishniy et al.
# - Binning (quantization, decision-tree leaf, hand-made)
# - SAX + anchor (e.g., mean or starting value)
# - TabPFN

In [None]:
# From TabBERT
def _quantization_binning(nbins, data):
    qtls = np.arange(0.0, 1.0 + 1 / nbins, 1 / nbins)
    bin_edges = np.quantile(data, qtls, axis=0)
    bin_widths = np.diff(bin_edges, axis=0)
    bin_centers = bin_edges[:-1] + bin_widths / 2
    return bin_edges, bin_centers, bin_widths

def _quantize(nbins, inputs, bin_edges):
    quant_inputs = np.zeros(inputs.shape[0])
    for i, x in enumerate(inputs):
        quant_inputs[i] = np.digitize(x, bin_edges)
    quant_inputs = quant_inputs.clip(1, nbins) - 1
    return quant_inputs

# sns.histplot(raw_data["RAIN"], binwidth=1)
t = raw_data[["RAIN", "hour"]].groupby("RAIN").count()
print(t)

# From KBinsDiscretizer
kbd = KBinsDiscretizer(n_bins=50, strategy="uniform", encode="ordinal", subsample=None)
test = kbd.fit_transform(raw_data["RAIN"].to_numpy().reshape(-1, 1))

print(kbd.bin_edges_)
print(np.unique(test.squeeze(1), return_counts=True))

In [None]:
# Wind direction (wd) categorical variable data exploration

# fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 4))
sns.catplot(
    x='wd', 
    y='PM2.5', 
    # hue='station', 
    kind="swarm",
    data=raw_data.sample(10_000), # takes too long otherwise
    s=2,
)
sns.catplot(
    x='wd', 
    y='PM10', 
    # hue='station', 
    kind="swarm",
    data=raw_data.sample(10_000),
    s=2,
)

# Take-away:
# - WD alone does not seem very informative about the targets