# LSTM Ray Training

## Data Preprocessing
1. target mro selection -- ["mro"] or ["sub_mro1", "sub_mro2", ...]
2. add previous mro
3. dealing with purchase time
1. Standardization
    - Continuous Features
    - Categorical Features


--- 

### Target MRO Selection

In [1]:
import pandas as pd

file_name = "./Data/mro_daily_clean.csv"
data = pd.read_csv(file_name, index_col=0, engine="pyarrow")

# control parameter: target_mro
# a list defined by user
target_mro: list = ["mro"]


mro_detail = [
    "battery_dummy",
    "brake_dummy",
    "tire_dummy",
    "lof_dummy",
    "wiper_dummy",
    "filter_dummy",
    "others",
]
if target_mro == ["mro"]:
    data["target_mro"] = data["mro"]
elif isinstance(target_mro, list) and all(col in mro_detail for col in target_mro):
    data["target_mro"] = data[target_mro].max(axis=1)
else:
    print("Target MRO is defined with error")
    print("Use the mro as default mro")
    target_mro = ["mro"]
    data["target_mro"] = data["mro"]

---

### Add Previous MRO

In [2]:
# control parameter: add_mro_prev
add_mro_prev: bool = True


if add_mro_prev:
    data.sort_values(by=["id", "yr_nbr", "week_nbr"], inplace=True)
    data["mro_prev"] = data.groupby("id")["mro"].shift(1)
    mro_prev = ["mro_prev"]
else:
    mro_prev = []

---

### Dealing with Purchase Time

In [3]:
# control parameter: add_purchase_time
add_purchase_time: bool = True


if add_purchase_time:
    data["purchase_month"] = data["purchase_mth_nbr"].astype(int)
    # devide into 2 bins: 1-6 is the first half, 7-12 is the second half
    data["purchase_half_year"] = pd.cut(
        data["purchase_month"], bins=[0, 6, 12], labels=["first_half", "second_half"]
    )

    data["purchase_time"] = (
        data["purchase_yr_nbr"].astype(int).astype(str)
        + "_"
        + data["purchase_half_year"].astype(str)
    )

    purchase_time = ["purchase_time"]
else:
    purchase_time = []

---

### Weekly Aggregation

In [4]:
continuous_variable = [
    "hard_braking",
    "hard_acceleration",
    "speeding_sum",
    "day_mileage",
    "engn_size",
    "est_hh_incm_prmr_cd",
    "purchaser_age_at_tm_of_purch",
    "tavg",
    "random_avg_traffic",
]

category_variable = [
    "gmqualty_model",
    "umf_xref_finc_gbl_trim",
    "input_indiv_gndr_prmr_cd",
] + purchase_time

driver_navigation = [
    "id",
    "yr_nbr",
    "mth_nbr",
    "week_nbr",
]

data = data[
    driver_navigation
    + continuous_variable
    + category_variable
    + mro_prev
    + ["target_mro"]
]

In [5]:
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from utils import create_train_test_group

# control parameter: aggregation function
agg_fun = ["mean", "sum", "max", "min", "std", "skew"]


agg_rules = {
    "mth_nbr": "first",
    "target_mro": "max",
    "hard_braking": agg_fun,
    "hard_acceleration": agg_fun,
    "speeding_sum": agg_fun,
    "day_mileage": agg_fun,
    "est_hh_incm_prmr_cd": "first",
    "purchaser_age_at_tm_of_purch": "first",
    "input_indiv_gndr_prmr_cd": "first",
    "gmqualty_model": "first",
    "umf_xref_finc_gbl_trim": "first",
    "engn_size": "first",
    "tavg": agg_fun,
    "random_avg_traffic": agg_fun,
}
if add_mro_prev:
    agg_rules["mro_prev"] = "max"
if add_purchase_time:
    agg_rules["purchase_time"] = "first"


data = data.groupby(["id", "yr_nbr", "week_nbr"]).agg(agg_rules)

data.reset_index(inplace=True)


def flatten_columns(df: pd.DataFrame):
    def clean_col(col):
        if isinstance(col, tuple):
            col_name, agg_func = col
            agg_func = agg_func.strip()
            if col_name in (["target_mro"] + mro_prev) and agg_func == "max":
                return col_name
            if agg_func in ("first", ""):
                return col_name
            return f"{col_name}_{agg_func}"
        else:
            return col

    df.columns = [clean_col(col) for col in df.columns]
    return df


data = flatten_columns(data)
data.fillna(0, inplace=True)
data = data.drop(["yr_nbr", "week_nbr", "mth_nbr"], axis=1)

data

