In [None]:
# <api>
#The interactive plots are created by using the libraries bqplot and ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import bqplot
from ipywidgets import widgets, Layout
from IPython.display import display, clear_output

In [None]:
# <api>
class InterceptFeature:
    """
    Constant intercept feature
    """
    def transform(self, x):
        return np.hstack([np.ones((x.shape[0], 1)), x])

class PolynomialFeatures:
    """
    Polynomial features
    """
    def __init__(self, degree):
        self.__degree = degree
    
    def transform(self, x):
        """
        Compute features x**i for i = 0, ..., degree
        """
        return np.hstack([x**i for i in range(self.__degree + 1)])
    
class GaussianBasisFunctions():
    """
    Transform the input with a gaussian function of the form:
    phi = exp(- kernelsize(=500) * (phi - mu) ** 2)
    """
    def __init__(self, mus, sigma = 1.0):
        self.mus = mus
        self.sigma = sigma
      
    def transform(self, X):
        phi = np.repeat(X,np.shape(self.mus)[0], axis = 1)
        phi_trans = phi.T - self.mus
        phi_trans = np.exp(- 0.5 / self.sigma**2 * phi_trans.T**2)
        return phi_trans    
    
class Pipeline:
    """
    Model pipeline of preprocessing steps and actual model
    """
    def __init__(self, steps):
        self.steps = steps
    def fit(self, X, y):
        for step in self.steps[:-1]:
            X = step.transform(X)
        ## Last step of pipeline is actual model
        self.steps[-1].fit(X, y)
        self.weights = self.steps[-1].weights_
        return self
    def predict(self, X):
        for step in self.steps[:-1]:
            X = step.transform(X)
        return self.steps[-1].predict(X)

In [None]:
# <api>
class LinearRegressionL2:
    """
    Linear regression with L2 regularization
    """
    
    def __init__(self, lam):
        """
        Create a linear regression model
        t = X w
        
        that minimizes
        ||X w - t||_2 + \\lambda || w ||_2
        """
        self.weights_ = None
        self.lambda_ = lam
        
    def fit(self, X, t):
        """
        Fit linear model on training data D = (X, t)
        """
        num_samples, num_features = X.shape
        self.weights_ = np.linalg.solve(np.dot(X.T, X) + self.lambda_ * np.identity(num_features),
                                        np.dot(X.T, t))
        
        return self
    
    def predict(self, X):
        """
        Predict model response on inputs X
        """
        num_samples, num_features = X.shape
        
        return np.dot(X, self.weights_)

