## 3. Model Training

#### A360 MDK interface

In [1]:
a360ai

<A360 AI Interface for project: Product Demand Forecasting>

In [2]:
# Get default data repo
DATAREPO_LIST = a360ai.list_datarepos()
DATAREPO = DATAREPO_LIST['name'][0]
DATAREPO

'Product Demand Forecasting'

In [3]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
from matplotlib import pyplot as plt

In [4]:
import warnings
warnings.filterwarnings('ignore')

### 1. Set Default Data Repo and load data

In [5]:
a360ai.set_default_datarepo(DATAREPO)

In [6]:
a360ai.list_datasets()

Unnamed: 0,base_name,extension,size
1,X.csv,csv,97672.0
2,X.parquet,parquet,20323.0
3,product-demand-2015-2020-a.csv,csv,44133.0
4,rf_model.pkl,pkl,11465124.0
5,y.csv,csv,10765.0
6,y.parquet,parquet,5080.0


In [7]:
X = a360ai.load_dataset("X.parquet")
y = a360ai.load_dataset("y.parquet")

#### Fliter Train/ Val Set to only include 2015-2019 (pre-COVID)
- Train: 2015 Jan- 2019 Sep 
- Val: 2019 Oct- Dec

In [7]:
df_X_train = X.drop(X.index[list(range(1734, 1946))])
df_y_train = y['sales'][0:1734]

In [8]:
df_X_train.shape, df_y_train.shape

((1734, 20), (1734,))

In [9]:
df_X_val = X.loc[1734:1824]
df_y_val = y['sales'][1734:1825]
df_X_val = df_X_val.reset_index(drop=True)
df_y_val = df_y_val.reset_index(drop=True)

In [10]:
df_X_val.shape, df_y_val.shape

((91, 20), (91,))

#### Set up data for COVID period (2020 Jan- April) to study potential drift 

In [11]:
df_X_covid = X.loc[1825:]
df_y_covid = y['sales'][1825:]
df_X_covid = df_X_covid.reset_index(drop=True)
df_y_covid = df_y_covid.reset_index(drop=True)

In [12]:
df_X_covid.shape, df_y_covid.shape

((121, 20), (121,))

### 2. Model training

In [13]:
X_train, X_test, y_train, y_test = train_test_split(df_X_train, df_y_train, test_size=0.15, random_state=0)

#### Use a360ai MDK to track model experiments

In [14]:
model = a360ai.get_or_create_model(model_name=f"demand_forecast_demo_6_9", model_type = "regression")

In [15]:
experiment = model.get_or_create_experiment(
    experiment_name = f"cdemand_forecast_demo_RF_6_9",
    model_flavor = "sklearn",
    enable_drift_monitoring=True,
    train_features=X_train,
    train_target=y_train,
    feature_names=list(X.columns),
    data_exploration_file="/home/jovyan/01_exploratory-data-analysis.ipynb",
    data_preparation_file="/home/jovyan/02_data-preprocessing.ipynb",
    model_training_file="/home/jovyan/03_model-training.ipynb",
)


//-- Experiment Loaded --//
Model Name: demand_forecast_demo_6_9
Experiment Name: cdemand_forecast_demo_RF_6_9
Final Run Id: None
Model Flavor: sklearn
Input Signature: trend-index float, day_of_month big_integer, day_of_year big_integer, week_of_year big_integer, year big_integer, is_wknd big_integer, is_month_start big_integer, is_month_end big_integer, month_1 small_integer, month_2 small_integer, month_3 small_integer, month_4 small_integer, month_5 small_integer, month_6 small_integer, month_7 small_integer, month_8 small_integer, month_9 small_integer, month_10 small_integer, month_11 small_integer, month_12 small_integer
Output Signature: sales float
Data Exploration File: /home/jovyan/01_exploratory-data-analysis.ipynb
Data Preparation File: /home/jovyan/02_data-preprocessing.ipynb
Model Training File: /home/jovyan/03_model-training.ipynb
Drift Monitoring Enabled: True



In [16]:
experiments = model.list_experiments()
experiments

Unnamed: 0,id,bestRun,model_id,experiment_name,best_run_id,model_flavor,input_signature,output_signature,data_exploration_file,data_preparation_file,model_training_file,baseline,train_shape,model_name,updated_at
0,9e60ec24-5bdd-4d1b-bd20-bfc4986a295e,,c789c008-c7cc-4f7a-b4ac-ce9378b9c6f0,cdemand_forecast_demo_RF_6_9,,sklearn,"trend-index float, day_of_month big_integer, d...",sales float,/home/jovyan/01_exploratory-data-analysis.ipynb,/home/jovyan/02_data-preprocessing.ipynb,/home/jovyan/03_model-training.ipynb,"{'trend-index': {'mean': 36.36070604209099, 's...",[20],demand_forecast_demo_6_9,2022-06-09 16:14


