# XOR game

The XOR function takes two binary digits as input and outputs, 
$$ \mathrm{XOR}:\{0,1\}^2 \to \{0,1\}$$
and is defined by
$$\mathrm{XOR}(0,0) = 0, \,
\mathrm{XOR}(1,0) = 1, \\
\mathrm{XOR}(0,1) = 1, \, \mathrm{XOR}(1,1) = 0. $$

It is known that there exists a neural network $f: \mathbb{R}^2 \to \mathbb{R}$ which exactly fits the XOR function [GBC, pg. 174]. On the other hand, there does not exists a linear function from $\mathbb{R}^2$ to $\mathbb{R}$ exactly fitting XOR. 

The widget below demonstates this. Run the full 'Code' cell at the bottom of the notebook, followed by the cell below.

**Aim:** Using the sliders, find the values of the weights and biases of the neural network which exactly fit the XOR function. 

The widget shows the output of the neural network (left), linear output layer (middle), as well as a discontinuous function which trivially fits XOR (right).

**Neural network specifications:**
- One hidden layer of width 2 with ReLU activation function. 
    - For this layer, the matrix of weights is denoted by $W \in \mathbb{R}^{2 \times 2}$ and vector of biases is denoted by $c \in \mathbb{R}^{2}$.
- Linear output layer.
    - For this layer, the weights are denoted by $w \in \mathbb{R}^{2}$ and the bias is denoted by $b \in \mathbb{R}$.
    
[GBC] Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. MIT press, 2016.


In [None]:
# Run this cell first - widget is below

import numpy as np
import matplotlib.pyplot as plt
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display, Latex

def lin(x,w,b):
    return np.dot(w,x) + b

def relu(t):
    return np.maximum(0,t)

def nn1_reg(x,w1,w2,b1,b2):
    '''
    NN with 1 hidden layer (2 ReLu activation functions) and linear output unit
    Layer 1 weights: w1 (2x2 array)
    Layer 1 bias: b1 (1x2 array) 
    Layer 2 weights: w2 (1x2 array)
    Layer 2 bias: b2 (float) 
    '''
    x = np.array(x)
    #layer1
    h = relu(np.matmul(w1,x) + b1)
    return  np.dot(w2,h) + b2

def XOR(x1,x2): # Takes floats between -0.5 and 1.49, rounding to nearest non-neg integer
    x1, x2 = np.round(x1), np.round(x2)
    if x1 + x2 == 2: return 0
    else: return x1 + x2
XOR = np.vectorize(XOR)

