# Linear regression

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [81]:
from itertools import product

import numpy as np
import xarray as xr
import xarray.ufuncs as xf
import matplotlib.pyplot as plt
from scipy.optimize import lsq_linear

In [5]:
data = xr.open_dataset("data_smoothed.nc")
data

In [19]:
def linear_regression(f, states, b):
    """

    :param f: Function calculating the rows of the design matrix.
    :param states: States.
    :param b: Signal being fitted.
    :return: Coefficient of the function being fitted, residuals
    """
    # Ensure these are numpy arrays
    states = np.asarray(states)
    b = np.asarray(b).flatten()

    # Check inputs
    assert states.shape[0] == len(b), "states must have the same number of rows as the length of b."

    # Loop through each state, call the function and add the result
    # as a row to the A matrix
    a_matrix = []
    for state in states:
        a_row = f(*state)
        a_matrix.append(a_row)
    a_matrix = np.array(a_matrix)

    # Run least squares
    result = lsq_linear(a_matrix, b)
        
    return result.x, result.cost, result.fun

In [116]:
def f(alpha, beta):
    order = 20
    terms = []
    for alpha_power, beta_power in product(range(order), range(order)):
        terms.append(alpha**alpha_power * beta**beta_power)
    return terms

In [117]:
states = np.vstack((
    data.alpha_estimate.values,
    data.beta_m.values,
)).T

coefficients, r2, residuals = linear_regression(f, states, data.c_m.values)
display(coefficients)
display(r2)

array([-5.96263422e-02,  2.04393054e-02,  3.40408500e-02, -3.57609523e+00,
       -1.40029269e+02,  4.00260254e+02,  1.89735515e+04, -1.78703271e+04,
       -1.26156798e+06,  3.67676438e+05,  4.74353900e+07, -3.25264458e+06,
       -1.03585674e+09,  3.15130738e+06,  1.29389954e+10,  4.98079516e+07,
       -8.54980281e+10,  1.23954602e+09,  2.31352241e+11, -1.11166153e+10,
        1.42389863e-01,  8.39774973e-02,  2.55793522e+01, -6.04344510e+01,
       -2.42132561e+03,  9.09123275e+03,  5.68269566e+04, -6.25860166e+05,
        2.55709592e+06,  2.37232866e+07, -1.63177987e+08, -5.35376127e+08,
        3.41217618e+09,  7.22318483e+09, -3.22118738e+10, -5.35076813e+10,
        1.18010559e+11,  1.60986368e+11, -2.00418229e+10,  6.19649718e+10,
       -6.92787170e-01, -3.86763763e+00, -9.10132942e+01,  1.98394775e+00,
        2.38139281e+03, -1.02413618e+04,  5.52429911e+05,  7.67977227e+05,
       -3.84917217e+07, -1.76729312e+07,  1.06159455e+09,  1.12130402e+08,
       -1.47508011e+10,  

0.0164245017591691

In [118]:
c_m_f = []
for state in states:
    c_m_f.append(coefficients @ f(*state))

In [121]:
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(data.alpha_m, data.beta_m, data.c_m, s=0.1)
ax.scatter(data.alpha_m, data.beta_m, c_m_f, s=0.1)
plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [122]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(data.alpha_m, data.beta_m, abs(data.c_m - c_m_f) / c_m_f * 100., s=0.1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7eff94ffea00>