In [1]:
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV

In [2]:
##########################################################################
# Important Note: This notebook only uses one csv file to test the model #
##########################################################################

In [3]:
# assign processed data to data_files
cwd = os.getcwd()
data_dir = os.path.join(cwd, "data/")
data_files = [f for f in os.listdir(str(data_dir)) if f.endswith('csv')]

data = pd.read_csv(os.path.join(data_dir,data_files[0]))
data

Unnamed: 0,id,tdrift,tdrift50,tdrift10,rea,dcr,peakindex,peakvalue,tailslope,currentamp,lfpr,lq80,areagrowthrate,inflection point,risingedgeslope,energylabel,highavse,lowavse,truedcr,lq
0,650000,75.924,38.0,7.6,0.306251,5.647056e+05,1040,1236.0,-0.107739,0.006022,0.016692,137120.0,-525788.0,332,20.697813,449.096599,True,True,True,True
1,650001,82.917,41.5,8.3,0.610983,5.467540e+05,1049,1341.0,-0.110221,0.003851,0.018183,214469.0,-536552.0,337,17.853155,472.905768,True,True,True,True
2,650002,115.884,58.0,11.6,1.585428,2.801101e+05,1052,775.0,-0.056455,0.005939,0.023463,188754.0,-268393.5,337,4.624622,240.859403,True,True,True,True
3,650003,128.871,64.5,12.9,-0.471883,8.859788e+04,1068,242.0,-0.012898,0.004452,0.010622,109347.0,-64318.0,360,1.477667,56.433683,True,True,True,True
4,650004,95.904,48.0,9.6,0.903343,1.759048e+06,1060,3802.0,-0.358300,0.006815,0.019583,203822.0,-1714558.5,288,46.383844,1512.367517,True,True,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1689996,649995,103.896,52.0,10.4,1.386256,2.888775e+05,1037,657.0,-0.055199,0.006327,0.021953,78290.0,-265790.5,308,5.563203,238.264153,True,True,True,True
1689997,649996,44.955,22.5,4.5,-0.555437,8.694774e+05,1023,1930.0,-0.169472,0.007386,0.017672,185348.0,-848319.5,271,50.753887,727.778128,False,False,True,False
1689998,649997,165.834,83.0,16.6,1.516233,2.478339e+05,1066,676.0,-0.050436,0.004362,0.024428,171916.0,-217168.0,300,2.899391,206.356597,True,True,True,True
1689999,649998,95.904,48.0,9.6,0.647120,3.457733e+06,1073,6729.0,-0.701681,0.005781,0.018728,348388.0,-3073543.0,342,85.845334,2615.179984,True,True,True,True


In [11]:


# Drop columns needed for classification group
boolean_col = ['highavse','lowavse','truedcr','lq']
data_filtered = data.drop(columns=boolean_col+['id'])


# Find and Drop rows with missing values
data_filtered = data_filtered.dropna()


# Drop irrelevant features and feature with multicollinearity 
data_filtered = data_filtered.drop(columns=['tdrift50','tdrift10'])
X = data_filtered.drop(columns=['energylabel'])
y = data_filtered['energylabel']

In [17]:
# Train test split + standardization
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3)
scaler = StandardScaler()
X_train_standardized = scaler.fit_transform(X_train)
X_test_standardized = scaler.transform(X_test)

# GridCV to find best lambda
alpha_range = np.logspace(-5, 2.5, 100)
param_grid = {"alpha": alpha_range}
ridge = Ridge()
grid_search = GridSearchCV(estimator=ridge, param_grid={"alpha": alpha_range}, cv=5, scoring='neg_mean_squared_error')

# Fit the model
grid_search.fit(X_train_standardized, y_train)

best_alpha = grid_search.best_params_['alpha']
print(f"lambda: {best_alpha}")

# Train the Ridge Regression model with the best alpha
ridge_reg = Ridge(alpha=best_alpha)
ridge_reg.fit(X_train_standardized,y_train)
y_pred = ridge_reg.predict(X_test_standardized)

