In [None]:
%%capture
! pip install numpy
! pip install matplotlib
! pip install ipympl
! pip install sympy

In [None]:
%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import time
import sympy
import random

In [None]:
x,y = sympy.symbols('x y')
parameters = sympy.symbols('a1 w1 w2 b')
a1, w1, w2, b = parameters
decision_function_sympy = a1*(x-w1)**2+a1*(y-w2)**2+b
#decision_function_sympy = w1*x+w2*y+b
decision_function = sympy.lambdify(list(parameters)+[x,y],decision_function_sympy)
derivatives = [sympy.lambdify(list(parameters)+[x,y],sympy.diff(decision_function_sympy,param)) for param in parameters]

def decision_boundary(args:list[float],x_vec:np.array,y_vec:np.array)->np.array:
    assert(len(args) == len(parameters))
    return decision_function(*(args+[x_vec,y_vec]))

def logistic_fun(args:list[float],x_vec:np.array,y_vec:np.array)->np.array:
    return 1/(1+np.exp(-1*decision_boundary(args,x_vec,y_vec)))

def calc_cost_function(args:list[float], x_vec:np.array, y_vec:np.array, category_vec:np.array)->int:
    '''
    Calculates the logistic regression cost function (average of -y*log(logistic function)-(1-y)*log(1-logistic function))
    '''
    logistic_functions = logistic_fun(args,x_vec,y_vec)
    casewise_cat_1 = -1*category_vec*np.log(logistic_functions)
    casewise_cat_0 = -1*(1-category_vec)*np.log(1-logistic_functions)
    return np.average(casewise_cat_0+casewise_cat_1)/2

def get_gradient(args:list[float], x_vec:np.array, y_vec:np.array, category_vec:np.array)->np.array:
    return logistic_fun(args,x_vec,y_vec)-category_vec

def get_gradients(args:list[float], x_vec:np.array, y_vec:np.array, category_vec:np.array)->list[float]:
    assert(len(args) == len(parameters))
    gradient_vec = get_gradient(args,x_vec,y_vec,category_vec)
    gradients = []
    for derivative in derivatives:
        curr_gradients = derivative(*(args+[x_vec,y_vec]))
        gradients.append(np.average(gradient_vec*curr_gradients))
    return gradients



In [None]:
with plt.ioff():
    fig, ax = plt.subplots()
ax.set_xlim(0,1)
ax.set_ylim(0,1)

# Some weird hack to make the plot update immediately I found on github (https://github.com/matplotlib/ipympl/issues/290#issuecomment-755377055)
def display_immediately(fig):
    canvas = fig.canvas
    display(canvas)
    canvas._handle_message(canvas, {'type': 'send_image_mode'}, [])
    canvas._handle_message(canvas, {'type':'refresh'}, [])
    canvas._handle_message(canvas,{'type': 'initialized'},[])
    canvas._handle_message(canvas,{'type': 'draw'},[])

display_immediately(fig)


category_toggle = widgets.Checkbox(description="Toggle for category")
animate_checkbox = widgets.Checkbox(description="Animate linear regression",value=True)
lin_reg_button = widgets.Button(description="Click")

points:list[tuple] = []
def onclick(event):
    is_category_1:bool = not category_toggle.value
    points.append((event.xdata,event.ydata,1 if is_category_1 else 0))
    ax.plot(event.xdata,event.ydata,marker="o",color="green" if is_category_1 else "red")

fig.canvas.mpl_connect("button_press_event",onclick)

def plot_implicit(params:list[float]):
    x_vals = np.arange(0,1,0.01)
    y_vals = np.arange(0,1,0.01)
    x,y  = np.meshgrid(x_vals,y_vals)
    #display(x,y,decision_boundary(w1,w2,b,x,y))
    return ax.contour(x,y,decision_boundary(params,x,y),[0])

def log_reg(event):
    should_animate = animate_checkbox.value
    x = np.array([point[0] for point in points])
    y = np.array([point[1] for point in points])
    categories = np.array([point[2] for point in points])
    params = [random.uniform(-1,1) for _ in range(len(parameters))]
    alpha = 0.01
    contour_lines = plot_implicit(params)
    
    cost = calc_cost_function(params,x,y,categories)
    for i in range(1000):
        gradients = get_gradients(params,x,y,categories)
        for j in range(len(params)):
            params[j] -= alpha*gradients[j] # Could improve with a matrix operation with numpy
        new_cost = calc_cost_function(params,x,y,categories)
        if (new_cost > cost):
            alpha /=2
        else:
            alpha += 0.01
        cost = new_cost
        if (should_animate and i % 20 == 0):
            contour_lines.remove()
            contour_lines = plot_implicit(params)
            fig.canvas.draw()
            time.sleep(0.05)
    contour_lines.remove()
    plot_implicit(params)
    fig.canvas.draw()


lin_reg_button.on_click(log_reg)
display(category_toggle,lin_reg_button,animate_checkbox)