In [3]:
import numpy as np
import torch


In [4]:

# Input (temp, rainfall, humidity)
inputs = np.array([[73, 67, 43], 
                   [91, 88, 64], 
                   [87, 134, 58], 
                   [102, 43, 37], 
                   [69, 96, 70]], dtype='float32')

In [5]:
# Targets (apples, oranges)
targets = np.array([[56, 70], 
                    [81, 101], 
                    [119, 133], 
                    [22, 37], 
                    [103, 119]], dtype='float32')

In [6]:
# Convert inputs and targets to tensors
inputs = torch.from_numpy(inputs)
targets = torch.from_numpy(targets)
print(inputs)
print(targets)

tensor([[ 73.,  67.,  43.],
        [ 91.,  88.,  64.],
        [ 87., 134.,  58.],
        [102.,  43.,  37.],
        [ 69.,  96.,  70.]])
tensor([[ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.]])


In [7]:
w = torch.randn(2,3,requires_grad=True)
b = torch.randn(2, requires_grad=True)

In [8]:
print(w)
print(b)

tensor([[ 1.7426,  2.3574,  0.5301],
        [ 0.3513, -0.4058,  1.3611]], requires_grad=True)
tensor([-0.1261,  0.6217], requires_grad=True)


In [9]:
def model(x):
    return x @ w.t() + b

In [10]:
preds = model(inputs)
print(preds)

tensor([[307.8259,  57.6004],
        [399.8307,  83.9835],
        [498.1223,  55.7448],
        [298.6022,  69.3602],
        [383.5335,  81.1756]], grad_fn=<AddBackward0>)


In [11]:
# Compare with targets
print(targets)

tensor([[ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.]])


In [12]:
# MSE loss
def mse(t1, t2):
    diff = t1 - t2
    return torch.sum(diff * diff) / diff.numel()

In [13]:
loss = mse(preds, targets)
print(loss)

tensor(47290.0391, grad_fn=<DivBackward0>)


In [14]:
# Compute gradients
loss.backward()

In [15]:
# Gradients for weights
print(w)
print(w.grad)

tensor([[ 1.7426,  2.3574,  0.5301],
        [ 0.3513, -0.4058,  1.3611]], requires_grad=True)
tensor([[25590.1523, 26911.3867, 16618.8789],
        [-1696.8036, -2984.0149, -1510.6842]])


In [16]:
with torch.no_grad():
    w -= w.grad * 1e-5
    b -= b.grad * 1e-5

In [17]:
# Let's verify that the loss is actually lower
loss = mse(preds, targets)
print(loss)

tensor(47290.0391, grad_fn=<DivBackward0>)


In [18]:
w.grad.zero_()
b.grad.zero_()
print(w.grad)
print(b.grad)

tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([0., 0.])


In [19]:
# Generate predictions
preds = model(inputs)
print(preds)

tensor([[263.9653,  61.4882],
        [342.2226,  89.1206],
        [430.1556,  62.0960],
        [254.7764,  72.9332],
        [328.4051,  86.2687]], grad_fn=<AddBackward0>)


In [20]:
# Calculate the loss
loss = mse(preds, targets)
print(loss)

tensor(32090.0371, grad_fn=<DivBackward0>)


In [21]:
# Compute gradients
loss.backward()
print(w.grad)
print(b.grad)

tensor([[21063.8828, 22052.8789, 13619.7734],
        [-1292.8607, -2542.7769, -1240.0760]])
tensor([247.7050, -17.6186])


In [22]:
# Adjust weights & reset gradients
with torch.no_grad():
    w -= w.grad * 1e-5
    b -= b.grad * 1e-5
    w.grad.zero_()
    b.grad.zero_()

In [23]:
print(w)
print(b)


tensor([[ 1.2761,  1.8678,  0.2277],
        [ 0.3812, -0.3505,  1.3886]], requires_grad=True)
tensor([-0.1315,  0.6221], requires_grad=True)


In [24]:
# Calculate loss
preds = model(inputs)
loss = mse(preds, targets)
print(loss)

tensor(21844.4785, grad_fn=<DivBackward0>)


In [25]:
# Train for 100 epochs
for i in range(2500):
    preds = model(inputs)
    loss = mse(preds, targets)
    loss.backward()
    with torch.no_grad():
        w -= w.grad * 1e-5
        b -= b.grad * 1e-5
        w.grad.zero_()
        b.grad.zero_()

In [26]:
# Calculate loss
preds = model(inputs)
loss = mse(preds, targets)
print(loss)

tensor(1.3148, grad_fn=<DivBackward0>)


In [27]:
preds

tensor([[ 57.2316,  70.3567],
        [ 81.4635, 101.0672],
        [120.2764, 131.9807],
        [ 21.5270,  36.7332],
        [100.3391, 120.1141]], grad_fn=<AddBackward0>)

In [28]:
targets

tensor([[ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.]])