In [17]:
count = 1
for n in [25,50,75,100,125,150]:
    with experiment.run_experiment() as run:
    
        hyperparams = {
            "n_estimators": n,
            "max_depth": 6            
        }
    
        rf = RandomForestRegressor(**hyperparams)
        rf.fit(X_train, y_train)
        train_score = rf.score(X_train, y_train)
        test_score = rf.score(X_test, y_test)

        metrics = {
            "train_score": train_score,
            "test_score": test_score,
        }

        run.log_metrics(metrics)
        run.log_hyperparameters(hyperparams)
        run.log_model(rf)
        
        print("Run " + str(count) + " Complete!")
        count = count + 1 

Run 1 Complete!
Run 2 Complete!
Run 3 Complete!
Run 4 Complete!
Run 5 Complete!
Run 6 Complete!


In [18]:
runs = experiment.list_runs()
runs

Unnamed: 0,id,dataset,artifact_paths,metric_test_score,metric_train_score,hyperparameter_max_depth,hyperparameter_n_estimators,metadata_run_time
0,f3903739-5459-417e-9aa6-7bad4a8f5527,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': 'd9c76eb1-166f-4b02-84c8-ef0630d3c606'...,0.874706,0.873968,6,25,0.081347
1,089cb908-6d3f-480a-9842-7fd694b161ea,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': '4640798c-1b43-4597-976b-e1492fb291f2'...,0.875468,0.873814,6,50,0.154114
2,786f754e-9e42-4ec8-be7b-8f043a549ac7,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': 'baecb487-387c-44f0-beb1-06ecaaed9d6c'...,0.876737,0.877087,6,75,0.225202
3,25a8f488-90e4-45b5-8124-5f5a2660842e,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': '34b13e1e-ad83-481b-9cf4-aa87561832fe'...,0.875668,0.873892,6,100,0.293164
4,3ebfbb29-d791-4c92-a511-a9cd04c49735,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': 'a25bca90-9d48-4f12-af01-37bde121089e'...,0.875978,0.876713,6,125,0.369278
5,582e99d5-b366-48c4-a51c-765527d51f51,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': '6bdeaf14-4f74-4e77-b816-2da36aad6df3'...,0.876782,0.875965,6,150,0.420461


In [19]:
runs.sort_values('metric_test_score',ascending=False)

Unnamed: 0,id,dataset,artifact_paths,metric_test_score,metric_train_score,hyperparameter_max_depth,hyperparameter_n_estimators,metadata_run_time
5,582e99d5-b366-48c4-a51c-765527d51f51,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': '6bdeaf14-4f74-4e77-b816-2da36aad6df3'...,0.876782,0.875965,6,150,0.420461
2,786f754e-9e42-4ec8-be7b-8f043a549ac7,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': 'baecb487-387c-44f0-beb1-06ecaaed9d6c'...,0.876737,0.877087,6,75,0.225202
4,3ebfbb29-d791-4c92-a511-a9cd04c49735,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': 'a25bca90-9d48-4f12-af01-37bde121089e'...,0.875978,0.876713,6,125,0.369278
3,25a8f488-90e4-45b5-8124-5f5a2660842e,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': '34b13e1e-ad83-481b-9cf4-aa87561832fe'...,0.875668,0.873892,6,100,0.293164
1,089cb908-6d3f-480a-9842-7fd694b161ea,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': '4640798c-1b43-4597-976b-e1492fb291f2'...,0.875468,0.873814,6,50,0.154114
0,f3903739-5459-417e-9aa6-7bad4a8f5527,de7922d4-2b3b-4721-883f-605ab14a1954,[{'id': 'd9c76eb1-166f-4b02-84c8-ef0630d3c606'...,0.874706,0.873968,6,25,0.081347


#### Get the best run from model experiments

In [20]:
best_run_id = runs.sort_values("metric_test_score", ascending=False).id.values[0]
best_run_id

'582e99d5-b366-48c4-a51c-765527d51f51'

In [21]:
model.set_final_run(experiment, best_run_id)

#### Prediction for Val data: 2019 Oct-Dec

In [None]:
y_pred_val = rf.predict(df_X_val)

In [None]:
fig = plt.figure(figsize=(20, 5))
ax = plt.gca()
#ax.xaxis.set_major_locator(locator)
plt.plot(y_pred_val, label='Pred')
plt.plot(df_y_val, label='Obs')
plt.tight_layout()
plt.title('Product Demand Forecast')
plt.xlabel('day_of_month')
plt.ylabel("Sales")
plt.legend()
plt.show()

#### Model prediction for COVID data: 2020 Jan-April

In [None]:
y_pred_covid = rf.predict(df_X_covid)

In [None]:
fig = plt.figure(figsize=(20, 5))
ax = plt.gca()
#ax.xaxis.set_major_locator(locator)
plt.plot(y_pred_covid, label='Pred')
plt.plot(df_y_covid, label='Obs')
plt.tight_layout()
plt.title('Product Demand Forecast')
plt.xlabel('day_of_month')
plt.ylabel("Sales")
plt.legend()
plt.show()