## Model training

In [1]:
# base modules
from pathlib import Path
import math

# for manipulating data
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)

# for dataviz
from matplotlib import pyplot as plt
import seaborn as sns
%matplotlib inline

# for Machine Learning
from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor, GradientBoostingRegressor
from catboost import CatBoostRegressor, Pool, cv, CatBoostClassifier
from sklearn.model_selection import cross_val_score, KFold, GridSearchCV, train_test_split
from sklearn import metrics
from xgboost import XGBRegressor, plot_tree  # pip install xgboost
import optuna

# custom functions
from eclyon.transforms import process_df

In [2]:
df = pd.read_csv('cleaned_restaurant_inspection_data_3.csv')
pd.concat([df.head(), df.tail()])

  df = pd.read_csv('cleaned_restaurant_inspection_data_3.csv')


Unnamed: 0,CAMIS,DBA,BORO,BUILDING,STREET,ZIPCODE,PHONE,CUISINE DESCRIPTION,ACTION,VIOLATION CODE,VIOLATION DESCRIPTION,CRITICAL FLAG,SCORE,INSPECTION TYPE,Latitude,Longitude,Community Board,Council District,Census Tract,BIN,BBL,NTA,Location,days_since_last,avg_last_3_scores,inspection_year,inspection_month,inspection_weekday
0,30075445,MORRIS PARK BAKE SHOP,Bronx,1007,MORRIS PARK AVENUE,10462.0,7188924968,Bakery Products/Desserts,Establishment Closed by DOHMH. Violations were...,06D,"Food contact surface not properly washed, rins...",Critical,21.0,Cycle Inspection / Initial Inspection,40.848231,-73.855972,211.0,13.0,25200.0,2045445.0,2041270000.0,BX37,POINT (-73.855971889932 40.848231224526),730.0,25.02185,2023,1,1
1,30075445,MORRIS PARK BAKE SHOP,Bronx,1007,MORRIS PARK AVENUE,10462.0,7188924968,Bakery Products/Desserts,Establishment Closed by DOHMH. Violations were...,08C,Pesticide not properly labeled or used by unli...,Not Critical,21.0,Cycle Inspection / Initial Inspection,40.848231,-73.855972,211.0,13.0,25200.0,2045445.0,2041270000.0,BX37,POINT (-73.855971889932 40.848231224526),0.0,21.0,2023,1,1
2,30075445,MORRIS PARK BAKE SHOP,Bronx,1007,MORRIS PARK AVENUE,10462.0,7188924968,Bakery Products/Desserts,Establishment Closed by DOHMH. Violations were...,04L,Evidence of mice or live mice in establishment...,Critical,21.0,Cycle Inspection / Initial Inspection,40.848231,-73.855972,211.0,13.0,25200.0,2045445.0,2041270000.0,BX37,POINT (-73.855971889932 40.848231224526),0.0,21.0,2023,1,1
3,30075445,MORRIS PARK BAKE SHOP,Bronx,1007,MORRIS PARK AVENUE,10462.0,7188924968,Bakery Products/Desserts,Establishment Closed by DOHMH. Violations were...,06C,"Food, supplies, and equipment not protected fr...",Critical,21.0,Cycle Inspection / Initial Inspection,40.848231,-73.855972,211.0,13.0,25200.0,2045445.0,2041270000.0,BX37,POINT (-73.855971889932 40.848231224526),0.0,21.0,2023,1,1
4,30075445,MORRIS PARK BAKE SHOP,Bronx,1007,MORRIS PARK AVENUE,10462.0,7188924968,Bakery Products/Desserts,Establishment Closed by DOHMH. Violations were...,10F,Non-food contact surface or equipment made of ...,Not Critical,21.0,Cycle Inspection / Initial Inspection,40.848231,-73.855972,211.0,13.0,25200.0,2045445.0,2041270000.0,BX37,POINT (-73.855971889932 40.848231224526),0.0,21.0,2023,1,1
275284,50178121,DUANE PARK PATISSERIE,Manhattan,179,DUANE STREET,10013.0,9178627592,Bakery Products/Desserts,Violations were cited in the following area(s).,10F,Non-food contact surface or equipment made of ...,Not Critical,13.0,Pre-permit (Non-operational) / Initial Inspection,40.717399,-74.010137,101.0,1.0,3900.0,1077430.0,1001438000.0,MN24,POINT (-74.01013680549 40.717399104808),0.0,13.0,2025,10,2
275285,50178178,HISPANIC DELI GROCERY,Brooklyn,4916,4 AVENUE,11220.0,9298131664,Other,Violations were cited in the following area(s).,10F,Non-food contact surface or equipment made of ...,Not Critical,8.0,Pre-permit (Non-operational) / Initial Inspection,40.64688,-74.012158,307.0,38.0,7800.0,3012890.0,3007820000.0,BK32,POINT (-74.012158467614 40.646880154905),730.0,25.02185,2025,10,0
275286,50178178,HISPANIC DELI GROCERY,Brooklyn,4916,4 AVENUE,11220.0,9298131664,Other,Violations were cited in the following area(s).,10G,Dishwashing and ware washing: Cleaning and san...,Not Critical,8.0,Pre-permit (Non-operational) / Initial Inspection,40.64688,-74.012158,307.0,38.0,7800.0,3012890.0,3007820000.0,BK32,POINT (-74.012158467614 40.646880154905),0.0,8.0,2025,10,0
275287,50178190,BARKER CAFETERIA,Brooklyn,395,NOSTRAND AVENUE,11216.0,6462504215,Coffee/Tea,Violations were cited in the following area(s).,06C,"Food, supplies, or equipment not protected fro...",Critical,7.0,Pre-permit (Non-operational) / Initial Inspection,40.684235,-73.950336,303.0,36.0,24900.0,3051618.0,3018230000.0,BK75,POINT (-73.95033642083 40.684234700485),730.0,25.02185,2025,11,0
275288,50178194,BAGEL SCHMAGEL,Brooklyn,7510,3 AVENUE,11209.0,9178061856,Bagels/Pretzels,Violations were cited in the following area(s).,10F,Non-food contact surface or equipment made of ...,Not Critical,2.0,Pre-permit (Non-operational) / Initial Inspection,40.631304,-74.027867,310.0,47.0,6600.0,3148681.0,3059390000.0,BK31,POINT (-74.027867134865 40.631303517933),730.0,25.02185,2025,10,0


