In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
from torch.distributions.exponential import Exponential
import jax
from jax import random
from jax import grad, jit
import jax.numpy as np # using jax.numpy instead
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS


In [9]:
N = 50; J = 2
X = random.normal(random.PRNGKey(seed = 123), (N, J))
weight = np.array([1.5, -2.8])
error = random.normal(random.PRNGKey(234), (N, )) # standard Normal
b = 10.5
y_obs = X @ weight + b + error
y = y_obs.reshape((N, 1))
X = jax.device_get(X) # convert jax array into numpy array
y = jax.device_get(y) # convert jax array into numpy array
x_data = Variable(torch.from_numpy(X), requires_grad=True)
y_data = Variable(torch.from_numpy(y), requires_grad=True)


  x_data = Variable(torch.from_numpy(X), requires_grad=True)


In [10]:
class lm_model(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
        self.bias = torch.nn.Parameter(torch.randn(out_features))
    def forward(self, input):
        x, y = input.shape
        if y != self.in_features:
            print(f'Wrong Input Features. Pls use tensor with {self.in_features} Input Features')
            return 0
        output = input.matmul(self.weight.t())
        if self.bias is not None:
            output += self.bias
        return output


In [17]:
in_features = 2
out_features = 1

In [20]:
model = lm_model(in_features, out_features)
criterion = nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
for epoch in range(8000):
    y_predict = model(x_data) 
    loss = criterion(y_predict, y_data)
    if (epoch + 1) % 2000 == 0 or epoch % 2000 == 0:
        print(epoch)
        print("Estimated weights: ", model.weight.data)
        print("Estimated bias: ", model.bias.data.item())
        print("Estimated loss: ", loss.data.item())
        print("====================")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


0
Estimated weights:  tensor([[0.4765, 1.1778]])
Estimated bias:  0.07964716851711273
Estimated loss:  7097.109375
1999
Estimated weights:  tensor([[ 1.6097, -2.4827]])
Estimated bias:  10.717047691345215
Estimated loss:  56.24026870727539
2000
Estimated weights:  tensor([[ 1.6097, -2.4827]])
Estimated bias:  10.717047691345215
Estimated loss:  56.24026870727539
3999
Estimated weights:  tensor([[ 1.6097, -2.4827]])
Estimated bias:  10.717047691345215
Estimated loss:  56.24026870727539
4000
Estimated weights:  tensor([[ 1.6097, -2.4827]])
Estimated bias:  10.717047691345215
Estimated loss:  56.24026870727539
5999
Estimated weights:  tensor([[ 1.6097, -2.4827]])
Estimated bias:  10.717047691345215
Estimated loss:  56.24026870727539
6000
Estimated weights:  tensor([[ 1.6097, -2.4827]])
Estimated bias:  10.717047691345215
Estimated loss:  56.24026870727539
7999
Estimated weights:  tensor([[ 1.6097, -2.4827]])
Estimated bias:  10.717047691345215
Estimated loss:  56.24026870727539


In [21]:
class lm_model_lik(nn.Module):
    def __init__(self):
        super().__init__()
        ws = Uniform(-10, 10).sample((3, 1))
        self.weights = torch.nn.Parameter(ws)
        self.sigma = torch.nn.Parameter(Uniform(0, 2).sample())
    
    def forward(self, input, output):
        prior_weights = Normal(0, 10).log_prob(self.weights).sum()
        prior_sigma = Exponential(2.0).log_prob(self.sigma)
        i, _= input.shape
        y_hat = input @ self.weights[0:2] + self.weights[2:3]
        y_hat = y_hat.view(i, -1)
        lik = Normal(y_hat, self.sigma).log_prob(output).sum()
        LL = lik + prior_weights + prior_sigma
        return -LL


In [22]:
model = lm_model_lik()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(10000):
    neg_log_lik = model(x_data, y_data)
    if (epoch + 1) % 1000 == 0 or epoch % 1000 == 0:
        print(epoch)
        print("Estimated beta1: ", model.weights.data.view(3))
        print("Estimated sigma: ", model.sigma.item())
        print("Estimated neg logLik: ", neg_log_lik.item())
        print("====================")
    optimizer.zero_grad()
    neg_log_lik.backward()
    optimizer.step()

0
Estimated beta1:  tensor([ 2.3313, -4.3577, -4.0766])
Estimated sigma:  0.05082345008850098
Estimated neg logLik:  2063562.75
999
Estimated beta1:  tensor([ 1.4148, -5.2293, -1.8804])
Estimated sigma:  0.5376570820808411
Estimated neg logLik:  13217.0908203125
1000
Estimated beta1:  tensor([ 1.4146, -5.2293, -1.8786])
Estimated sigma:  0.5378520488739014
Estimated neg logLik:  13203.8466796875
1999
Estimated beta1:  tensor([ 1.3909, -5.0172, -0.1442])
Estimated sigma:  0.6939980387687683
Estimated neg logLik:  5940.3095703125
2000
Estimated beta1:  tensor([ 1.3909, -5.0168, -0.1425])
Estimated sigma:  0.6941274404525757
Estimated neg logLik:  5936.25830078125
2999
Estimated beta1:  tensor([ 1.4272, -4.6321,  1.6245])
Estimated sigma:  0.807281494140625
Estimated neg logLik:  3116.16845703125
3000
Estimated beta1:  tensor([ 1.4272, -4.6317,  1.6263])
Estimated sigma:  0.8073810338973999
Estimated neg logLik:  3114.208251953125
3999
Estimated beta1:  tensor([ 1.4657, -4.1898,  3.4995])

In [23]:
def model(X, y=None):
    ndims = np.shape(X)[-1]
    ws = numpyro.sample('betas', dist.Normal(0.0,10*np.ones(ndims)))
    b = numpyro.sample('b', dist.Normal(0.0, 10.0))
    sigma = numpyro.sample('sigma', dist.Uniform(0.0, 10.0))
    mu = X @ ws + b
    return numpyro.sample("y", dist.Normal(mu, sigma), obs = y)


In [25]:

X = jax.device_put(X) # convert numpy array into jax array
y = jax.device_put(y) # convert numpy array into jax array
nuts_kernel = NUTS(model)
num_warmup, num_samples = 500, 500
mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(random.PRNGKey(0), X, y = y_obs)
mcmc.get_samples()

sample: 100%|█| 1000/1000 [00:03<00:00, 288.02it/s, 3 steps of size 6.90e-01. ac


{'b': DeviceArray([10.550515 , 10.49154  , 10.707619 , 10.8568945, 10.642052 ,
              10.475721 , 10.834875 , 10.834875 , 10.826376 , 10.554865 ,
              10.760931 , 10.499135 , 10.536117 , 10.811796 , 11.0484085,
              10.837597 , 10.823776 , 10.688841 , 11.0398855, 10.24319  ,
              10.615424 , 10.746661 , 10.788264 , 10.900532 , 10.930208 ,
              10.568759 , 10.547355 , 10.909073 , 10.572882 , 10.858274 ,
              10.661693 , 10.754926 , 10.700899 , 10.704465 , 10.857199 ,
              10.691531 , 10.688172 , 10.816447 , 10.80192  , 10.816098 ,
              10.699189 , 10.6077   , 10.77718  , 10.879744 , 10.713393 ,
              10.686068 , 10.860602 , 10.852782 , 10.896113 , 10.769543 ,
              10.727321 , 10.893265 , 10.977732 , 10.444333 , 10.556506 ,
              10.756605 , 10.614553 , 10.508725 , 10.738361 , 10.754585 ,
              10.333077 , 10.347359 , 10.977198 , 10.581141 , 10.579388 ,
              10.783565 , 10.7082