# PyTorch basics

In [1]:
import torch
from torch.autograd import Variable

#### PyTorch has Tensors, too

In [2]:
dtype = torch.FloatTensor
#dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

X = torch.randn(3, 4).type(dtype)
X


-1.4722e-01 -8.1976e-01  1.0090e+00 -5.6232e-01
-5.9177e-05  2.7463e-01  5.0102e-01  4.7825e-01
 9.0050e-01 -2.3513e-01 -5.9586e-01  9.0564e-03
[torch.FloatTensor of size 3x4]

#### Tensors are wrapped in Variables that will also store gradients

In [3]:
X_Var = torch.autograd.Variable(X)

### Porting the numpy network to PyTorch

In [4]:
n = 64
num_features = 1000
hidden_dim = 100
output_dim = 10

learning_rate = 1e-6
num_epochs = 500

In [5]:
dtype = torch.FloatTensor

In [6]:
# we do not need to compute gradients with respect to these Variables during the backward pass
X = Variable(torch.randn(n, num_features).type(dtype), requires_grad=False)
y = Variable(torch.randn(n, output_dim).type(dtype), requires_grad=False)

In [7]:
X.data


 1.2686e-01 -3.9124e-01 -2.4620e-01  ...  -5.0728e-01  1.5596e+00 -1.0778e+00
-2.5183e-01 -1.4531e+00 -1.3367e+00  ...  -1.3785e+00 -7.9791e-01  1.2874e+00
-3.4469e-01  5.2509e-02  1.1070e+00  ...   5.4420e-01  7.3204e-01 -9.4370e-01
                ...                   ⋱                   ...                
 1.5243e-01 -9.3061e-02 -2.3743e+00  ...  -3.1264e-01 -7.1904e-01  5.1194e-02
-1.9438e-02 -1.0221e+00  3.3764e-02  ...  -5.2677e-01  8.6042e-01 -3.9284e-01
 9.5267e-01  1.0459e+00  9.6335e-01  ...   3.6626e-01 -3.8677e-01 -2.4536e+00
[torch.FloatTensor of size 64x1000]

In [8]:
# for the weights we do need to ;-)
W1 = Variable(torch.randn(num_features, hidden_dim).type(dtype), requires_grad=True)
W2 = Variable(torch.randn(hidden_dim, output_dim).type(dtype), requires_grad=True)

In [9]:
for epoch in range(num_epochs):
    
  # Forward pass
  # We do not need to keep references to intermediate values
  # since we are not implementing the backward pass by hand!
  y_pred = X.mm(W1).clamp(min=0).mm(W2)
  
  # Compute and print loss using operations on Variables.
  # loss.data is a Tensor of shape (1,); loss.data[0] is a scalar value holding the loss.
  loss = (y_pred - y).pow(2).sum()
  print(epoch, loss.data[0])
  
  # Use autograd to compute the backward pass. 
  # After this call W1.grad and W2.grad will be Variables holding the gradient
  # of the loss with respect to W1 and W2 respectively.
  loss.backward()

  # Update weights using gradient descent
  W1.data -= learning_rate * W1.grad.data
  W2.data -= learning_rate * W2.grad.data

  # Manually zero the gradients after updating the weights
  W1.grad.data.zero_()
  W2.grad.data.zero_()


0 30955566.0
1 28220284.0
2 27565712.0
3 25139144.0
4 19944146.0
5 13496688.0
6 8146869.0
7 4691667.0
8 2782156.25
9 1778751.875
10 1242939.75
11 936592.25
12 744792.8125
13 613280.9375
14 516364.71875
15 441003.625
16 380298.3125
17 330281.75
18 288491.3125
19 253161.125
20 223079.953125
21 197273.65625
22 175035.0
23 155773.5625
24 139020.171875
25 124395.65625
26 111590.6796875
27 100345.65625
28 90434.1640625
29 81677.4765625
30 73914.3125
31 67012.2734375
32 60862.71875
33 55374.71875
34 50461.24609375
35 46052.8359375
36 42091.40625
37 38524.94921875
38 35307.55859375
39 32399.353515625
40 29767.408203125
41 27381.5078125
42 25217.484375
43 23251.990234375
44 21462.47265625
45 19830.716796875
46 18341.05078125
47 16981.291015625
48 15737.8408203125
49 14598.416015625
50 13553.9521484375
51 12598.0458984375
52 11719.1259765625
53 10909.87890625
54 10163.8916015625
55 9475.7265625
56 8840.0810546875
57 8252.7158203125
58 7710.03759765625
59 7207.73583984375
60 6741.912109375
61 630

386 0.002698805183172226
387 0.002600964391604066
388 0.002512119011953473
389 0.0024262038059532642
390 0.00233875191770494
391 0.002259625354781747
392 0.0021830322220921516
393 0.002109852386638522
394 0.002036720747128129
395 0.0019690992776304483
396 0.0019022980704903603
397 0.0018378103850409389
398 0.001779127516783774
399 0.001720631611533463
400 0.0016638955567032099
401 0.001609415514394641
402 0.0015556184807792306
403 0.001505737192928791
404 0.0014569256454706192
405 0.001410250086337328
406 0.00136574637144804
407 0.0013204481219872832
408 0.0012795031070709229
409 0.0012397506507113576
410 0.001201357808895409
411 0.0011645667254924774
412 0.0011273581767454743
413 0.0010932261357083917
414 0.001060563139617443
415 0.001028475002385676
416 0.0009991417173296213
417 0.0009690485894680023
418 0.0009396715904586017
419 0.0009118677699007094
420 0.0008848932920955122
421 0.0008597898995503783
422 0.0008356799371540546
423 0.0008109200862236321
424 0.0007883226498961449
425 