<a href="https://colab.research.google.com/github/DanielhCarranza/Curso-Deep-Learning/blob/master/How_to_build_a_JAX_Framework.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to build a Jax Framework from Scratch


In [None]:
import gzip
import math
import pickle
import numpy as onp
import matplotlib.pyplot as plt

import jax.numpy as np

from jax import random
from jax import jacfwd, jacrev
from jax import grad, jit, vmap
from jax.scipy.special import logsumexp
from jax.nn import initializers as init


import torch
import torch.nn as nn
from torch import tensor
from torch.nn import functional as F

from collections import OrderedDict, namedtuple

from fastai import datasets

## Dataset: MNIST

Comparing Jax with Pytorch

In [None]:
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'
def get_data(URL:str=MNIST_URL, conv_type=tensor):
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(conv_type, (x_train,y_train,x_valid,y_valid))

In [None]:
x_train, y_train, x_valid, y_valid = get_data()
print(f'X shape: {x_train.shape}, Y shape: {y_valid.shape}')

print(f'X mean: {x_train.mean()}, X std: {x_train.std()}')

Downloading http://deeplearning.net/data/mnist/mnist.pkl


X shape: torch.Size([50000, 784]), Y shape: torch.Size([10000])
X mean: 0.1304190456867218, X std: 0.30728983879089355


In [None]:
# Parameters
weights = torch.randn(784,10)
bias = torch.zeros(10)

In [None]:
m1 = x_valid[:5]
m2 = weights

In [None]:
# Pytorch
%timeit -n 10 t2 = m1 @ m2

The slowest run took 479.45 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 6.81 µs per loop


### JAX 

In [None]:
x_jtrain, y_jtrain, x_jvalid, y_jvalid = get_data(conv_type=np.array)
type(x_jtrain)

jax.interpreters.xla.DeviceArray

In [None]:
key = random.PRNGKey(0)
weights = random.normal(key, (784,10), dtype=np.float32)

In [None]:
m1 = x_jvalid[:5]
m2 = weights
m1.shape, m2.shape

((5, 784), (784, 10))

In [None]:
%timeit -n 10 t2 = m1 @ m2

The slowest run took 96.64 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 275 µs per loop


In [None]:
%timeit -n 10 t2 = np.matmul(m1, m2)

10 loops, best of 3: 290 µs per loop


### FUNCTION computation comparison


In [None]:
@jit
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x)

The slowest run took 848.18 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 172 µs per loop


In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * torch.where(x > 0, x, alpha * torch.exp(x) - alpha)

x = torch.randn(1000000,)
%timeit selu(x)

The slowest run took 5.97 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 3: 8.44 ms per loop


## NN

In [None]:
@jit
def jnormalize(x, m, s):  return (x-m)/s

def normalize(x, m, s): return (x-m)/s


In [None]:
# PYTORCH
train_mean, train_std = x_train.mean(), x_train.std()
%timeit -n 10  normalize(x_train, train_mean, train_std)

10 loops, best of 3: 73.1 ms per loop


In [None]:
# JAX
train_mean, train_std = x_jtrain.mean(), x_jtrain.std()
%timeit -n 10  jnormalize(x_jtrain,train_mean, train_std)

The slowest run took 91.64 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 115 µs per loop


In [None]:
x_jtrain = jnormalize(x_jtrain, train_mean, train_std)
x_jvalid = jnormalize(x_jvalid, train_mean, train_std)
print(f'MEAN: {train_mean},STD: {train_std}')
print(f'MEAN: {x_jtrain.mean()},STD: {x_jtrain.std()}')


MEAN: 0.13044890761375427,STD: 0.3072897791862488
MEAN: 2.6451584744791035e-06,STD: 0.9999991655349731


### Basic Architecture JAX

In [None]:
nh = 50 # hidden layers
n, m = x_jtrain.shape
c = int(y_jtrain.max()+1)
n,m,c

(50000, 784, 10)