Unnamed: 0,id,target_mro,hard_braking_mean,hard_braking_sum,hard_braking_max,hard_braking_min,hard_braking_std,hard_braking_skew,hard_acceleration_mean,hard_acceleration_sum,...,tavg_std,tavg_skew,random_avg_traffic_mean,random_avg_traffic_sum,random_avg_traffic_max,random_avg_traffic_min,random_avg_traffic_std,random_avg_traffic_skew,mro_prev,purchase_time
0,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,1.000000,1,1,1,0.000000,0.000000,1.000000,1,...,0.000000,0.000000,12886.225115,12886.225115,12886.225115,12886.225115,0.000000,0.000000,0.0,2018_first_half
1,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,0.500000,1,1,0,0.707107,0.000000,0.000000,0,...,0.013972,0.000000,14554.620499,29109.240997,14584.257628,14524.983370,41.913230,0.000000,0.0,2018_first_half
2,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,6.500000,39,14,0,5.856620,0.058243,0.833333,5,...,0.032139,0.632224,14559.007102,87354.042609,14632.187360,14500.482353,49.103724,0.235793,0.0,2018_first_half
3,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,8.571429,60,20,0,6.754187,0.564530,1.714286,12,...,0.436051,0.370157,14433.908044,101037.356310,14573.240317,14282.035211,119.119692,-0.006491,0.0,2018_first_half
4,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,7.142857,50,22,0,8.395010,0.989005,1.428571,10,...,0.033094,-0.056727,14390.440682,100733.084776,14463.821187,14301.403433,64.521592,-0.267357,0.0,2018_first_half
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3972098,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,0.0,5.000000,30,10,1,3.633180,0.150131,0.666667,4,...,0.027423,-0.143138,9881.400949,59288.405694,9977.887549,9821.257020,58.926955,0.867606,0.0,2018_first_half
3972099,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,0.0,4.333333,13,9,0,4.509250,0.330832,0.333333,1,...,0.035240,1.139904,9743.572901,29230.718703,9790.153155,9690.645645,50.056462,-0.561412,0.0,2018_first_half
3972100,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,0.0,2.600000,13,5,0,2.302173,0.196697,0.400000,2,...,0.027980,0.262250,9732.679965,48663.399826,9838.082886,9652.890848,69.868338,0.727995,0.0,2018_first_half
3972101,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,0.0,1.166667,7,3,0,1.169045,0.667628,0.000000,0,...,6.938263,-2.449397,10792.873524,64757.241143,16161.000000,9681.779573,2629.969913,2.448921,0.0,2018_first_half


---

### Standardization

In [6]:
col_need_std = [
    item
    for item in data.columns.values.tolist()
    if item not in (["target_mro"] + mro_prev + ["id"] + category_variable)
]

col_need_encode = category_variable


scaler = StandardScaler()
data[col_need_std] = scaler.fit_transform(data[col_need_std])


encoder = OneHotEncoder(sparse_output=False)
encoded_categorical = encoder.fit_transform(data[col_need_encode])

category_counts = [len(encoder.categories_[i]) for i, _ in enumerate(col_need_encode)]

onehot_feature_names = []
for col_idx, col in enumerate(col_need_encode):
    num_categories = category_counts[col_idx]
    onehot_feature_names.extend([f"{col}_onehot_{i}" for i in range(num_categories)])

encoded_df = pd.DataFrame(
    encoded_categorical, index=data.index, columns=onehot_feature_names
)
data = pd.concat([data, encoded_df], axis=1)
data = data.drop(columns=col_need_encode)

data

Unnamed: 0,id,target_mro,hard_braking_mean,hard_braking_sum,hard_braking_max,hard_braking_min,hard_braking_std,hard_braking_skew,hard_acceleration_mean,hard_acceleration_sum,...,umf_xref_finc_gbl_trim_onehot_3,umf_xref_finc_gbl_trim_onehot_4,umf_xref_finc_gbl_trim_onehot_5,umf_xref_finc_gbl_trim_onehot_6,input_indiv_gndr_prmr_cd_onehot_0,input_indiv_gndr_prmr_cd_onehot_1,purchase_time_onehot_0,purchase_time_onehot_1,purchase_time_onehot_2,purchase_time_onehot_3
0,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,-0.926556,-1.047504,-1.252611,-0.216080,-1.272027,-0.574747,-0.160523,-0.725081,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
1,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,-1.038839,-1.047504,-1.252611,-0.509626,-1.014112,-0.574747,-0.880683,-0.905059,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
2,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,0.308562,0.785086,0.615113,-0.509626,0.864153,-0.512343,-0.280550,-0.005170,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
3,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,0.773737,1.797833,1.477139,-0.509626,1.191537,0.030121,0.353877,1.254675,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
4,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.0,0.452927,1.315573,1.764481,-0.509626,1.790021,0.484926,0.148117,0.894719,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3972098,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,0.0,-0.028288,0.351052,0.040428,-0.216080,0.053162,-0.413888,-0.400576,-0.185148,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
3972099,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,0.0,-0.177999,-0.468791,-0.103243,-0.509626,0.372705,-0.220276,-0.640630,-0.725081,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
3972100,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,0.0,-0.567249,-0.468791,-0.677927,-0.509626,-0.432318,-0.363996,-0.592619,-0.545103,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
3972101,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,0.0,-0.889128,-0.758148,-0.965269,-0.509626,-0.845622,0.140586,-0.880683,-0.905059,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0