In [None]:
# <api>
class LinRegWidget:
    """
    Interactive widget for demonstration of linear regression.
    """
    
    def __init__(self, N=10, data='Linear', model='Linear', L2=True):
        """
        Initialize the widget.
        
        Args:
            N [int]: a number of initial points.
            data [str]: a type of data set. Can be choosen from:
                - 'Linear', 
                - 'Sinusoidal', 
                - 'Quadratic'.
            model [str]: a type of model (basis functions). Can be choosen from:
                - 'Linear',
                - 'Polynom', 
                - 'RBF'.
            L2 [bool]: if True the widget will be initialized with L2 regularization,
                False - not regularized.
        """
        # Create subwidgets
        self.sigma = widgets.FloatSlider(1, min = 0.1, max = 2, step = 0.02, description = "RBF width")
        self.deg = widgets.IntSlider(2, min = 0, max = N-1, step = 1, description = "Poly degree")       
        self.log_lam = widgets.FloatSlider(0, min = -10, max = 3, step = 0.01, 
                                           description = "Log lambda", disabled=(not L2))
        self.data = widgets.Dropdown(options = ['Linear', 'Sinusoidal', 'Quadratic'],                             
                                     description = 'Data set')
        self.regularized = widgets.Checkbox(value=L2, description='L2 regularization', disabled=False)        
        self.htitle = 'Linear Regression'
        self.button_redraw = widgets.Button(description="Redraw")
        self.message = '<h4 style="color:Tomato;">Polynom degree is too high, expect difficulties by plotting!</h4>'
        self.label = widgets.HTML(value='')
        #self.label = widgets.Label(value='Label', color='red')
        self.tab = widgets.Tab(children=[widgets.VBox([self.log_lam]), 
                                         widgets.VBox([self.deg, self.log_lam]),
                                         widgets.VBox([self.sigma, self.log_lam])],
                                 _titles = {'0': 'Linear', '1': 'Polynom', '2': 'RBF'})
        self.weights = widgets.HTML(value='', description='Weights:', disabled=False)
        
        # Choose regularization
        self.last_lam = np.copy(self.log_lam.value)
        if not L2:
            self.log_lam.value = -np.inf        
        
        # Choose the dataset
        try:
            self.data.value = data
        except:
            raise ValueError('Unknown data set. Choose between Linear/Sinusoidal/Quadratic.')  
        
        # Choose the model (basis functions)
        if model=='Linear':
            self.tab.selected_index='0'
        elif model=='Polynom':
            self.tab.selected_index='1'
        elif model=='RBF':
            self.tab.selected_index='2'
        else:
            raise ValueError('Unknown model. Choose between Linear/Polynom/RBF.')
        
        # Define internal constants
        self.N = N        #Number of points
        self.size = 10    #Size of the plot (not usefull)
        
        # Draw samples from the dataset
        self.redraw()
        
        # Create elements of the canvas 
        self.sc_x = bqplot.LinearScale()
        self.sc_y = bqplot.LinearScale(min = -self.size, max = self.size)
        self.scat = bqplot.Scatter(x = self.xdata[self.data.value], 
                            y = self.ydata[self.data.value],
                            scales = {'x':self.sc_x, 'y':self.sc_y},
                            colors = ['violet'], marker = 'diamond', 
                            enable_move = True)
        self.line = bqplot.Lines(x = [],y = [], scales = {'x':self.sc_x, 'y':self.sc_y}, 
                          display_legend = False, labels=['prediction'])        
        self.ax_x = bqplot.Axis(scale=self.sc_x, label='x data')
        self.ax_y = bqplot.Axis(scale=self.sc_y, orientation='vertical', label='y data') 
        self.new_canvas()
        
        # Set callback functions       
        self.button_redraw.on_click(self.redraw_function)
        self.regularized.observe(self.update_tab, names = 'value')
        self.deg.observe(self.update_line, names = 'value')        
        self.sigma.observe(self.update_line, names = 'value')        
        self.data.observe(self.update_all, names = 'value')        
        self.tab.observe(self.update_line, names = 'selected_index')        
        self.log_lam.observe(self.update_line, names = 'value')
        
    def show(self):
        """Show the widget."""
        display(self.ui, display_id='ui') 
        
    def get_model(self):
        """Get the model."""
        if self.regularized.value == False:
            lam = 0
        else:
            lam = np.exp(self.log_lam.value)
        if self.tab.selected_index == 0:
            return Pipeline([InterceptFeature(), LinearRegressionL2(lam)])
        elif self.tab.selected_index == 1:
            return Pipeline([PolynomialFeatures(self.deg.value), LinearRegressionL2(lam)])
        else:
            return Pipeline([GaussianBasisFunctions(np.linspace(-1, 2, 12)[:, None], self.sigma.value),
                                                     InterceptFeature(),
                                                     LinearRegressionL2(lam)])
        
    def new_canvas(self):
        """Create a new canvas."""
        self.scat = bqplot.Scatter(x = self.xdata[self.data.value], 
                            y = self.ydata[self.data.value],
                            scales = {'x':self.sc_x, 'y':self.sc_y},
                            colors = ['violet'], marker = 'diamond', 
                            enable_move = True)
        
        self.update_line()
        self.canvas = bqplot.Figure(marks=[self.scat, self.line], 
                                    axes=[self.ax_x, self.ax_y], title = self.htitle)
        self.ui = widgets.VBox([widgets.HBox([self.canvas, widgets.VBox([self.data, self.regularized, self.tab, 
                                                           self.button_redraw, 
                                                           self.weights])]),
                                self.label])
        # Enable the ability to move and add points
        with self.scat.hold_sync():
            self.scat.enable_move = True
            self.scat.interactions = {'click': 'add'}
            self.scat.observe(self.update_line, names=['x'])
            self.scat.observe(self.update_line, names=['y'])
    
    def update_all(self, change=None):
        """A callback function for choosing a new dataset."""
        clear_output(wait=True)
        self.new_canvas()
        self.show()
        
    
    def update_line(self, change=None):
        """A callback function for changing a model (basis functions) or their parameters."""
        if (self.regularized.value==False) and (self.tab.selected_index==1) and (self.deg.value >= self.N-1):
            self.label.value = self.message
        else:
            self.label.value = ''
        with self.line.hold_sync():
            self.line.x = np.linspace(-self.size, self.size, 1001)
            self.model = self.get_model()
            self.model.fit(self.scat.x[:, None], self.scat.y)            
            self.line.y = self.model.predict(self.line.x[:, None])
            self.weights.value = self.weights2html(self.model.weights)
            
            
            
    def update_tab(self, change=None):
        """A callback function for choosing a regularization."""
        if self.regularized.value == False:
            self.last_lam = np.copy(self.log_lam.value)
            self.log_lam.disabled = True
            self.log_lam.min = -np.inf
            self.log_lam.value = -np.inf
        else:
            self.log_lam.disabled = False
            self.log_lam.min = -10
            self.log_lam.value = self.last_lam
            
        
    def redraw_function(self, button):
        '''Action, when Redraw button is clicked. Just a wrapper for self.redraw().''' 
        self.redraw()
        self.update_all()
        
    def redraw(self):
        """Redraw the dataset."""
        x = np.random.uniform(size = self.N)*1.2*self.size - self.size*0.66
        noise_level = np.random.uniform(low=0.1, high=0.3)
        self.xdata = {'Linear': x,
                      'Sinusoidal': x,
                      'Quadratic': x}
        self.ydata = {'Linear': np.random.uniform(low=-2, high=2.) * x - \
                                np.random.uniform(low=0.1, high=0.3) + \
                                noise_level * np.random.normal(size = self.N),
                      'Sinusoidal': np.sin(np.random.uniform(low=0.2, high=0.7)*np.pi*x + \
                                          np.random.uniform(low=-1, high=1)) + \
                                          noise_level * np.random.normal(size = self.N),
                      'Quadratic': x*x*np.random.uniform(low=0.3, high=1.5) - \
                                   self.size * np.random.uniform(low=0., high=1.) + \
                                   noise_level * np.random.normal(size = self.N)} 
        
    def weights2html(self, weights):
        """Returns a HTML table for given array of weights. """
        html = '<table>'
        for i in range(len(weights)):
            html += '<tr><th> '+str(i)+': </th><th> '+'{:0.4f}'.format(weights[i])+'</th></tr>'
        html += '</table>'
        return html

In [None]:
lr = LinRegWidget(N=10, data='Sinusoidal', model='Polynom', L2=False)
lr.show()