In [None]:
w1 = random.normal(key, (m, nh), dtype=np.float32)/np.sqrt(m)
b1 = np.zeros(nh)
w2 = random.normal(key, (nh,1), dtype=np.float32)/np.sqrt(nh)
b2 = np.zeros(1)

In [None]:
def linear(x, w, b): return x @ w + b

In [None]:
%timeit -n 10 linear(x_jtrain, w1, b1)

The slowest run took 63.46 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 490 µs per loop


In [None]:
t = linear(x_jtrain, w1, b1)
t.mean(), t.std()

(DeviceArray(0.16632, dtype=float32), DeviceArray(0.987166, dtype=float32))

In [None]:
@jit
def relu(x): return np.maximum(0, x) -0.5

In [None]:
%timeit -n 10 t = relu(linear(x_jvalid, w1, b1))

10 loops, best of 3: 758 µs per loop


In [None]:
t.mean(), t.std()

(DeviceArray(0.166319, dtype=float32), DeviceArray(0.987166, dtype=float32))

### Initialization

In [None]:
from jax.nn import initializers as init

In [None]:
weights= init.kaiming_normal()(key, (m,nh))
bias = np.zeros(nh)
weights.mean(), weights.std()

(DeviceArray(-0.0003, dtype=float32), DeviceArray(0.050342, dtype=float32))

In [None]:
t = relu(linear(x_jtrain, weights, bias))
t.mean(), t.std()

(DeviceArray(0.169347, dtype=float32), DeviceArray(0.894278, dtype=float32))

In [None]:
@jit
def model(xb):
    l1 = linear(xb, w1, b1)
    l2 = relu(l1)
    l3 = linear(l2, w2, b2)
    return l3

In [None]:
%timeit -n 10 _=model(x_jvalid)

The slowest run took 14.83 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 250 µs per loop


### Loss Function: MSE

In [None]:
model(x_jvalid).shape

(10000, 1)

In [None]:
@jit
def mse(output, target): return ((output.squeeze(axis=-1) - target)**2).mean()

In [None]:
preds = model(x_jtrain)
preds.shape

(50000, 1)

In [None]:
%time mse(preds, y_jtrain)

CPU times: user 1.14 ms, sys: 0 ns, total: 1.14 ms
Wall time: 741 µs


DeviceArray(28.006317, dtype=float32)

In [None]:
grad(mse)(preds, y_jtrain)

DeviceArray([[-1.855972e-04],
             [ 2.692820e-06],
             [-1.226115e-04],
             [-1.319953e-05],
             ...,
             [ 1.071363e-05],
             [-3.087608e-04],
             [-1.495328e-04],
             [-3.050688e-04]], dtype=float32)

### Gradients and Backward pass

