In [1]:
import torch
import torch.nn as nn

In [32]:
# Typical training procedure

# 1) Design model (input, output, forward pass with different layers)
# 2) Construct loss and optimizer
# 3) Training loop
#       - Forward = compute prediction and loss
#       - Backward = compute gradients
#       - Update weights


X = torch.tensor([[1], [2], [3], [4]], dtype=torch.float32)
Y = torch.tensor([[2], [4], [6], [8]], dtype=torch.float32)

X_test = torch.tensor([5], dtype=torch.float32)

# 4 samples, 1 feature
n_sample, n_features = X.shape
print(n_sample, n_features)

input_size = n_features
output_size = n_features

# model = nn.Linear(input_size, output_size)

class LinearRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        # Define layers
        self.lin = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        return self.lin(x)
    
model = LinearRegression(input_size, output_size)


4 1


In [33]:
print(f'Prediction before training: f(5) = {model(X_test).item():.3f}')

# Training
learning_rate = 0.01
n_iters = 200

loss = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

for epoch in range(n_iters):
    # predict = forward pass
    y_pred = model(X)

    # loss
    l = loss(Y, y_pred)
    
    # calculate gradients - backward pass
    l.backward() # dl/dw

    optimizer.step()

    optimizer.zero_grad()

    if epoch % 10 == 0:
        [w, b] = model.parameters()
        print(f'epoch {epoch+1}: w = {w[0][0].item():.3f}, b = {b[0].item()}, loss = {l:.8f}')
     
print(f'Prediction after training: f(5) = {model(X_test).item():.3f}')

Prediction before training: f(5) = -4.361
epoch 1: w = -0.572, b = 0.7319250702857971, loss = 58.56463623
epoch 11: w = 1.190, b = 1.2836995124816895, loss = 1.83438468
epoch 21: w = 1.485, b = 1.338010311126709, loss = 0.34805971
epoch 31: w = 1.544, b = 1.3133221864700317, loss = 0.29210851
epoch 41: w = 1.564, b = 1.2769142389297485, loss = 0.27418274
epoch 51: w = 1.578, b = 1.239579439163208, loss = 0.25820002
epoch 61: w = 1.591, b = 1.203025460243225, loss = 0.24317083
epoch 71: w = 1.603, b = 1.1674995422363281, loss = 0.22901700
epoch 81: w = 1.615, b = 1.133014440536499, loss = 0.21568689
epoch 91: w = 1.626, b = 1.0995465517044067, loss = 0.20313281
epoch 101: w = 1.637, b = 1.067067265510559, loss = 0.19130953
epoch 111: w = 1.648, b = 1.0355472564697266, loss = 0.18017432
epoch 121: w = 1.658, b = 1.0049583911895752, loss = 0.16968721
epoch 131: w = 1.668, b = 0.9752731323242188, loss = 0.15981054
epoch 141: w = 1.678, b = 0.9464646577835083, loss = 0.15050869
epoch 151: w