---

In [9]:
# control parameter: sample_frac, test_size, valid_size
sample_frac = 1.0
test_size = 0.1
valid_size = 0.1


rnn_features = col_need_std + onehot_feature_names + mro_prev
rnn_target = ["target_mro"]
col_rnn_origin = ["id"] + rnn_features + rnn_target
data_rnn_origin = data[col_rnn_origin].copy()
data_rnn_origin = create_train_test_group(
    data_rnn_origin,
    sample_frac=sample_frac,
    test_size=test_size,
    valid_size=valid_size,
    random_state=42,
)

data_rnn_origin

Unnamed: 0,id,hard_braking_mean,hard_braking_sum,hard_braking_max,hard_braking_min,hard_braking_std,hard_braking_skew,hard_acceleration_mean,hard_acceleration_sum,hard_acceleration_max,...,umf_xref_finc_gbl_trim_onehot_6,input_indiv_gndr_prmr_cd_onehot_0,input_indiv_gndr_prmr_cd_onehot_1,purchase_time_onehot_0,purchase_time_onehot_1,purchase_time_onehot_2,purchase_time_onehot_3,mro_prev,target_mro,group
0,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,-0.926556,-1.047504,-1.252611,-0.216080,-1.272027,-0.574747,-0.160523,-0.725081,-0.695076,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,train
1,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,-1.038839,-1.047504,-1.252611,-0.509626,-1.014112,-0.574747,-0.880683,-0.905059,-1.131083,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,train
2,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.308562,0.785086,0.615113,-0.509626,0.864153,-0.512343,-0.280550,-0.005170,1.048950,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,train
3,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.773737,1.797833,1.477139,-0.509626,1.191537,0.030121,0.353877,1.254675,0.176937,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,train
4,w4HClcKnwrzCv8KgwrjDi8Klwr3Cm8KVwqfCrsKowprClg==,0.452927,1.315573,1.764481,-0.509626,1.790021,0.484926,0.148117,0.894719,0.612944,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3972098,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,-0.028288,0.351052,0.040428,-0.216080,0.053162,-0.413888,-0.400576,-0.185148,-0.259070,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,test
3972099,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,-0.177999,-0.468791,-0.103243,-0.509626,0.372705,-0.220276,-0.640630,-0.725081,-0.695076,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,test
3972100,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,-0.567249,-0.468791,-0.677927,-0.509626,-0.432318,-0.363996,-0.592619,-0.545103,-0.259070,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,test
3972101,wrbCt8K1wrvDi8KtwrjCtMONwrvCrsKXwqbCqcKqwpzCnA==,-0.889128,-0.758148,-0.965269,-0.509626,-0.845622,0.140586,-0.880683,-0.905059,-1.131083,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,test


---

## Model Training
### Build the data loader

In [12]:
from model import mroRnnDataset
# ---------------------------------------------------------
max_seq_length = 8

train_data_set = mroRnnDataset(
    data_rnn_origin=data_rnn_origin,
    rnn_features=rnn_features,
    rnn_target=rnn_target,
    group="train",
    max_seq_length=max_seq_length,
)

val_data_set = mroRnnDataset(
    data_rnn_origin=data_rnn_origin,
    rnn_features=rnn_features,
    rnn_target=rnn_target,
    group="valid",
    max_seq_length=max_seq_length,
)

In [13]:
input_feature_size = len(rnn_features)
output_size = len(rnn_target)

In [15]:
alpha = 1 - data_rnn_origin["target_mro"].eq(1).mean()
print(f"Alpha value for Focal Loss: {alpha}")
gamma = 4
print(f"Gamma value for Focal Loss: {gamma}")

Alpha value for Focal Loss: 0.9523861793110602
Gamma value for Focal Loss: 4
