In [6]:

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.nn.functional import conv2d, max_pool2d, cross_entropy


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
torch.cuda.set_device(device)


def init_weights(shape,device =  torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
    # Kaiming He initialization (a good initialization is important)
    # https://arxiv.org/abs/1502.01852
    std = np.sqrt(2. / shape[0])
    w = torch.randn(size=shape, device = device) * std
    w.requires_grad = True
    print(w.is_leaf)
    return w


class RMSprop(optim.Optimizer):
    """
    This is a reduced version of the PyTorch internal RMSprop optimizer
    It serves here as an example
    """

    def __init__(self, params, lr=1e-3, alpha=0.5, eps=1e-8):
        defaults = dict(lr=lr, alpha=alpha, eps=eps)
        super(RMSprop, self).__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                grad = p.grad.data
                state = self.state[p]

                # state initialization
                if len(state) == 0:
                    state['square_avg'] = torch.zeros_like(p.data)

                square_avg = state['square_avg']
                alpha = group['alpha']

                # update running averages
                square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
                avg = square_avg.sqrt().add_(group['eps'])

                # gradient update
                p.data.addcdiv_(grad, avg, value=-group['lr'])



# input shape is (B, 784)
w_h = init_weights((784, 625))
# hidden layer with 625 neurons
w_h2 = init_weights((625, 625)).to(device)
# hidden layer with 625 neurons
w_o = init_weights((625, 10)).to(device)
# output shape is (B, 10)

print(w_h.is_leaf,w_h2.is_leaf,w_o.is_leaf)
optimizer = RMSprop(params=[w_h, w_h2, w_o])

Device: cuda:0
True
True
True
True True True
