<a href="https://colab.research.google.com/github/ZaraGiraffe/xor/blob/main/torch_xor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import numpy as np
from torch import nn
import plotly.graph_objects as go

In [2]:
mytype = torch.float32

In [3]:
class Mydata:
    def __init__(self):
        self.X = np.array([
            [0, 0],
            [0, 1],
            [1, 0],
            [1, 1]
        ])
        self.y = np.array([0, 1, 1, 0])

In [88]:
class Mymodel(nn.Module):
    def __init__(self, lr=0.1):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 1), nn.ReLU())
        nn.init.constant_(self.net[0].bias, 0.1)
        nn.init.constant_(self.net[2].bias, 0.1)
        self.optim = torch.optim.SGD(self.parameters(), lr)
        self.myloss = torch.nn.MSELoss()
    def forward(self, X):
        return self.net(X)
    
    def loss(self, y, y_hat):
        return self.myloss(y_hat, y)

class Mytrainer:
    def __init__(self, epochs=300):
        self.epochs = epochs
        self.params = []
        self.errors = []
    
    def fit(self, model, data):
        for epoch in range(self.epochs):
            model.optim.zero_grad()
            y_hat = model.forward(torch.tensor(data.X, dtype=mytype))
            loss = model.loss(torch.tensor(data.y, dtype=mytype), torch.flatten(y_hat))
            loss.backward()
            model.optim.step()
            with torch.no_grad():
                self.errors.append(loss.detach())
                self.params.append(model.net[0].weight.detach().numpy().flatten())

In [89]:
data = Mydata()
model = Mymodel()
trainer = Mytrainer()
trainer.fit(model, data)

In [90]:
trainer.errors[-1]

tensor(0.0074)

In [91]:
model.forward(torch.tensor(data.X, dtype=mytype))

tensor([[0.0932],
        [0.8897],
        [0.9444],
        [0.0737]], grad_fn=<ReluBackward0>)

In [92]:
fig = go.Figure()
all = np.array(trainer.params)
for i in range(len(trainer.params[0])):
    fig.add_trace(go.Scatter(x=list(range(trainer.epochs)), y=all[:,i]))
fig.show()

In [93]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(len(trainer.errors))), y=trainer.errors))
fig.show()

In [94]:
mas = []
for i in range(50):
    data = Mydata()
    model = Mymodel()
    trainer = Mytrainer()
    trainer.fit(model, data)
    mas.append(trainer.errors[-1])
fig = go.Figure()
fig.add_trace(go.Histogram(x=mas))
fig.show()

In [95]:
fig = go.Figure()
for i in range(50):
    data = Mydata()
    model = Mymodel()
    trainer = Mytrainer()
    trainer.fit(model, data)
    fig.add_trace(go.Scatter(x=list(range(trainer.epochs)), y=trainer.errors))
fig.show()