In [120]:
# 1) Design model (input, output size, forward pass)
# 2) Construct loss and optimizer
# 3) training loop
#       - Forward pass: compute prediction
#       - Backward pass: gradients
#       - update weights

In [121]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib as plt

In [122]:
X = torch.tensor([1,2,3,4], dtype=torch.float32).reshape(-1,1)
Y = torch.tensor([2,4,6,8], dtype=torch.float32).reshape(-1,1)

X_test = torch.tensor([5],dtype=torch.float32)
n_samples, n_features = X.shape

In [123]:
# model prediction

input_size = n_features
output_size = n_features

class LinearRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        #define layers
        self.lin = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.lin(x)

model = LinearRegression(input_size, output_size)


In [124]:
print(f'Prediction before training: f(5) = {model(X_test).item()}')
#Training
learning_rate = 0.01
it = 5000

loss = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)


for epoch in range(it):
    #prediction = forward pass
    y_pred = model(X)

    #loss
    l = loss(Y,y_pred)

    #gradients = backward pass
    l.backward() # dl/dw

    #update weights
    optimizer.step()

    # zero gradient
    optimizer.zero_grad()

    if epoch % 500 == 0:
        [w,b] = model.parameters()
        print(f'epoch {epoch+1}: w = {w[0][0].item():.3f}, loss = {l:.8f}')

print(f'Prediction after training: f(5) = {model(X_test).item():.3f}')

Prediction before training: f(5) = -1.4166738986968994
epoch 1: w = 0.265, loss = 43.62630844
epoch 501: w = 2.015, loss = 0.00032484
epoch 1001: w = 2.003, loss = 0.00001620
epoch 1501: w = 2.001, loss = 0.00000081
epoch 2001: w = 2.000, loss = 0.00000004
epoch 2501: w = 2.000, loss = 0.00000000
epoch 3001: w = 2.000, loss = 0.00000000
epoch 3501: w = 2.000, loss = 0.00000000
epoch 4001: w = 2.000, loss = 0.00000000
epoch 4501: w = 2.000, loss = 0.00000000
Prediction after training: f(5) = 10.000
