# Polynomial Linear Regression - Interactive Lab

Following the linear regression example described in [linear_regression.ipynb](linear_regression.ipynb), here let's relax and play around!

<!--<badge>--><a href="https://colab.research.google.com/github/inlab-geo/cofi-examples/blob/main/notebooks/linear_regression/linear_regression_lab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a><!--</badge>-->

---
### Import modules and get prepared

Still, some steps are necessary in preparation for the coming interactive lab.

In [1]:
# -------------------------------------------------------- #
#                                                          #
#     Uncomment below to set up environment on "colab"     #
#                                                          #
# -------------------------------------------------------- #

# !pip install -U cofi

In [2]:
# --------------------------------------------------------------- #
#                                                                 #
#     Problem definition, copied from linear_regression.ipynb     #
#                                                                 #
# --------------------------------------------------------------- #

######## Import required modules
import json
from contextlib import contextmanager
import numpy as np
import matplotlib.pyplot as plt
from cofi import BaseProblem, InversionOptions, Inversion
from cofi.solvers import solvers_table

######## Set random seed (to ensure consistent results in different runs)
np.random.seed(42)

######## Define the polynomial linear regression problem
_basis_func = lambda x: np.array([x**i for i in range(4)]).T
_m_true = np.array([-6,-5,2,1])                                            # m
_sample_size = 20                                                          # N
x = np.random.choice(np.linspace(-3.5,2.5), size=_sample_size)             # x
forward_func = lambda m: (np.array([x**i for i in range(4)]).T) @ m        # m -> y_synthetic
y_observed = forward_func(_m_true) + np.random.normal(0,1,_sample_size)    # d

sigma = 1.0                                     # common noise standard deviation
Cdinv = np.eye(len(y_observed))/(sigma**2)      # inverse data covariance matrix
def log_likelihood(model):
    y_synthetics = forward_func(model)
    residual = y_observed - y_synthetics
    return -0.5 * residual @ (Cdinv @ residual).T

m_lower_bound = np.ones(4) * (-10.)             # lower bound for uniform prior
m_upper_bound = np.ones(4) * 10                 # upper bound for uniform prior
def log_prior(model):    # uniform distribution
    for i in range(len(m_lower_bound)):
        if model[i] < m_lower_bound[i] or model[i] > m_upper_bound[i]: return -np.inf
    return 0.0 # model lies within bounds -> return log(1)

ndim = 4

inv_problem = BaseProblem()
inv_problem.name = "Polynomial Regression"
inv_problem.set_data(y_observed)
inv_problem.set_forward(forward_func)
inv_problem.set_data_misfit("L2")
inv_problem.set_jacobian(_basis_func(x))
inv_problem.set_initial_model(np.ones(4))
inv_problem.set_log_prior(log_prior)
inv_problem.set_log_likelihood(log_likelihood)

######## Parameters that are to be changed by user settings
def walkers_start(nwalkers):
    return np.array([0.,0.,0.,0.]) + 1e-4 * np.random.randn(nwalkers, ndim)

######## Review the basic/fixed problem setup
# inv_problem.summary()

In [3]:
# ---------------------------------------------------------------- #
#                                                                  #
#     Auxiliary code for widgets displaying, no need to modify     #
#                                                                  #
# ---------------------------------------------------------------- #

def adjust_problem(regularisation, regularisation_factor):
    inv_problem.set_regularisation(regularisation, regularisation_factor)
    return inv_problem
    
def adjust_options(solving_method, tool, solver_params):
    inv_options = InversionOptions()
    inv_options.set_solving_method(solving_method)
    inv_options.set_tool(tool)
    inv_options.set_params(**solver_params)
    if solving_method == "sampling":
        inv_problem.set_walkers_starting_pos(walkers_start(solver_params["nwalkers"]))
    return inv_options

def plot_from_result(inv_result, method):
    if method == "sampling":
        flat_samples = inv_result.sampler.get_chain(discard=300, thin=30, flat=True)
        inds = np.random.randint(len(flat_samples), size=100) # get a random selection from posterior ensemble
        _x_plot = np.linspace(-3.5,2.5)
        _G_plot = _basis_func(_x_plot)
        _y_plot = _G_plot @ _m_true
        plt.figure(figsize=(12,8))
        sample = flat_samples[0]
        _y_synth = _G_plot @ sample
        plt.plot(_x_plot, _y_synth, color="seagreen", label="Posterior samples",alpha=0.1)
        for ind in inds:
            sample = flat_samples[ind]
            _y_synth = _G_plot @ sample
            plt.plot(_x_plot, _y_synth, color="seagreen", alpha=0.1)
        plt.plot(_x_plot, _y_plot, color="darkorange", label="true model")
        plt.scatter(x, y_observed, color="lightcoral", label="observed data")
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.legend()
    else:
        _x_plot = np.linspace(-3.5, 2.5)
        _G_plot = _basis_func(_x_plot)
        _y_plot_true = _G_plot @ _m_true
        _y_plot_synth = _G_plot @ inv_result.model
        plt.figure(figsize=(12,8))
        plt.plot(_x_plot, _y_plot_true, color="darkorange", label="true model")
        plt.plot(_x_plot, _y_plot_synth, color="seagreen", label=f"{method} solution")
        plt.scatter(x, y_observed, color="lightcoral", label="observed data")
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.legend()