\begin{align} 
& \bullet \text{Initialize } W^{[1]} .. W^{[L]}, b^{[1]} … b^{[L]} \\ 
& \bullet \text{Set } A^{[0]} = X \text{ ( Input ) }, L = \text{Total Layers}  \\ 
& \bullet \text{Loop } \text{epoch} = 1 \text{ to } \text{ max iteration } \\ 
& \rule{1cm}{0pt} \bullet \text{Forward Propagation} \\ 
& \rule{2cm}{0pt} \bullet \text{Loop } l=1 \text{ to } L-1 \\ 
& \rule{3cm}{0pt} \bullet Z^{[l]} = W^{[l]}A^{[l-1]}+b^{[l]} \\ 
& \rule{3cm}{0pt} \bullet A^{[l]} = g \left (  b^{[l]} \right )  \\ 
& \rule{3cm}{0pt} \bullet \text{Save }  A^{[l]},W^{[l]} \text{ in memory for later use } \\ 
& \rule{2cm}{0pt} \bullet Z^{[L]} = W^{[L]}A^{[L-1]}+b^{[L]} \\ 
& \rule{2cm}{0pt} \bullet A^{[L]} = \sigma \left (  Z^{[L]} \right )  \\ 
& \rule{1cm}{0pt} \bullet \text{Cost } J= – \frac{1}{n} \bigg(  Ylog \left ( A^{[2]} \right ) – \left ( 1-Y  \right )log \left  ( 1 – A^{[2]} \right ) \bigg)\\ 
& \rule{1cm}{0pt} \bullet \text{Backward Propagation} \\ 
& \rule{2cm}{0pt} \bullet dA^{[L]} = -\frac{Y}{A^{[L]}}  + \frac{1-Y}{1- A^{[L]}} \\ 
& \rule{2cm}{0pt} \bullet dZ^{[L]} = dA^{[L]} \sigma’\left ( dA^{[L]} \right )  \\ 
& \rule{2cm}{0pt} \bullet dW^{[L]} = dZ^{[L]} dA^{[L-1]} \\ 
& \rule{2cm}{0pt} \bullet db^{[L]} = dZ^{[L]} \\ 
& \rule{2cm}{0pt} \bullet dA^{[L-1]} = dZ^{[L]} W^{[L]} \\ 
& \rule{2cm}{0pt} \bullet \text{Loop } l=L-1 \text{ to } 1 \\ 
& \rule{3cm}{0pt} \bullet dZ^{[l]} = dA^{[l]} g’\left ( dA^{[l]} \right )  \\ 
& \rule{3cm}{0pt} \bullet dW^{[l]} = dZ^{[l]} dA^{[l-1]} \\ 
& \rule{3cm}{0pt} \bullet db^{[l]} = dZ^{[l]} \\ 
& \rule{3cm}{0pt} \bullet dA^{[l-1]} = dZ^{[l]} W^{[l]} \\ 
& \rule{1cm}{0pt} \bullet \text{Update W and b}  \\ 
& \rule{2cm}{0pt} \bullet \text{Loop } l=1 \text{ to } L \\ 
& \rule{3cm}{0pt} \bullet W^{[l]} =W^{[l]} -\alpha . dW^{[l]} \\ 
& \rule{3cm}{0pt} \bullet b^{[l]} =b^{[l]} -\alpha . db^{[l]} 
\end{align}

In [None]:
def mse_grad(inp, target):
  inp.g = 2 * (inp.squeeze() - target)[...,None]/inp.shape[0]

In [None]:
def relu_grad(inp, out):
  inp.g = float(inp>0)*out.g 

$$
d A^{[l-1]}=\frac{\partial \mathscr{L}}{\partial A^{[l-1]}}=W^{[l] T} d Z^{[l]}
$$

$$
d Z^{[l]}=d A^{[l] *} g^{\prime}\left(Z^{[l]}\right)
$$


In [None]:
def linear_grad(inp, out, w, b):
  #grad of matmul with respect to input
  inp.g = out.g @ w.T()
  w.g = (inp[...,None] * out.g[:,None,:]).sum(axis=0)
  b.g = out.g.sum(axis=0, keepdims=True)

In [None]:
def forward_and_backward(inp, target):
  # Forward
  l1 = linear(inp,w1,b1)
  l2 = relu(l1)
  out = linear(l2, w2,b2)
  mse(out, target)

  # Backward
  # mse_grad(out, target)
  linear_grad(l2, out, w2, b2)
  relu_grad(l1, l2)
  linear_grad(inp, l1, w1,b1)



In [None]:
forward_and_backward(x_jtrain,y_jtrain)

In [None]:
a = x_jtrain[:5]
a.shape

(5, 784)

In [None]:
def forward(inp, target):
  # Forward
  l1 = linear(inp,w1,b1)
  l2 = relu(l1)
  out = linear(l2, w2,b2)
  return mse(out, target)

In [None]:
loss= forward(x_jtrain, y_jtrain)

In [None]:
class Tensorjax(np.DeviceArray):
  def __init__(self, data):
    self.data = np.DeviceArray(data)
  def __repr__(self): print(f'{data.shape}, {data} ')
  def gradient(self, x):
    pass

In [None]:
# Tensorjax.gradient = 2

## JAX Framework

