In [1]:
import numpy as np
import pandas as pd
from pathlib import Path

f = np.load(str(Path('../data/features_(n_iterations, n_wells, n_dates, n_features).npy')))
t = np.load(str(Path('../data/targets_(n_iterations, n_wells, n_dates).npy')))

n_iterations, n_wells, n_dates, n_features = f.shape
assert t.shape == (n_iterations, n_wells, n_dates)

print(f.shape)
print(t.shape)

assert np.isnan(t).sum() == 0
print(np.isnan(f).sum() / np.prod(f.shape) * 100, '%')

(241, 10, 114, 3)
(241, 10, 114)
15.263157894736842 %


In [2]:
def get_well(f, t, well):
    well_t = t[:, well, :]
    well_f = f[:, well, :, :]
    well_f = well_f.reshape((-1, 3))
    well_t = well_t.reshape(-1)
    return well_f, well_t


for well in range(10):
    well_f, well_t = get_well(f, t, well)
    print(f'well={well}: {np.isnan(well_f).sum() / np.prod(well_f.shape) * 100:.1f}% NaN')

well=0: 89.5% NaN
well=1: 0.0% NaN
well=2: 0.0% NaN
well=3: 0.0% NaN
well=4: 5.3% NaN
well=5: 5.3% NaN
well=6: 15.8% NaN
well=7: 10.5% NaN
well=8: 10.5% NaN
well=9: 15.8% NaN


In [3]:
def get_train_test(f, t, well, train_size=0.7):
    well_f, well_t = get_well(f, t, well)
    is_nan = (np.isnan(well_f).sum(axis=1) >= 1)
    well_f = well_f[~is_nan]
    well_t = well_t[~is_nan]
    x_train, x_test, y_train, y_test = train_test_split(well_f, well_t, train_size=train_size, random_state=0)
    return x_train, x_test, y_train, y_test

In [4]:
from sklearn.linear_model import RidgeCV, LassoCV, LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [5]:
alphas = np.geomspace(1e-5, 1e5, 100)
models = [LinearRegression(), RidgeCV(alphas=alphas), LassoCV(alphas=alphas)]
scaler = StandardScaler()

models_names = [model.__class__.__name__ for model in models]
models_rmse_names = [f'{model_name} RMSE' for model_name in models_names]
models_intercept = [f'{model_name} INTERCEPT' for model_name in models_names]
models_coef = [f'{model_name} COEF' for model_name in models_names]
models_alpha = [f'{model_name} ALPHA' for model_name in models_names if model_name != 'LinearRegression']
columns = ['well', 'n_observations', 'const PRED', 'const RMSE'] + models_rmse_names + models_intercept + models_coef + models_alpha

results = pd.DataFrame(columns=columns).set_index('well')

for well in range(10):
    x_train, x_test, y_train, y_test = get_train_test(f, t, well)

    const_prediction = y_train.mean()
    const_rmse = mean_squared_error(np.full_like(y_test, const_prediction), y_test, squared=False)

    models_rmse = []
    models_intercept = []
    models_coef = []
    models_alpha = []

    for model in models:
        model.fit(scaler.fit_transform(x_train), y_train)
        rmse = mean_squared_error(model.predict(scaler.transform(x_test)), y_test, squared=False)
        models_rmse.append(rmse)
        models_intercept.append(model.intercept_)
        models_coef.append(model.coef_)
        if not isinstance(model, LinearRegression):
            models_alpha.append(model.alpha_)

    results.loc[well] = [x_train.shape[0], const_prediction, const_rmse] + models_rmse + models_intercept + models_coef + models_alpha
results

Unnamed: 0_level_0,n_observations,const PRED,const RMSE,LinearRegression RMSE,RidgeCV RMSE,LassoCV RMSE,LinearRegression INTERCEPT,RidgeCV INTERCEPT,LassoCV INTERCEPT,LinearRegression COEF,RidgeCV COEF,LassoCV COEF,RidgeCV ALPHA,LassoCV ALPHA
well,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
0,2024,4.719395,8.208672,8.079176,8.079134,8.07918,4.719395,4.719395,4.719395,"[0.7975920670553084, 0.0, 1.0157992880090927]","[0.7539243467063641, 0.0, 0.9540377477785056]","[0.7972447266296483, 0.0, 1.0154515837708096]",148.496826,0.000413
1,19231,6.445017,5.042689,3.612191,3.612193,3.612189,6.445017,6.445017,6.445017,"[1.7879273336872115, -0.03834291513380507, 1.9...","[1.7866781707562103, -0.038950149723245175, 1....","[1.7880342580626234, -0.038192712557793745, 1....",29.150531,0.000206
2,19231,7.668444,4.636348,3.295152,3.29512,3.29515,7.668444,7.668444,7.668444,"[1.6764694412256174, -0.14021070051599424, 1.8...","[1.675673129142826, -0.14000818531484172, 1.81...","[1.6765260332588556, -0.13994575038859974, 1.8...",23.101297,0.00026
3,19231,8.061961,4.915286,3.447397,3.447407,3.447404,8.061961,8.061961,8.061961,"[1.9446726584533585, -0.06617940801421736, 1.7...","[1.9421307308460882, -0.06673670740792703, 1.7...","[1.9447641757510326, -0.06592502499592122, 1.7...",29.150531,0.000327
4,18219,7.120315,4.425542,3.170238,3.170224,3.170246,7.120315,7.120315,7.120315,"[1.7059511770770255, -0.18655073345449025, 1.6...","[1.7041874317120946, -0.18639994458340814, 1.6...","[1.7055901215437839, -0.185767049094055, 1.629...",29.150531,0.00083
5,18219,5.729247,5.679621,4.989903,4.990301,4.989969,5.729247,5.729247,5.729247,"[1.5113918607207573, 0.7304563275450469, 1.226...","[1.505375792590998, 0.7280640698392206, 1.2230...","[1.5107158767247815, 0.7295982306440664, 1.226...",93.260335,0.001048
6,16195,3.780596,2.741828,2.262759,2.262776,2.262744,3.780596,3.780596,3.780596,"[0.8110579848754752, 0.046302124153983236, 0.8...","[0.8094874033009396, 0.04628207814422147, 0.86...","[0.8107873669674166, 0.04580581675836958, 0.86...",58.570208,0.000521
7,17207,6.3439,4.157634,3.217761,3.217833,3.217769,6.3439,6.3439,6.3439,"[1.4061398432041439, 0.17215809940339047, 1.41...","[1.4043433070106792, 0.1716533353858134, 1.415...","[1.4059055994133618, 0.17161060373282055, 1.41...",36.783798,0.000521
8,17207,7.716262,4.925156,3.612658,3.612635,3.612657,7.716262,7.716262,7.716262,"[1.8055651723460087, 0.2580540362225022, 1.867...","[1.8038811279560534, 0.2577118659794966, 1.865...","[1.8055861489305356, 0.2580448094398375, 1.867...",29.150531,1e-05
9,16195,7.956582,6.350457,4.896427,4.896437,4.89642,7.956582,7.956582,7.956582,"[2.2918393562583184, -0.0690381483418468, 2.24...","[2.2875126805766257, -0.06804580075882694, 2.2...","[2.2915448335887705, -0.0684415031174111, 2.24...",46.415888,0.000521