def inversion(reg, reg_factor, method, tool, solver_params):
    inv_problem = adjust_problem(reg, reg_factor)
    inv_options = adjust_options(method, tool, solver_params)
    inv_runner = Inversion(inv_problem, inv_options)
    result = inv_runner.run()
    # result.summary()
    # plot_from_result(result, method)
    return result
    
# inversion("L2", "L2", 0.05, "optimisation", "scipy.optimize.minimize")

---
### Start the lab

In [4]:
# -------------------------------------------------------------------------------------- #
#                                                                                        #
# Auxiliary code for widgets displaying, no need to modify                               #
#                                                                                        #
# Run this cell and start interacting :)                                                 #
#                                                                                        #
# If you have trouble displaying the interactive widgets locally, check advice here:     #
# https://stackoverflow.com/questions/36351109/ipython-notebook-ipywidgets-does-not-show #
#                                                                                        #
# -------------------------------------------------------------------------------------- #

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

msg_label = widgets.Label('Ready')

method_widget = widgets.ToggleButtons(options=solvers_table.keys(), description="method")
tool_widget = widgets.RadioButtons(options=solvers_table["optimisation"].keys(), description="tool")
reg_widget = widgets.FloatSlider(value=2,min=-2,max=10,step=1,description="reg")
reg_factor_widget = widgets.FloatLogSlider(base=10,value=0.0001,min=-15,max=1,step=0.2, description="reg_factor")
solver_specific_params_widget = widgets.VBox(children=[])
run_widget = widgets.Button(description="Run Inversion")

def update_params_widgets(tool_solver):
    params_children = []
    required_in_options = tool_solver.required_in_options
    for param in required_in_options:
        params_children.append(widgets.Text(value="", placeholder="required", description=param))
    optional_in_options = tool_solver.optional_in_options
    for param, dft in optional_in_options.items():
        params_children.append(widgets.Text(value=str(dft), placeholder="optional", description=param))
    solver_specific_params_widget.children = params_children

def method_updated(*args):
    tool_widget.options = solvers_table[method_widget.value].keys()
method_widget.observe(method_updated, 'value')

def tool_updated(*args):
    tool = tool_widget.value
    tool_solver = solvers_table[method_widget.value][tool]
    required_in_problem = tool_solver.required_in_problem
    reg_factor_widget.layout.visibility = "visible" if "objective" in required_in_problem else "hidden"
    reg_widget.layout.visibility = "visible" if "objective" in required_in_problem else "hidden"
    update_params_widgets(tool_solver)
tool_widget.observe(tool_updated, 'value')
update_params_widgets(solvers_table["optimisation"]["scipy.optimize.minimize"])

def button_on_click(*args): 
    method = method_widget.value
    tool = tool_widget.value
    reg = reg_widget.value
    reg_factor = reg_factor_widget.value
    solver_params = {}
    try:
        for widget in solver_specific_params_widget.children:
            val = widget.value
            if widget.placeholder == "required" and not val:
                raise ValueError("please fill in required parameters")
            if val == "None": val = None
            elif val == "True" or val == "true": val = True
            elif val == "False" or val == "False": val = False
            elif "{" in val and "}" in val: val = json.loads(val)
            else:
                try:
                    val = int(val)
                    val_flt = float(val)
                    if val != val_flt:
                        val = val_flt
                except:
                    pass
            solver_params[widget.description] = val
    except Exception as e:
        msg_label.value = f"Something is wrong:\n{e.__class__.__name__} - {e}"
        raise
    with show_loading():
        try:
            res = inversion(reg, reg_factor, method, tool, solver_params)
            args[0].result = res
            args[0].settings = {"reg with order": reg, "reg_factor":reg_factor, 
                                "method":method, "tool":tool, "solver_params":solver_params}
        except Exception as e:
            msg_label.value = f"Something is wrong:\n{e.__class__.__name__} - {e}"
            raise
run_widget.on_click(button_on_click)

@contextmanager
def show_loading():
    msg_label.value = 'Running...'
    yield
    if hasattr(run_widget, "result"):
        msg_label.value = f'Ready ({run_widget.result.success_or_not})'

w = widgets.VBox(children=[method_widget, tool_widget, reg_widget, reg_factor_widget,
                          solver_specific_params_widget, run_widget, msg_label])
display(w)

VBox(children=(ToggleButtons(description='method', options=('optimisation', 'linear least square', 'sampling')…

In [5]:
#####################################################################################
#                                                                                   #
#    Uncomment below after you've selected parameters and run an inversion above    #
#                                                                                   #
#####################################################################################

# inv_result = run_widget.result
# plot_from_result(inv_result, run_widget.settings["method"])
# inv_result.summary()

---

In [6]:
import cofi
import emcee
import scipy

%reload_ext watermark
%watermark -n -u -v -iv -w

Last updated: Fri Jun 10 2022

Python implementation: CPython
Python version       : 3.10.4
IPython version      : 8.3.0

json      : 2.0.9
scipy     : 1.8.1
numpy     : 1.22.4
matplotlib: 3.5.2
cofi      : 0.1.2.dev5
emcee     : 3.1.2
ipywidgets: 7.6.5

Watermark: 2.3.1

