Regression

In [26]:
from __future__ import print_function
from itertools import count

import torch
import torch.nn.functional as F

POLY_DEGREE = 4
W_target = torch.randn(POLY_DEGREE, 1)*5 # for Range setting
b_target = torch.randn(1)*5

def make_features(x):
  """Builds features a matrix with columns [x, x^2, x^3, x^4]"""
  # squeeze : 1인 차원을 제거하는 함수
  # unsqueeze : 1인 차원을 생성하는 함수ef make_features(x):

  x = x.unsqueeze(1)

  return torch.cat([x**i for i in range(1, POLY_DEGREE+1)], 1)

def f(x):
  """Approximated function"""
  # mm : matrix multiplication 
  # x.mm(w) = x1w1+x2w2+...+xnwn
  return x.mm(W_target) + b_target.item()

def poly_desc(W, b):
  """Creates a string description of a polynomial"""
  result = 'y = '
  for i, w in enumerate(W):
    result += '{:+.2f} x^{}'.format(w, i+1)
  result += '{:+.2f}'.format(b[0])
  return result

def get_batch(batch_size = 32):
  """Builds a batch i. e. (x, f(x)) pair"""
  random = torch.randn(batch_size)
  x = make_features(random)
  y = f(x)
  return x, y

#Define Model
fc = torch.nn.Linear(W_target.size(0), 1)
#count(1) -> 1, 2, 3, 4 ...
for batch_idx in count(1):
  #Get data
  batch_x, batch_y = get_batch()
  #Reset Gradients
  fc.zero_grad()
  #Forward Pass
  output = F.smooth_l1_loss(fc(batch_x), batch_y)#L1과 L2 norm의 조합으로 만들어진 Error 함수
  # regurarization을 위해 사용함
  loss = output.item()
  #Backward pass
  output.backward()

  #Apply gradient
  #fc.parameters : Linear에 사용된 parameters  -> w1~w4를 불러옴
  #param.grad : fc.parameters의 gradient를 불러옴
  #param.data.add(-0.1*param.grad) -> 0.1의 학습률로 grad 업데이트
  for param in fc.parameters():
    param.data.add_(-0.1*param.grad)  
  #Stop criterion

  if loss <1e-3:
    break

print('Loss : {:.6f} after {} batched'.format(loss, batch_idx))
print('==> Learned function : \t'+poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function : \t' + poly_desc(W_target.view(-1), b_target))

Loss : 0.000714 after 685 batched
==> Learned function : 	y = +2.85 x^1-11.08 x^2-0.44 x^3-2.95 x^4-1.23
==> Actual function : 	y = +2.89 x^1-10.98 x^2-0.47 x^3-2.99 x^4-1.24