w1_11 = widgets.FloatSlider(value=1, min=-2, max=2, step=0.01, description='$W_{11}$', orientation='vertical',continuous_update=False)
w1_12 = widgets.FloatSlider(value=0, min=-2, max=2, step=0.01,description='$W_{12}$',orientation='vertical',continuous_update=False)
w1_21= widgets.FloatSlider(value=0, min=-2, max=2, step=0.01,description='$W_{21}$',orientation='vertical',continuous_update=False)
w1_22 = widgets.FloatSlider(value=1, min=-2, max=2, step=0.01,description='$W_{22}$',orientation='vertical',continuous_update=False)
w2_1 = widgets.FloatSlider(value=1, min=-2, max=2, step=0.01,description='$w_{1}$',orientation='vertical',continuous_update=False)
w2_2 = widgets.FloatSlider(value=-1, min=-2, max=2, step=0.01,description='$w_{2}$',orientation='vertical',continuous_update=False)
b1_1 = widgets.FloatSlider(value=0, min=-2, max=2, step=0.01,description='$c_1$',orientation='vertical',continuous_update=False)
b1_2 = widgets.FloatSlider(value=0, min=-2, max=2, step=0.01,description='$c_2$',orientation='vertical',continuous_update=False)
b2 =  widgets.FloatSlider(value=0, min=-2, max=2, step=0.01,description='$b$',orientation='vertical')
button = widgets.ToggleButton(value=False, description='Solution', disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Description',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)
ui = widgets.HBox([w1_11, w1_12, w1_21, w1_22,b1_1,b1_2,w2_1,w2_2,b2,button])
def f(w1_11, w1_12, w1_21, w1_22,w2_1,w2_2,b1_1,b1_2,b2,button):
    # Putting parameters into arrays
    w1 = np.array([[w1_11,w1_12], [w1_21,w1_22]]) 
    w2 = np.array([w2_1,w2_2])
    b1 = np.array([b1_1,b1_2])
    
    # Errors 
    nn_err = np.zeros((2,2))
    lin_err = np.zeros((2,2))
    for i in range(2):
        for j in range(2):
            nn_err[i,j] = round(abs(nn1_reg(np.array([i,j]),w1,w2,b1,b2) - XOR(i,j)),3)
            lin_err[i,j] = round(abs(lin(np.array([i,j]),w2,b2) - XOR(i,j)),3)
    print('Neural Network Error: Bit (0,0):', nn_err[0,0], ', Bit (0,1):', nn_err[0,1], ', Bit (1,0):', nn_err[1,0], ', Bit (1,1):', nn_err[1,1])
    print('Linear Error:         Bit (0,0):', lin_err[0,0], ', Bit (0,1):', lin_err[0,1], ', Bit (1,0):', lin_err[1,0], ', Bit (1,1):', lin_err[1,1])
    if button == True:
        display(Latex('Solution: $\qquad \qquad \qquad   W_{1 1} = 1, \, W_{1 2} = 1, \,W_{2 1} = 1, \,W_{2 2} = 1, \,c_{1} = 0, \,c_{2} = -1, \, w_{1} = 1, \,w_{2} = -2, \,b = 0$'))
        
    # Plotting
    
    fig, axs = plt.subplots(1,3,figsize=(15,15))
    x1_lst = np.linspace(-0.2,1.2,100)
    x2_lst = np.linspace(-0.2,1.2,100)
    x_grid = np.meshgrid(x1_lst,x2_lst)
    
    nn_grid = np.zeros([100,100])
    for i in range(100):
        for j in range(100):
            nn_grid[i,j] = nn1_reg([x_grid[0][i,j],x_grid[1][i,j]],w1,w2,b1,b2)
            
    lin_grid = np.zeros([100,100])
    for i in range(100):
        for j in range(100):
            lin_grid[i,j] = lin(np.array([x_grid[0][i,j],x_grid[1][i,j]]),w2,b2)
    
    xor_grid = XOR(x_grid[0],x_grid[1])
    
    
    # Left axis (Neural Network)
    im1 = axs[0].imshow(nn_grid, extent=[x1_lst[0],x1_lst[-1],x2_lst[0],x2_lst[-1]])
    axs[0].title.set_text('Neural Network')
    axs[0].scatter([0,1],[0,1],marker = '$1$',color = 'white')
    axs[0].scatter([1,0],[0,1],marker = '$0$',color = 'white')
    plt.colorbar(im1, ax=axs[0], fraction=0.046, pad=0.04) # 'fraction' and 'pad' parameters scale the colorbar to the size of the graph 
    
    # Middle axis (Linear)
    im2  = axs[1].imshow(lin_grid,extent=[x1_lst[0],x1_lst[-1],x2_lst[0],x2_lst[-1]])
    axs[1].title.set_text('Linear')
    axs[1].scatter([0,1],[0,1],marker = '$1$',color = 'white')
    axs[1].scatter([1,0],[0,1],marker = '$0$',color = 'white')
    plt.colorbar(im2, ax=axs[1], fraction=0.046, pad=0.04)
    
    # Right axis (Discontinuous)
    im3 = axs[2].imshow(xor_grid,extent=[x1_lst[0],x1_lst[-1],x2_lst[0],x2_lst[-1]])
    axs[2].title.set_text('Discontinuous solution')
    axs[2].scatter([0,1],[0,1],marker = '$1$',color = 'white')
    axs[2].scatter([1,0],[0,1],marker = '$0$',color = 'white')
    plt.colorbar(im3, ax=axs[2], fraction=0.046, pad=0.04)
    
    fig.tight_layout()
    
    
        
out = widgets.interactive_output(f, {'w1_11': w1_11, 'w1_12': w1_12, 'w1_21': w1_21, 'w1_22': w1_22,
                                    'w2_1': w2_1, 'w2_2': w2_2, 'b1_1': b1_1, 'b1_2': b1_2, 'b2' : b2, 'button': button})


In [None]:
# Run this after running the above code cell.
display(ui, out)