# Parameter

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))

from utils import *

In [3]:
#export
class Parameter():
    def __init__(self, data, requires_grad=True):
        '''Model parameter class with tensor data and gradient to imitate a basic pytorch tensor.
            data: tensor data (with autograd turned off)
            requires_grad: whether gradient of data is computed
        '''
        self.data = data if data != None else torch.Tensor()
        self.requires_grad = requires_grad
        self.grad = 0.
        
    def __get__(self, instance, owner): return self.data
    
    def step(self, learning_rate): self.data -= learning_rate * self.grad
    
    def zero_data(self): self.data.zero_()
    
    def zero_grad(self): self.grad = 0.
    
    def update(self, grad): self.grad = grad
        
    def __repr__(self): return f'shape: {tuple(self.data.shape)}, grad: {self.requires_grad}'

# Tests

In [4]:
x = Parameter(torch.randn(3, 3))

In [5]:
print(x)
print(x.grad)
print(x.data)

shape: (3, 3), grad: True
0.0
tensor([[ 0.0327, -0.7079, -0.3881],
        [-1.1837,  1.1981, -0.4327],
        [-1.2923, -0.5472, -0.1192]])


In [6]:
x.zero_data()
print(x.data)

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
