In [1]:
import nujo as nj
import nujo.nn as nn
import nujo.objective as obj
import nujo.optim as optim

from nujo.utils.viz import ComputationGraphPlotter

In [2]:
# Create example data
x = nj.rand(30, 3, name='X_train', diff=False)
y = nj.Tensor(x.value @ [[2], [3], [4]] - 10, name='Y_labels', diff=False)

In [3]:
params = [[nj.randn(3, 1, name='Weights')]]
print('Params:', params)

loss_fn = obj.L2Loss()
print('Loss:', loss_fn)

optimizer = optim.Adam(params, lr=0.1)
print('Optimizer:', optimizer)

Params: [[<Weights>]]
Loss: <|L2Loss>
Optimizer: <nujo.optim.optimizers.Adam object at 0x10a058250>


In [4]:
for epoch in range(1, 100):
    # Forward
    output = x @ params[0][0]
    # Compute Loss
    loss = (output - y)**2

    # Print the loss every 10th epoch for monitoring
    if epoch % 10 == 0:
        print('EPOCH:', epoch, '| LOSS: ', loss.value)

    # Backprop
    loss.backward()

    # Update
    optimizer.step()

    # Zero grad
    optimizer.zero_grad()

EPOCH: 10 | LOSS:  [[6.20586150e+00]
 [1.77785758e+01]
 [3.01488043e+00]
 [4.50788849e+01]
 [4.21401886e+00]
 [7.48372485e-01]
 [6.31202397e-02]
 [1.09350388e+01]
 [1.13196522e+01]
 [2.43701412e+01]
 [2.83190371e+01]
 [4.50349346e+00]
 [2.37448507e+01]
 [5.45714997e+00]
 [4.23459498e+00]
 [1.14250185e+01]
 [4.27486010e+01]
 [4.58444353e+01]
 [8.56691273e+00]
 [7.08187211e+00]
 [1.26099688e+01]
 [1.75621757e+01]
 [3.63502574e+01]
 [1.23393660e+00]
 [1.06870193e+01]
 [2.41287913e+01]
 [5.37766557e-01]
 [2.58043225e+00]
 [8.29230415e+01]
 [4.39775811e-02]]
EPOCH: 20 | LOSS:  [[8.26513123e-01]
 [8.69442726e+00]
 [7.50869325e-04]
 [3.84038125e+01]
 [1.37356479e-01]
 [9.03248078e+00]
 [5.36798360e+00]
 [4.67789122e+00]
 [4.30267764e+00]
 [1.45537069e+01]
 [1.91092936e+01]
 [1.99934272e-01]
 [1.48995047e+01]
 [4.28449817e-01]
 [2.88109291e-01]
 [3.37853055e+00]
 [3.26257935e+01]
 [3.26917546e+01]
 [2.41268606e+00]
 [2.23974743e+00]
 [4.29331804e+00]
 [8.67994038e+00]
 [2.50075233e+01]
 [1.110

In [5]:
cg_plot = ComputationGraphPlotter(filename='graph').create(loss)
cg_plot.view()