In [1]:
#|hide
from nbdev.showdoc import *

# XOR

> An example of how to solve the XOR problem with MLMVN.

The XOR problem is an example of how a single real-valued neuron cannot learn a simple but non-linear relationship. At least, this holds if we do not extend the dimensionality of the feature space.

In [2]:
#|hide
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from mlmvn.layers import OutputLayer, cmplx_phase_activation
from mlmvn.loss import ComplexMSELoss

torch.manual_seed(0)  #  for repeatable results

<torch._C.Generator at 0x7f0ed4d7e050>

In [3]:
config = dict(
    epochs=20,
    classes=2,
    kernels=[2],
    batch_size=4,
    learning_rate=1,
    dataset="XOR",
    architecture="MLMVN")

The dataset contains four input-output mappings with binary classes. The two-dimensional input $x$ is mapped to a class label $y$. The following table shows the truth table with associated labels for the XOR gate.

$$
\begin{aligned}
    \begin{array}{cc|c|cc}
        x_1 & x_2 & y & z & arg(z) \\
        \hline
		1 &  1	& 0	&  1+j &  45° \\
		1 & -1	& 1	&  1-j & 315° \\
		-1 &  1	& 1	& -1+j & 135° \\
		-1 & -1	& 0	& -1-j & 225° \\
    \end{array}
\end{aligned}
$$


If we consider $x_1$ as $Re(z)$ and $x_2$ as $Im(z)$, the problem can also be expressed graphically into the complex domain. 

<center>
    <img src="fig/xor_complex_domain.png" width=450 />
</center>

In [4]:
# create data
x = torch.Tensor([[1., 1.],
               [1., -1.],
               [-1., 1.],
               [-1., -1.]])

x = x.type(torch.cdouble)

y = torch.Tensor([0., 1., 1., 0.]).reshape(x.shape[0], 1)


In [5]:
class BasicModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = OutputLayer(2, 20)
        self.phase_act = cmplx_phase_activation()
        self.linear1 = OutputLayer(20, 1)
        self.phase_act = cmplx_phase_activation()

    def forward(self, x):
        x = self.linear(x)
        x = self.linear1(x)
        x = self.phase_act(x)
        return x

In [6]:
model = BasicModel()
criterion = ComplexMSELoss.apply
optimizer = torch.optim.SGD(model.parameters(), lr=1)
categories =  2
periodicity = 2

In [7]:
for t in range(200):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    loss = criterion(y_pred.view(-1), y, categories, periodicity)
    # wandb.log({"loss": torch.abs(loss)})
    
    if t % 10 == 9: print(t, torch.abs(loss))

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # for idx, param in enumerate(model.parameters()):
    #     wandb.log({"weights_layer"+str(idx)+"_real": param.real})
    #     wandb.log({"weights_layer"+str(idx)+"_imag": param.imag})

9 tensor(4.5521e-06, dtype=torch.float64, grad_fn=<AbsBackward0>)
19 tensor(3.9741e-12, dtype=torch.float64, grad_fn=<AbsBackward0>)
29 tensor(1.1070e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
39 tensor(1.3549e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
49 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
59 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
69 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
79 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
89 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
99 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
109 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
119 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
129 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
139 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBackward0>)
149 tensor(1.3554e-16, dtype=torch.float64, grad_fn=<AbsBac

In [8]:
for idx, param in enumerate(model.parameters()):
    param.real

In [9]:
predictions = model(x)

In [10]:
def angle2class(x: torch.tensor, categories, periodicity) -> torch.tensor:
    tmp = x.angle() + 2 * np.pi
    angle = torch.remainder(tmp, 2 * np.pi)

    # This will be the discrete output (the number of sector)
    o = torch.floor(categories * periodicity * angle / (2 * np.pi))
    return torch.remainder(o, categories)

angle2class(predictions, 2, 2)

tensor([[0.],
        [1.],
        [1.],
        [0.]], dtype=torch.float64, grad_fn=<RemainderBackward0>)