In [1]:
import mlflow

import pandas as pd
from pandas.tseries.holiday import USFederalHolidayCalendar as calendar

import xgboost as xgb

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.linear_model import LinearRegression , Lasso
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

import pickle

In [2]:
mlflow.set_tracking_uri("http://127.0.0.1:5000") 
mlflow.set_experiment("Electricity Demand Prediction")

<Experiment: artifact_location='mlflow-artifacts:/1', creation_time=1723467087565, experiment_id='1', last_update_time=1723467087565, lifecycle_stage='active', name='Electricity Demand Prediction', tags={}>

In [5]:
def read_dataframe(filename):
    df = pd.read_csv(filename)
    df['date'] = pd.to_datetime(df['date'])
    df['demand'] = pd.to_numeric(df['demand'], errors='coerce').astype('float')
    df['year'] = df['date'].dt.year
    df['month'] = df['date'].dt.month
    df['day'] = df['date'].dt.day
    df['hr'] = df['date'].dt.hour
    df['day_of_week'] = df['date'].dt.dayofweek  # Monday=0, Sunday=6
    df['is_weekend'] = df['date'].dt.dayofweek >= 5  # True if weekend, False otherwise
    
    holidays = calendar().holidays(start=df['date'].min(), end=df['date'].max())
    df['holiday'] = df['date'].isin(holidays).astype(int)
    #display(df.head())
    #print(df.dtypes)
    
    
    return df

In [6]:
dataset_path = "/workspaces/Electricity-Demand-Prediction/Model Training/Data/dataset.csv"

In [7]:
df = read_dataframe(dataset_path)

In [8]:
# Define features and target
X = df.drop('demand', axis=1)
X = X.drop('date', axis=1)

y = df['demand']

In [9]:
# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

In [10]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [13]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        mlflow.xgboost.autolog()
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [14]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:2787.03254                        
[1]	validation-rmse:2469.83188                        
  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[2]	validation-rmse:2196.47983                        
[3]	validation-rmse:1959.51200                        
[4]	validation-rmse:1756.22405                        
[5]	validation-rmse:1580.97799                        
[6]	validation-rmse:1432.29101                        
[7]	validation-rmse:1304.16423                        
[8]	validation-rmse:1196.28346                        
[9]	validation-rmse:1103.89176                        
[10]	validation-rmse:1026.88423                       
[11]	validation-rmse:964.29367                        
[12]	validation-rmse:911.04022                        
[13]	validation-rmse:867.15222                        
[14]	validation-rmse:831.31282                        
[15]	validation-rmse:802.56106                        
[16]	validation-rmse:779.07012                        
[17]	validation-rmse:759.39568                        
[18]	validation-rmse:743.60944                        
[19]	validation-rmse:730.14773                        
[20]	valid





2024/08/12 13:36:43 INFO mlflow.tracking._tracking_service.client: 🏃 View run luminous-panda-257 at: http://127.0.0.1:5000/#/experiments/1/runs/0207f760d83b4fc48a74473b80057a6c.

2024/08/12 13:36:43 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



  2%|▏         | 1/50 [00:27<22:32, 27.60s/trial, best loss: 629.9418010287991]





[0]	validation-rmse:2685.18919                                                 
[1]	validation-rmse:2298.99564                                                 
[2]	validation-rmse:1979.12943                                                 
[3]	validation-rmse:1715.92835                                                 
[4]	validation-rmse:1501.32157                                                 
[5]	validation-rmse:1326.79624                                                 
[6]	validation-rmse:1187.49574                                                 
[7]	validation-rmse:1075.50117                                                 
[8]	validation-rmse:986.78835                                                  
[9]	validation-rmse:914.40915                                                  
[10]	validation-rmse:860.62128                                                 
[11]	validation-rmse:816.53067                                                 
[12]	validation-rmse:782.68591          




2024/08/12 13:37:05 INFO mlflow.tracking._tracking_service.client: 🏃 View run respected-slug-726 at: http://127.0.0.1:5000/#/experiments/1/runs/7027400f385844b6bde797b80c83ca9c.

2024/08/12 13:37:05 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



  4%|▍         | 2/50 [00:50<19:40, 24.58s/trial, best loss: 490.25478326898735]





[0]	validation-rmse:2249.30232                                                  
[1]	validation-rmse:1658.89978                                                  
[2]	validation-rmse:1291.95033                                                  
[3]	validation-rmse:1077.31386                                                  
[4]	validation-rmse:962.34974                                                   
[5]	validation-rmse:902.97061                                                   
[6]	validation-rmse:875.53997                                                   
[7]	validation-rmse:862.34148                                                   
[8]	validation-rmse:857.57323                                                   
[9]	validation-rmse:855.88674                                                   
[10]	validation-rmse:855.62285                                                  
[11]	validation-rmse:855.31805                                                  
[12]	validation-rmse:855.282





: 