In [None]:
class Module():
  def __init__(self): self._modules={}
  def __call__(self, *args): 
    self.args = args
    self.out = self.forward(*args)
    return self.out
  def forward(self): raise Exception("Not implemented method")
  
  def __setattr__(self, k, v):
    if not k.startswith("_"): self._modules[k]=v
    super().__setattr__(k,v)
 
  def __repr__(self): return f'{self._modules}'
 
  def parameters(self):
    for l in self._modules.values():
      for p in l.parameters: yield p

      

In [None]:
# class Module():
#   def __call__(self, *args): 
#     self.args = args
#     self.out = self.forward(*args)
#     return self.out
#   def forward(self): raise Exception("Not implemented method")

class ReLU(Module):
  def forward(self, x):
    return np.maximum(0,x) - 0.5

class Sigmoid(Module): 
  def forward(self, x):
    return 0.5*(np.tanh(x/2) + 1)

class Linear(Module):
  def __init__(self, W, b):
    super().__init__() 
    self.W, self.b = W, b
  def forward(self, inputs):
    return (inputs @ self.W + self.b)


class LinearV2(Module):
  def forward(self, W, b, inputs):
    return (inputs @ W + b)

class Loss(Module):
  def __init__(self, inputs, targets):
    super().__init__()
    self.x, self.y = inputs, targets

  def forward(self, W, b):
    preds = Sigmoid()(LinearV2()(W, b, self.x))
    label_probs = preds*self.y + (1-preds)*(1-self.y)
    return -np.sum(np.log(label_probs))

def cross_entropy(preds, targets):
    label_probs = preds*targets + (1-preds)*(1-targets)
    return -np.sum(np.log(label_probs)) 

class LossV2(Module):
  def __init__(self, model, inp):
    self.model, self.inp = model, inp
  def forward(self, params, preds, targets):
    preds = self.model(params, self.inp)
    return cross_entropy(preds, targets)

In [None]:
# Build a toy dataset.
inputs = np.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = np.array([True, True, False, True])


# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

grad(Loss(inputs, targets), argnums=(0,1))(W, b)

(DeviceArray([-0.169656, -0.877464, -1.490134], dtype=float32),
 DeviceArray(-0.292272, dtype=float32))

In [None]:
preds = Sigmoid()(Linear(W,b)(inputs))
print(preds, preds.shape == targets.shape)

preds = Sigmoid()(LinearV2()(W,b, inputs))
print(preds, preds.shape == targets.shape)
# grad(LossV2(preds))

[0.758798 0.761036 0.038418 0.994809] True
[0.758798 0.761036 0.038418 0.994809] True


In [None]:
pred = Linear(W,b)
pred

{'W': DeviceArray([-0.09006 , -0.718414,  2.477131], dtype=float32), 'b': DeviceArray(0.090163, dtype=float32)}

In [None]:
class Linear(Module):
  def __init__(self, params): self.params = params
  def forward(self, inputs):
    return (inputs @ self.params['W'] + self.params['b'])

In [None]:
def compute_loss(params, x, targets):
  layers = [Linear(params), Sigmoid()]
  for l in layers: x = l(x) 
  return cross_entropy(x, targets)

In [None]:
grad(compute_loss)({'W':W, 'b':b}, inputs, targets)

{'W': DeviceArray([-1.72562 ,  2.016514, -2.178545], dtype=float32),
 'b': DeviceArray(-2.381826, dtype=float32)}

In [None]:
nh = 50 # hidden layers
n, m = x_jtrain.shape
c = int(y_jtrain.max()+1)
n,m,c

(50000, 784, 10)

In [None]:
def init_params(n_in, n_out, init=init.kaiming_normal(), key=random.PRNGKey(0)):
  W_key, b_key = random.split(key, 2)
  weights= init(W_key, (n_in,n_out))
  bias = np.zeros(n_out)
  return weights, bias 

In [None]:
w1, b1 = init_params(m, nh)
w2, b2 = init_params(nh, 1)

