In [None]:
# |hide
from nbdev.showdoc import show_doc

# XOR

> An example of how to solve the XOR problem with MLMVN.

The XOR problem is an example of how a single real-valued neuron cannot learn a simple but non-linear relationship. At least, this holds if we do not extend the dimensionality of the feature space.

## Setup

In [None]:
# |hide
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from mlmvn.layers import FirstLayer, OutputLayer, cmplx_phase_activation
from mlmvn.loss import ComplexMSELoss
from mlmvn.optim import ECL

from res.utils import find_project_root

## Define Variables

In [None]:
# | hide
PROJECT_ROOT = find_project_root()
SEED: int = 42
MODEL_MLP_LEARNING_RATE: float = 0.1
MODEL_MLMVN_LEARNING_RATE: float = 1
MODEL_MLMVN_CATEGORIES: int = 2
MODEL_MLMVN_PERIODICITY: int = 2
MODEL_EPOCHS: int = 10
MODEL_BATCH_SIZE: int = 120

torch.manual_seed(SEED)

<torch._C.Generator>

## Load Dataset

The dataset contains four input-output mappings with binary classes. The two-dimensional input $x$ is mapped to a class label $y$. The following table shows the truth table with associated labels for the XOR gate.

$$
\begin{aligned}
    \begin{array}{cc|c|cc}
        x_1 & x_2 & y & z & arg(z) \\
        \hline
		1 &  1	& 0	&  1+j &  45° \\
		1 & -1	& 1	&  1-j & 315° \\
		-1 &  1	& 1	& -1+j & 135° \\
		-1 & -1	& 0	& -1-j & 225° \\
    \end{array}
\end{aligned}
$$


If we consider $x_1$ as $Re(z)$ and $x_2$ as $Im(z)$, the problem can also be expressed graphically into the complex domain. 

<center>
    <img src="fig/xor_complex_domain.png" width=320 />
</center>

In [None]:
# XOR problem inputs and outputs

# complex case
x = torch.Tensor([[1.0, 1.0], [1.0, -1.0], [-1.0, 1.0], [-1.0, -1.0]])
x = x.type(torch.cdouble)
y = torch.Tensor([0.0, 1.0, 1.0, 0.0]).reshape(x.shape[0], 1)

# real case
inputs = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
outputs = torch.tensor([[0.0], [1.0], [1.0], [0.0]])

## MLMVN

In [None]:
# Define the MLMVN model
class MLMVN(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = FirstLayer(2, 2)  # input layer
        self.phase_act = cmplx_phase_activation()  # complex activation function
        self.linear1 = OutputLayer(2, 1)  # output layer
        self.phase_act1 = cmplx_phase_activation()  # complex activation function

    def forward(self, x):
        x = self.linear(x)
        x = self.phase_act(x)
        x = self.linear1(x)
        x = self.phase_act1(x)
        return x


# Initialize the MLMVN
model = MLMVN()

# Define loss function and optimizer
criterion = ComplexMSELoss.apply
optimizer = ECL(model.parameters(), lr=MODEL_MLMVN_LEARNING_RATE)

# Train the MLMVN
for t in range(MODEL_EPOCHS):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    loss = criterion(y_pred, y, MODEL_MLMVN_CATEGORIES, MODEL_MLMVN_PERIODICITY)
    print(t, torch.abs(loss))

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step(inputs=x, layers=list(model.children()))

# Test the MLMVN
predictions = model(x)
predictions

0 tensor(1.3599, dtype=torch.float64, grad_fn=<AbsBackward0>)
1 tensor(0.3484, dtype=torch.float64, grad_fn=<AbsBackward0>)
2 tensor(0.0889, dtype=torch.float64, grad_fn=<AbsBackward0>)
3 tensor(0.0108, dtype=torch.float64, grad_fn=<AbsBackward0>)
4 tensor(0.0019, dtype=torch.float64, grad_fn=<AbsBackward0>)
5 tensor(0.0003, dtype=torch.float64, grad_fn=<AbsBackward0>)
6 tensor(3.8336e-05, dtype=torch.float64, grad_fn=<AbsBackward0>)
7 tensor(5.5209e-06, dtype=torch.float64, grad_fn=<AbsBackward0>)
8 tensor(8.0576e-07, dtype=torch.float64, grad_fn=<AbsBackward0>)
9 tensor(1.1703e-07, dtype=torch.float64, grad_fn=<AbsBackward0>)


tensor([[ 0.7070+0.7072j],
        [-0.7071+0.7071j],
        [ 0.7071-0.7071j],
        [-0.7070-0.7072j]], dtype=torch.complex128,
       grad_fn=<phase_activationBackward>)

The output of the model is complex and can be converted into a real output using the function `angle2class`

In [None]:
def angle2class(x: torch.tensor, categories, periodicity) -> torch.tensor:
    tmp = x.angle() + 2 * np.pi
    angle = torch.remainder(tmp, 2 * np.pi)

    # This will be the discrete output (the number of sector)
    o = torch.floor(categories * periodicity * angle / (2 * np.pi))
    return torch.remainder(o, categories)


angle2class(predictions, 2, 2)

tensor([[0.],
        [1.],
        [1.],
        [0.]], dtype=torch.float64, grad_fn=<RemainderBackward0>)

## Multilayer Perceptron (MLP)

In [None]:
# Define the MLP model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 2),  # input layer
            nn.Sigmoid(),  # activation function
            nn.Linear(2, 1),  # output layer
            nn.Sigmoid(),  # output activation function
        )

    def forward(self, x):
        return self.layers(x)


# Initialize the MLP
model = MLP()

# Define loss function and optimizer
criterion = nn.BCELoss()  # Binary Cross Entropy Loss for binary classification
optimizer = optim.SGD(model.parameters(), lr=MODEL_MLP_LEARNING_RATE)

# Train the MLP
for epoch in range(10000):
    # Forward pass
    predictions = model(inputs)

    # Compute loss
    loss = criterion(predictions, outputs)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print loss every epochs
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

# Test the MLP
test_preds = model(inputs)
for input, pred in zip(inputs, test_preds):
    print(f"Input: {input}, Prediction: {pred}")

Epoch 0, Loss: 0.7061562538146973
Epoch 1000, Loss: 0.6886123418807983
Epoch 2000, Loss: 0.6580316424369812
Epoch 3000, Loss: 0.41256237030029297
Epoch 4000, Loss: 0.11984607577323914
Epoch 5000, Loss: 0.058993514627218246
Epoch 6000, Loss: 0.03796124458312988
Epoch 7000, Loss: 0.027697304263710976
Epoch 8000, Loss: 0.021699126809835434
Epoch 9000, Loss: 0.017790351063013077
Input: tensor([0., 0.]), Prediction: tensor([0.0146], grad_fn=<UnbindBackward0>)
Input: tensor([0., 1.]), Prediction: tensor([0.9864], grad_fn=<UnbindBackward0>)
Input: tensor([1., 0.]), Prediction: tensor([0.9808], grad_fn=<UnbindBackward0>)
Input: tensor([1., 1.]), Prediction: tensor([0.0124], grad_fn=<UnbindBackward0>)