# Evaluate the model
mse = mean_squared_error(y_test, y_pred)
r2= r2_score(y_test, y_pred)

lambda: 1e-05


In [18]:
mse,r2

(4904.623324733912, 0.9890781350334877)

In [19]:
np.mean(abs(y_pred - y_test))

16.361993487043588

In [16]:
data_filtered.corr()

Unnamed: 0,tdrift,rea,dcr,peakindex,peakvalue,tailslope,currentamp,lfpr,lq80,areagrowthrate,inflection point,risingedgeslope,energylabel
tdrift,1.0,-0.353646,-0.119906,0.966816,-0.104132,0.112163,-0.12991,-0.287289,0.20041,0.119038,0.131842,-0.207463,-0.104336
rea,-0.353646,1.0,0.028737,-0.431403,0.029845,-0.025223,0.245087,0.649042,-0.261714,-0.038562,-0.289455,-0.069715,0.028187
dcr,-0.119906,0.028737,1.0,0.027008,0.997981,-0.996803,0.213819,0.089385,0.684834,-0.998767,-0.040027,0.917027,0.992463
peakindex,0.966816,-0.431403,0.027008,1.0,0.042621,-0.033552,-0.106819,-0.278013,0.317134,-0.027973,0.147367,-0.058978,0.042785
peakvalue,-0.104132,0.029845,0.997981,0.042621,1.0,-0.996964,0.198594,0.091994,0.711452,-0.999103,-0.045428,0.905183,0.994342
tailslope,0.112163,-0.025223,-0.996803,-0.033552,-0.996964,1.0,-0.201432,-0.0878,-0.704431,0.996307,0.03837,-0.902061,-0.991746
currentamp,-0.12991,0.245087,0.213819,-0.106819,0.198594,-0.201432,1.0,0.093334,-0.110651,-0.214353,-0.02607,0.244163,0.202547
lfpr,-0.287289,0.649042,0.089385,-0.278013,0.091994,-0.0878,0.093334,1.0,-0.078781,-0.094215,-0.240215,0.024977,0.092354
lq80,0.20041,-0.261714,0.684834,0.317134,0.711452,-0.704431,-0.110651,-0.078781,1.0,-0.683208,0.038102,0.554716,0.696343
areagrowthrate,0.119038,-0.038562,-0.998767,-0.027973,-0.999103,0.996307,-0.214353,-0.094215,-0.683208,1.0,0.046493,-0.913401,-0.994058


In [15]:
X_train

Unnamed: 0,tdrift,rea,dcr,peakindex,peakvalue,tailslope,currentamp,lfpr,lq80,areagrowthrate,inflection point,risingedgeslope
305058,116.883,0.609233,543485.40625,1070,1209.0,-0.106624,0.006071,0.021196,123293.0,-514048.0,304,12.729758
637456,97.902,0.621089,275333.50000,1057,674.0,-0.054388,0.003855,0.017666,115033.0,-264010.5,298,7.462171
133917,85.914,0.954826,301019.25000,1034,694.0,-0.056668,0.005541,0.021488,111789.0,-272781.0,342,7.995679
1041655,117.882,1.119602,407047.71875,1061,892.0,-0.080081,0.004780,0.021254,77162.0,-378864.0,286,7.815813
274326,110.889,-0.157120,557659.71875,1067,1197.0,-0.106893,0.004438,0.020023,126866.0,-508111.5,380,14.499245
...,...,...,...,...,...,...,...,...,...,...,...,...
14370,64.935,0.499625,265744.68750,1018,625.0,-0.048555,0.005932,0.018987,112919.0,-240992.0,361,10.656818
292682,209.790,-0.785466,241984.37500,1151,563.0,-0.041015,0.004920,0.015914,192199.0,-177678.0,345,2.519041
183641,72.927,0.127937,614589.84375,1042,1349.0,-0.119925,0.006110,0.018323,145372.0,-578553.0,318,23.982229
849596,97.902,1.050445,159639.59375,1031,411.0,-0.028415,0.005574,0.016779,107261.0,-138976.0,386,3.548273