In [None]:
t = ReLU()(Linear(w1, b1)(x_jtrain))
t.mean(), t.std()

(DeviceArray(-0.352507, dtype=float32), DeviceArray(0.236954, dtype=float32))

#### Jax models

In [None]:
def log_softmax(logits):
  return logits - logsumexp(logits)[..., None]

def cross_entropy(preds, targets):
  return -np.sum(log_softmax(preds) * targets)


In [None]:
def one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

def model(x,y):
  layers = [Linear(w1,b1), ReLU(), Linear(w2,b2)]
  for l in layers: x = l(x)
  return x 

In [None]:
y = one_hot(y_jtrain, c)

In [None]:
%timeit -n 10 _ = model(x_jtrain, y)

10 loops, best of 3: 1.99 ms per loop


In [None]:
def compute_loss( x,y):
  preds= model(x,y)
  return cross_entropy(preds, y)

%time compute_loss(x_jtrain, y)

CPU times: user 5.03 ms, sys: 3.46 ms, total: 8.5 ms
Wall time: 10.7 ms


DeviceArray(543525.4, dtype=float32)

### JAX Module

In [None]:
class Module():
  def __init__(self):
    self._modules = OrderedDict()
    self._parameters = OrderedDict()

  def __call__(self, *args): 
    self.args = args
    self.out = self.forward(*args)
    return self.out

  def forward(self): raise Exception("Not implemented method")
  
  def __setattr__(self, k, v):
    if not k.startswith("_"): self._modules[k]=v
    super().__setattr__(k,v)

  def __getattr__(self, name):
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in _parameters:
                return _parameters[name]

        if '_modules' in self.__dict__:
            modules = self.__dict__['_modules']
            if name in modules:
                return modules[name]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, name))
 
  def __repr__(self): return f'{self._modules}'

  def register_parameter(self, name, param): 
    if '_parameters' not in self.__dict__:
            raise AttributeError(
                "cannot assign parameter before Module.__init__() call")
    self._parameters[name] = param

  def parameters(self):
    for l in self._modules.values():
      for p in l.parameters: yield p

In [None]:
class Model(Module):
  def __init__(self, n_in, nh, n_out):
    super().__init__()
    self.l1 = nn.Linear(n_in, nh)
    self.l2 = nn.Linear(nh, n_out)
  
  def forward(self, x):
    return self.l2(nn.ReLU((self.l2(x))))

In [None]:
nh = 50 # hidden layers
n, m = x_jtrain.shape
c = int(y_jtrain.max()+1)
n,m,c

(50000, 784, 10)

In [None]:
mo = Model(m, nh, c)
mo

OrderedDict([('l1', Linear(in_features=784, out_features=50, bias=True)), ('l2', Linear(in_features=50, out_features=10, bias=True))])

In [None]:
mo.parameters

<bound method Module.parameters of OrderedDict([('l1', Linear(in_features=784, out_features=50, bias=True)), ('l2', Linear(in_features=50, out_features=10, bias=True))])>

In [None]:
class Model(nn.Module):
  def __init__(self, n_in, nh, n_out):
    super().__init__()
    self.l1 = nn.Linear(n_in, nh)
    self.l2 = nn.Linear(nh, n_out)
  
  def forward(self, x):
    return self.l2(nn.ReLU((self.l2(x))))

In [None]:
mo = Model(m, nh, c)
mo

Model(
  (l1): Linear(in_features=784, out_features=50, bias=True)
  (l2): Linear(in_features=50, out_features=10, bias=True)
)

In [None]:
mo.parameters

<bound method Module.parameters of Model(
  (l1): Linear(in_features=784, out_features=50, bias=True)
  (l2): Linear(in_features=50, out_features=10, bias=True)
)>

#### JAX PARAMS

In [None]:
def init_params(n_in, n_out, init=init.kaiming_normal(), key=random.PRNGKey(0)):
  W_key, b_key = random.split(key, 2)
  weights= init(W_key, (n_in,n_out))
  bias = np.zeros(n_out)
  return weights, bias

