In [1]:
import sys; sys.path.append("..")
from models.utils import *

import statsmodels.api as sm
from statsmodels.regression.linear_model import OLS

In [2]:
data_home = "/mnt/g/My Drive/GTC/solodoch_data_minimal"
lats = ["26N", "30S", "55S", "60S"]

In [3]:
lat = lats[0]
data = xr.open_dataset(f"{data_home}/{lat}.nc")

In [4]:
# apply whatever preprocessing we want *before* calling reshape_inputs
pp_data = apply_preprocessing(data,
                              mode="inputs",
                              remove_season=True,
                              remove_trend=True,
                              standardize=True,
                              lowpass=True)

# reshape as desired and convert to a numpy array
pp_data_np = reshape_inputs(pp_data, keep_coords=["time"])

# dummy strength data
strength = np.random.rand(*(pp_data_np.shape[:-1]))
X = pp_data_np; y = strength
# shuffle data
p = np.random.permutation(len(y))
X, y = X[p], y[p]

class SMWrapper(BaseEstimator, RegressorMixin):
    def __init__(self, alpha=0.1, L1_wt=0.1):
        self.alpha = alpha
        self.L1_wt = L1_wt
        self.model = None

    def fit(self, X, y):
        self.model = sm.OLS(y, X).fit_regularized(alpha=self.alpha, L1_wt=self.L1_wt)
        return self

    def predict(self, X):
        return self.model.predict(X)
    
# add bias term
X = sm.add_constant(X)
# train/val/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123456)
# hyperparamters to optimise
param_grid = {
    "alpha": np.logspace(-4, 4, 10),
    "L1_wt": np.linspace(0, 1, 10)
}
# grid search - 10-fold cross-validation
grid_search = GridSearchCV(SMWrapper(), param_grid, cv=10, scoring="neg_mean_squared_error")
grid_search.fit(X_train, y_train)
# report best result
print(f"Best MSE (cross-validation): {round(-grid_search.best_score_, 3)}")
# report test performance
y_pred = grid_search.best_estimator_.predict(X_test)
r2 = round(r2_score(y_test, y_pred), 3)
mse = round(mean_squared_error(y_test, y_pred), 3)
print(f"Test R^2: {r2}")
print(f"Test MSE: {mse}")

axes: ['time', 'feature']
variables: ['SSH', 'SST', 'SSS', 'OBP', 'ZWS']
shape: (288, 5)
Best MSE (cross-validation): 0.081
Test R^2: -0.026
Test MSE: 0.084
