## PyTorch: Optim

Upto this point we have udpated the weights of our models by manually mutating the Tensors holding learnable parameters (with ```torch.no_grad()``` or ```.data``` to avoid tracking history in autograd).

The ```optim``` package in PyTorch abstracts the idea of an optimization algorithm and provides implementations of commonly used optimization algorithms.

In [1]:
import torch

# N: batch size
# D_in: input dimension
# H: hidden dimension
# D_out: output dimension
N, D_in, H, D_out = 64, 1000, 100, 10

In [2]:
# create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

In [3]:
# we will use nn package to define our model and loss function
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

loss_fn = torch.nn.MSELoss(reduction='sum')

In [4]:
# Use the optim package to define an Optimizer
# that will update weights. Here we will use Adam;
# the optim package contains many other optimization algorithms.
# The first arguement to the Adam constructor tells the optimizer
# which Tensors it should update.

learning_rate = 1e-4 # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [5]:
# no of epochs
epochs = 500

In [6]:
for t in range(epochs):
    # Forward pass: compute predicted y by passing x to the model
    y_pred = model(x)
    
    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    # Before the backward pass, we use the optimizer object
    # to zero all of the gradients for the variables it will update
    # This is because 
    optimizer.zero_grad()
    
    # Backward pass: compute gradients of the loss w.r.t. model
    # parameters
    loss.backward()
    
    # calling the step function on an Optimizer makes an update
    # to its parameters
    optimizer.step()

0 657.004150390625
1 640.4996948242188
2 624.4949340820312
3 608.9241333007812
4 593.7139892578125
5 578.9049682617188
6 564.523681640625
7 550.6159057617188
8 537.0457153320312
9 523.79833984375
10 510.9039306640625
11 498.3888854980469
12 486.2148132324219
13 474.3717956542969
14 462.813232421875
15 451.62469482421875
16 440.7097473144531
17 430.0853576660156
18 419.740478515625
19 409.67156982421875
20 399.8558654785156
21 390.2663269042969
22 380.8879699707031
23 371.7198486328125
24 362.83331298828125
25 354.1883544921875
26 345.746337890625
27 337.52569580078125
28 329.4541015625
29 321.53680419921875
30 313.8265380859375
31 306.327392578125
32 298.9854431152344
33 291.8112487792969
34 284.8535461425781
35 278.0677185058594
36 271.4591064453125
37 264.978271484375
38 258.640380859375
39 252.4136505126953
40 246.33876037597656
41 240.39572143554688
42 234.56874084472656
43 228.88214111328125
44 223.3340301513672
45 217.90415954589844
46 212.56756591796875
47 207.33726501464844
48 