# Learning XOR

XOR is a non-linear operation which yield 1 only if two bits are different from one another

In [1]:
from microtorch import Tensor, Parameter, Module, Functions
from microtorch.optim import SGD
import numpy as np

In [2]:
inputs = Tensor([
    [0, 0],
    [1, 0],
    [0, 1],
    [1, 1]
])

targets = Tensor([
    [1, 0],  # class 0 if bits are the same
    [0, 1],  # class 1 if bits are different
    [0, 1],
    [1, 0]
])

In [3]:
class Linear(Module):
    """
    Simple linear layer.
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        
        self.Weights = Parameter(in_dim, out_dim)
        self.bias = Parameter(out_dim)
        
    def forward(self, inputs):
        return inputs @ self.Weights + self.bias

In [4]:
class Model(Module):
    def __init__(self):
        super().__init__()
        
        self.linear1 = Linear(2, 2)
        self.linear2 = Linear(2, 2)
        
    def forward(self, inputs):
        return self.linear2(Functions.Tanh(self.linear1(inputs)))

In [5]:
model = Model()
optimizer = SGD(model, lr=1e-3)
criterion = Functions.MSE

In [6]:
for epoch in range(int(3e4)):
    optimizer.zero_grad()
    
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    
    loss.backward()
    
    optimizer.step()
    
    if epoch % 100 == 0: 
        print(loss.item())

32.79200786888695
2.8053666482933055
2.141355225214693
2.0941425680468955
2.0753815764939723
2.0619749544764208
2.0516687362731325
2.043575226324076
2.03712181412095
2.031909597287555
2.027652902942579
2.0241427525034554
2.0212233868493445
2.0187767647623613
2.016712094782182
2.0149586209411816
2.013460556814307
2.012173465622617
2.01106163124477
2.0100961196618874
2.0092533290700763
2.008513891020034
2.0078618272880417
2.0072838955702617
2.006769076392233
2.0063081669207525
2.0058934566391713
2.0055184663934074
2.005177736990568
2.004866656907822
2.0045813211340198
2.0043184149859865
2.0040751180989247
2.0038490248139276
2.0036380779654186
2.0034405136713955
2.0032548151956
2.0030796743162647
2.002913958925259
2.00275668581188
2.002606997770668
2.002464144322226
2.0023274654577037
2.0021963779170737
2.002070363593008
2.001948959719536
2.001831750560388
2.0017183603581805
2.0016084473440907
2.001501698639757
2.0013978259099168
2.001296561646744
2.0011976559855857
2.001100873967563
2.00

In [7]:
print(model(inputs).data)

[[ 1.00024777e+00  2.50847914e-04]
 [-2.70428632e-04  9.99726616e-01]
 [-3.93138836e-04  9.99602631e-01]
 [ 1.00024997e+00  2.52030198e-04]]


In [8]:
print(targets.data)

[[1. 0.]
 [0. 1.]
 [0. 1.]
 [1. 0.]]