In [3]:
X, y, nas = process_df(df, y_field = 'SCORE')

In [4]:
X_train, X_valid, y_train, y_valid = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print("Training set size:", X_train.shape)
print("validation set size:", X_valid.shape)

Training set size: (220231, 35)
validation set size: (55058, 35)


XG boost model

In [12]:
def objective(trial, X_training, y_training, X_validation, y_validation):
    """
    This function estimates a trial quality.
    It contains the 4 steps : hp sampling, model definition, training and evaluation.
    
    Parameters:
    ----------
    trial (optuna.trial.Trial): the trial at stake
    X_training: pd.DataFrame. Table containing features used for training.
    X_validation: pd.DataFrame. Table containing features used for validation.
    y_training: array-like. List of true target variable values for training.
    y_validation: array-like. List of true target variable values for validation.

    Returns:
    -------
    float: performance metric value (R^2)
    """
    
    learning_rate_guess = trial.suggest_float("learning_rate", 1e-2, 1, log=True)
    max_depth_guess = trial.suggest_int("max_depth", 2, 10, log=False)
    reg_lambda_guess = trial.suggest_int("reg_lambda", 1, 50, log=False)
    
    xgb = XGBRegressor(
        n_estimators=1000,                    
        early_stopping_rounds=None,           
        random_state=42,
        learning_rate=learning_rate_guess,       
        max_depth=max_depth_guess, 
        reg_lambda=reg_lambda_guess,
        verbosity=0             
    )
    
    xgb.fit(X_training, y_training)
    
    y_pred_validation = xgb.predict(X_validation)
    r2_validation = metrics.r2_score(y_validation, y_pred_validation)
    return r2_validation