In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
X = torch.randn(5, 2)

In [3]:
X

tensor([[ 0.2805, -1.4055],
        [ 0.2095,  0.8429],
        [ 0.6451, -1.1605],
        [ 0.4752,  0.2678],
        [ 0.3381, -1.1565]])

In [4]:
model = nn.Sequential(
    nn.Linear(2, 10),
    nn.ReLU(),
    nn.Linear(10, 1)
)


In [5]:
model

Sequential(
  (0): Linear(in_features=2, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=1, bias=True)
)

In [6]:
output = model(X)
print(output)

tensor([[-0.2311],
        [-0.0253],
        [-0.1156],
        [ 0.0276],
        [-0.1791]], grad_fn=<AddmmBackward0>)


In [7]:
print(X.shape)
print(output.shape)

torch.Size([5, 2])
torch.Size([5, 1])


In [8]:
for name, param in model.named_parameters():
    print(name, param.shape)

0.weight torch.Size([10, 2])
0.bias torch.Size([10])
2.weight torch.Size([1, 10])
2.bias torch.Size([1])


In [9]:
y = torch.tensor([[1.0],
                  [0.0],
                  [1.0],
                  [0.0],
                  [1.0]])

In [10]:
print(y.shape)

torch.Size([5, 1])


In [11]:
y_pred = model(X)
print(y_pred)

tensor([[-0.2311],
        [-0.0253],
        [-0.1156],
        [ 0.0276],
        [-0.1791]], grad_fn=<AddmmBackward0>)


In [12]:
loss_fn = nn.MSELoss()

In [13]:
loss = loss_fn(y_pred, y)
print(loss)

tensor(0.8303, grad_fn=<MseLossBackward0>)


In [14]:
print(y_pred.shape, y.shape)

torch.Size([5, 1]) torch.Size([5, 1])


In [15]:
y_pred = y.clone()
loss = loss_fn(y_pred, y)
print(loss)

tensor(0.)


In [16]:
y_pred = y + torch.randn_like(y) * 2
loss = loss_fn(y_pred, y)
print(loss)

tensor(3.9343)


In [17]:
y_pred = model(X)
loss = loss_fn(y_pred, y)
print(loss)


tensor(0.8303, grad_fn=<MseLossBackward0>)


In [18]:
model.zero_grad()

In [19]:
print(model[0].weight.grad)

None


In [20]:
loss.backward()
# Weights are NOT updated
# Only gradients are calculated
# does NOT change weights only computes how to change them

In [21]:
print(model[0].weight.grad)

tensor([[-4.2003e-04,  7.5167e-04],
        [ 4.9719e-02, -1.5068e-01],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [-6.1866e-02,  1.8550e-01],
        [-2.8579e-03,  5.1411e-03],
        [ 1.7522e-01, -5.2537e-01],
        [ 0.0000e+00,  0.0000e+00],
        [-8.1571e-04,  1.4598e-03]])


In [22]:
for name, param in model.named_parameters():
    print(name)
    print(param.grad)
    print()

0.weight
tensor([[-4.2003e-04,  7.5167e-04],
        [ 4.9719e-02, -1.5068e-01],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [-6.1866e-02,  1.8550e-01],
        [-2.8579e-03,  5.1411e-03],
        [ 1.7522e-01, -5.2537e-01],
        [ 0.0000e+00,  0.0000e+00],
        [-8.1571e-04,  1.4598e-03]])

0.bias
tensor([-1.2145e-04,  1.1991e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
        -1.4903e-01, -4.4695e-03,  4.2208e-01,  0.0000e+00, -2.3587e-04])

2.weight
tensor([[-1.8768e-03, -5.8643e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.2052e-01, -4.3309e-02, -1.3546e+00,  0.0000e+00, -1.1036e-03]])

2.bias
tensor([-1.4094])



In [23]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# parameter = parameter - lr * gradient

In [24]:
old_weight = model[0].weight.clone()

In [25]:
model.zero_grad()          # clear old gradients
y_pred = model(X)          # forward pass
loss = loss_fn(y_pred, y)  # compute loss
loss.backward()            # compute gradients
optimizer.step()           # UPDATE WEIGHTS

In [26]:
param.grad

tensor([-1.4094])

In [27]:
new_weight = model[0].weight
print(torch.allclose(old_weight, new_weight))

False


In [28]:
y_pred_new = model(X)
loss_new = loss_fn(y_pred_new, y)

print("Old loss:", loss.item())
print("New loss:", loss_new.item())

Old loss: 0.8303484916687012
New loss: 0.4408758580684662


In [29]:
for epoch in range(10):

    optimizer.zero_grad()
    y_pred = model(X)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item()}")


Epoch 0, Loss: 0.4408758580684662
Epoch 1, Loss: 0.2788845896720886
Epoch 2, Loss: 0.19391673803329468
Epoch 3, Loss: 0.14309482276439667
Epoch 4, Loss: 0.10971345752477646
Epoch 5, Loss: 0.08617296069860458
Epoch 6, Loss: 0.0686701089143753
Epoch 7, Loss: 0.05516562610864639
Epoch 8, Loss: 0.044497959315776825
Epoch 9, Loss: 0.035961706191301346
