In [33]:
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import random
from typing import List

from sgrad import Value

%matplotlib inline

In [87]:
class Neuron:
    
    def __init__(self, num_in):
        self.w = [Value(random.uniform(-1, 1)) for _ in range(num_in)]
        self.b = Value(random.uniform(-1, 1))
        
    def __call__(self, x):
        act = sum((wi*xi for wi, xi in zip(self.w, x)), self.b)
        out = act.tanh()
        return out
    
    def parameters(self):
        return self.w + [self.b]
    
    
class Layer:
    
    def __init__(self, num_in, num_out):
        self.num_in = num_in
        self.num_out = num_out
        
        self.neurons = [Neuron(num_in) for _ in range(num_out)]
        
    def __call__(self, x):
        out = [[n(xi) for n in self.neurons] for xi in x]
        if len(out[0]) == 1:
            out = [o[0] for o in out]
        return out
    
    def parameters(self):
        return [p for neuron in self.neurons for p in neuron.parameters()]
    
    
class MLP:
    
    def __init__(self, layout: List[int]):
        self.layers = [Layer(layout[i], layout[i+1]) for i in range(len(layout)-1)]
        
    def __call__(self, x):
        for l in self.layers:
            x = l(x)
        return x
    
    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters()]
    
    def zero_grad(self):
        for p in self.parameters():
            p.zero_grad()
            
        
##class SelfAttentionHead:
    
#    def __init__(self, ):

In [109]:
m = MLP([2, 3, 5, 3, 1])

In [110]:
m(x)

[Value(data=-0.9253634511617738),
 Value(data=-0.9253634511617738),
 Value(data=-0.9236762122728861),
 Value(data=-0.8850600679007604)]

In [111]:
def mse(pred, gt):
    cost = sum((p - y)*(p - y) for p, y in zip(pred, gt))
    return cost
    

In [112]:
x = [[1, 2], [1, 2], [1, 4], [347, 3]]
y = [-1, -1, 1, 1]



In [113]:
### Training loop
lr = 0.01
params = m.parameters()

losses = []
preds = []

for _ in range(10):
    
    m.zero_grad()
    
    ypred = m(x)
    
    loss = mse(ypred, y)
    
    loss.backward()
    
    preds.append([xi.data for xi in ypred])
    losses.append(loss.data)
    

    for p in m.parameters():
        p.data -= lr*p.grad
        
        

In [114]:
preds

[[-0.9253634511617738,
  -0.9253634511617738,
  -0.9236762122728861,
  -0.8850600679007604],
 [-0.9200226777254487,
  -0.9200226777254487,
  -0.9183539345785108,
  -0.8724541358261949],
 [-0.9077588630504544,
  -0.9077588630504544,
  -0.9061710735027422,
  -0.8416076331917741],
 [-0.8843878283703399,
  -0.8843878283703399,
  -0.8830853632488063,
  -0.7760860649127872],
 [-0.8400415635157856,
  -0.8400415635157856,
  -0.8396646280153895,
  -0.6325313491317215],
 [-0.7513091931977498,
  -0.7513091931977498,
  -0.7538623733960341,
  -0.3144746562870104],
 [-0.5704836732136097,
  -0.5704836732136097,
  -0.5814048633026307,
  0.21165099834712556],
 [-0.25695205246412933,
  -0.25695205246412933,
  -0.28379579528909565,
  0.6342003847948106],
 [0.11523100733991719,
  0.11523100733991719,
  0.07410968918557183,
  0.8285504829606307],
 [0.38083675389995125,
  0.38083675389995125,
  0.3365888982558384,
  0.9076942082209941]]

In [92]:
loss = mse(ypred, y)

In [380]:
x1 = torch.Tensor([2.0]).double(); x1.requires_grad = True
x2 = torch.Tensor([0.0]).double(); x2.requires_grad = True
w1 = torch.Tensor([-3.0]).double(); w1.requires_grad = True
w2 = torch.Tensor([1.0]).double(); w2.requires_grad = True

b = torch.Tensor([6.8813735870195432]).double(); b.requires_grad = True




In [381]:
n = x1*w1 + x2*w2 + b
o = torch.tanh(n)

In [93]:
loss.backward()

In [387]:
o.backward()

In [346]:
x1w1.grad

0.4999999999999999

In [347]:
x2w2.grad

0.4999999999999999

In [348]:
x1w1x2w2.grad

0.4999999999999999

In [349]:
n.grad

0.4999999999999999

In [350]:
o.grad

1

In [316]:
a = Value(3)
b = Value(6)

In [317]:
c = a / b

In [318]:
c

Value(data=0.5)

In [319]:
c.backward()

In [320]:
a.grad

0.16666666666666666

In [321]:
b.grad

-0.08333333333333333

In [121]:
-0.12*25

-3.0

In [122]:
-3/25

-0.12