In [None]:
nn.Parameter??

In [None]:
from jax.interpreters.xla import  DeviceArray


In [None]:
class Parameter(onp.ndarray, metaclass=_ArrayMeta):
  def __new__(cls, data=None, *args):
    # data=np.asarray(data).view(cls)
    return super(Parameter, cls).__new__(cls,data, *args)

  def __repr__(self):
    return 'Parameter containing:\n' + super(Parameter, self).__repr__()

In [None]:
type(Parameter(w1)), Parameter(w1)

In [None]:
class Parameter(onp.ndarray):
    def __new__(cls, x, *args, **kwargs): 
        return super().__new__(cls, x, *args, **kwargs) 

In [None]:
p = Parameter((3,))
isinstance(p, Parameter)

True

In [None]:
from jax.lax import convert_element_type
convert_element_type?

In [None]:
a = onp.array([1.0,2.0,3.0])
convert_element_type(a, new_dtype=np.float32)
# np.array

array([1., 2., 3.])

In [None]:
class InfoArray(onp.ndarray):

    def __new__(subtype, shape, dtype=float, buffer=None, offset=0,
                strides=None, order=None, info=None):
        # Create the ndarray instance of our type, given the usual
        # ndarray input arguments.  This will call the standard
        # ndarray constructor, but return an object of our type.
        # It also triggers a call to InfoArray.__array_finalize__
        return super(InfoArray, subtype).__new__(subtype, shape, dtype,
                                                buffer, offset, strides,
                                                order)
         

In [None]:
obj = InfoArray(shape=(3,), )

In [None]:
w1, b1 = init_params(m, nh)
w2, b2 = init_params(nh, 1)

In [None]:
type(w1)

jax.interpreters.xla.DeviceArray

### Linear

In [None]:
class Model(Module):
  def __init__(self, params):
    self.layers = [Linear(w1,b1), ReLU(), Linear(w2,b2)]
    self.loss = cross_entropy

  def forward(self, x, y):
    for l in self.layers: x = l(x)
    if hasattr(l, 'weight'):
      
    return self.loss(x, y)

In [None]:
def linear(inp, weight, bias=None):
  if inp.size()==2 and bias is not None:
    ret = inp @ weight + bias
  else:
    output = inp @ weight.T
    if bias is not None: output += bias
    ret = output
  return ret

class Linear(Module):
  def __init__(self, in_features, out_features, bias=True): 
    super().__init__()
    self.weight, self.bias =  init_params(in_features, out_features)
  def forward(self, inputs):
    return linear(inputs, self.weight, self.bias )

In [None]:
l1 = Linear(m,nh)
l1

{'weight': DeviceArray([[-0.049711, -0.035382, -0.054953, -0.001373, ..., -0.021679, -0.109848, -0.020343, -0.039765],
             [ 0.00608 ,  0.002496,  0.020257, -0.107654, ...,  0.030652,  0.010056,  0.00069 ,  0.016928],
             [ 0.08513 ,  0.05751 , -0.005576, -0.059141, ..., -0.057368, -0.066713,  0.042022,  0.050596],
             [ 0.035249, -0.006973, -0.065923, -0.094957, ...,  0.066852, -0.022646,  0.082308, -0.028639],
             ...,
             [ 0.03995 ,  0.041362, -0.084902, -0.015864, ..., -0.00484 ,  0.048652,  0.010824,  0.015427],
             [-0.012211, -0.05021 , -0.039358,  0.058195, ..., -0.039469,  0.074852,  0.073005,  0.024522],
             [ 0.034812,  0.076179, -0.075141, -0.097438, ..., -0.09212 , -0.078528,  0.021074, -0.090035],
             [ 0.027886, -0.005652, -0.025512, -0.017421, ...,  0.006612,  0.004661, -0.097214, -0.013013]],            dtype=float32), 'bias': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [None]:
l2 = nn.Linear(m,nh)
l2.parameters

<bound method Module.parameters of Linear(in_features=784, out_features=50, bias=True)>