Hand-written implementation of backward, forward and optimizers

In [5]:
import torch
import torch.nn as nn
from math import *

# Forward and backward

In [30]:
class MyModele (nn.Module):
    def __init__(self, tensor):
        self.__init__()
        self.param = nn.Parameter(tensor, requires_grad=False)

    def forward(self, x):
        return x * self.param

    def backward(self, grad):
        grad = 2 * grad - 7

# Optimizers: Adagrad and Adam

In [41]:
class Adagrad:
    def __init__(self, model_weight, lr: float=0.001):
        self.lr = lr
        self.weight = model_weight
        self.accumulated = 0

    def step(self, grad):
        self.accumulated = self.accumulated + grad ** 2
        adapt_lr = self.lr/(sqrt(self.accumulated) + 10**(-16))
        self.weight = self.weight - adapt_lr * grad


In [79]:
class Adam:
    def __init__(self, model_weight, lr: float=0.001, beta_1: float=0.9, beta_2: float=0.999):
        self.lr = lr
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.weight = model_weight
        self.velocity = torch.zeros_like(model_weight)
        self.accumulated = 0

    def step(self, grad):
        self.velocity = self.beta_1 * self.velocity + (1 - self.beta_1) * grad
        self.accumulated = self.beta_2 * self.accumulated + (1 - self.beta_2) * grad ** 2
        adapt_lr = self.lr/sqrt(self.accumulated)
        self.weight = self.weight - adapt_lr * self.velocity

# Solution of function

Function used $$ x^2 - 7x + 6 $$

In [66]:
grads = lambda x: 2*(x ** 2 - 7 * x + 6)*(2 * x - 7)
func = lambda x: (x ** 2 - 7 * x + 6)**2

In [83]:
def solve(x):
    x = torch.tensor(x).float()
    grad = grads(x)
    optim = Adagrad(x, 1)
    prew = optim.weight
    optim.step(grad)
    grad = grads(optim.weight)
    while abs(func(prew) - func(optim.weight)) >= 0.001:
        prew = optim.weight
        optim.step(grad)
        grad = grads(optim.weight)
    return optim.weight

In [85]:
print('First root:', solve(5).numpy())
print('Second root:', solve(-5).numpy())

First root: 6.0
Second root: 0.96306723
