In [None]:
# 1) Design model (input_size, output_size, forward pass)
# 2) Construct loss and optimizer
# 3) Training loop
# - forward pass: compute prediction
# - backward pass: gradients
# - update weights

In [1]:
import torch 
import torch.nn as nn

In [10]:
# f = w * x

# f = 2 * x , w = 2
X = torch.tensor([[1],[2],[3],[4]], dtype=torch.float32)
Y = torch.tensor([[2],[4],[6],[8]], dtype=torch.float32)


X_test = torch.tensor([5], dtype=torch.float32)

n_samples, n_features = X.shape

print(n_samples, " ", n_features)
               
input_size = n_features
output_size = n_features

# model prediction
model = nn.Linear(input_size, output_size)

class LinearRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        # define layers
        self.lin = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.lin(x)

model = LinearRegression(input_size, output_size)
print(f'Prediction before training: f(5) = {model(X_test).item():.3f}')

# Training
learning_rate = 0.01
n_iters = 300

loss = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(n_iters):
    # prediction = forward pass
    y_pred = model(X)
    
    # loss
    l = loss(Y, y_pred)
    
    # gradients
    l.backward() # dl/dw
    
    # update weights
    optimizer.step()
    
    # zero gradients
    optimizer.zero_grad()
    
    if epoch % 30 == 0:
        [w,b] = model.parameters()
        print(f'epoch {epoch+1}: w = {w[0][0].item():.3f}, loss = {l:.8f}')
        
print(f'Prediction after training: f(5) = {model(X_test).item():.3f}')

4   1
Prediction before training: f(5) = 2.419
epoch 1: w = 0.653, loss = 16.45957756
epoch 31: w = 1.774, loss = 0.07174120
epoch 61: w = 1.797, loss = 0.05969217
epoch 91: w = 1.815, loss = 0.04986390
epoch 121: w = 1.831, loss = 0.04165378
epoch 151: w = 1.845, loss = 0.03479553
epoch 181: w = 1.859, loss = 0.02906646
epoch 211: w = 1.871, loss = 0.02428071
epoch 241: w = 1.882, loss = 0.02028291
epoch 271: w = 1.892, loss = 0.01694333
Prediction after training: f(5) = 9.796
