In [1]:
import argparse
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from jax.numpy.fft import irfft, rfft, fft, ifft

from jax import grad, jit, vmap
import jax.numpy as jnp
import jax.random as random
import jax

!pip install numpyro
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import numpyro.contrib.module as module

!pip install -q flax
from flax import linen as nn

import torch
from torchvision import transforms, datasets



matplotlib.use("Agg")  # noqa: E402
np.random.seed(0)
key = numpyro.prng_key()



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(jax.devices(backend='gpu'))

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]


In [4]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)


class NumpyLoader(torch.utils.data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [5]:
def get_mnist(n, m):
    """
    Download MNIST and return train and evaluation sets.
    """
    mnist = datasets.MNIST('data',
        train=True,
        download=True,
        transform=FlattenAndCast())
    mnist = list(mnist)
    # One batch with all of mnist
    train_loader = NumpyLoader(mnist, batch_size=len(mnist), num_workers=0)
    x, y = list(train_loader)[0]
    # Normalize
    x = (x - x.mean()) / x.std()
    # Train and test set
    train_x, train_y = x[0:n], y[0:n]
    val_x, val_y = x[n:n+m], y[n:n+m]
    return train_x, train_y, val_x, val_y

In [6]:
class flax_CNN(nn.Module):
    @nn.compact   
    def __call__(self, x):
        #print(x.shape)
        x_length = x.shape[0]
        x = nn.Conv(features = 8, kernel_size = (5,5),strides = (1,1), 
                   padding = (2,2), 
                   use_bias= False)(x)
        #print(x.shape[0])
        x = nn.max_pool(x,window_shape=(2,2),strides = (2,2))
        all_len = len(x.flatten())
        x = x.reshape((x_length, int((all_len/x_length)) ))

        return x


In [7]:
_x, _y, _xv, _yv = get_mnist(50000,10000)

# Training
N=1000 
# Test
M=300

# Get the training and test data from the MNIST global variables
x, y, xv, yv = _x[0:N], _y[0:N], _xv[N:N+M], _yv[N:N+M]


print(x.shape)
print(y.shape)
print(jnp.ravel(x.reshape((-1,1,28,28))).shape)

(1000, 784)
(1000,)
(784000,)


In [8]:
dim1 = 800
dim2 = 128
# dim3 = 128
# dim4 = 32
key = random.PRNGKey(0)
d1 = dist.Bernoulli(jnp.array(0.5)).expand([2*28*28]).sample(key)
d2 = dist.Bernoulli(jnp.array(0.5)).expand([dim1]).sample(key)
d1 = 2*d1 - 1
d2 = 2*d2 - 1
conv = flax_CNN() #from (200,28,28) to (200, 1568)

rng_key = random.PRNGKey(1)

print(conv.tabulate(jax.random.PRNGKey(0), x.reshape((-1,28,28,1))))

# with numpyro.handlers.seed(rng_seed=0):
#     numpyro_conv = module.random_flax_module("conv", conv, dist.Normal(0,1))


[3m                                flax_CNN Summary                                [0m
┏━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath  [0m[1m [0m┃[1m [0m[1mmodule  [0m[1m [0m┃[1m [0m[1minputs           [0m[1m [0m┃[1m [0m[1moutputs          [0m[1m [0m┃[1m [0m[1mparams          [0m[1m [0m┃
┡━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│        │ flax_CNN │ [2mfloat32[0m[1000,28,… │ [2mfloat32[0m[1000,156… │                  │
├────────┼──────────┼───────────────────┼───────────────────┼──────────────────┤
│ Conv_0 │ Conv     │ [2mfloat32[0m[1000,28,… │ [2mfloat32[0m[1000,28,… │ kernel:          │
│        │          │                   │                   │ [2mfloat32[0m[5,5,1,8] │
│        │          │                   │                   │                  │
│        │          │                   │                   │ [1m200 [0m[1;2m(800 B)[0m      │
├──

In [9]:
def circ_matmul(x, w):
    # xw = fft(fft(w)*ifft(x)).real
    # Note the use of the n argument to get right output shape
    xw = irfft(jnp.conj(rfft(w)) * rfft(x), n=w.shape[0])
    return xw

In [10]:
# the non-linearity we use in our neural network
def nonlin(x):
    return jnp.tanh(x)

In [11]:
def model(x, y=None, dim1 = dim1, dim2 = dim2):
    w1 = numpyro.sample("w1", dist.Normal(0,1).expand([2*28*28]).to_event(1))
    b1 = numpyro.sample("b1", dist.Normal(0,1).expand([dim1]).to_event(1))

    w2 = numpyro.sample("w2", dist.Normal(0,1).expand([dim1]).to_event(1))
    b2 = numpyro.sample("b2", dist.Normal(0,1).expand([dim2]).to_event(1))

#     w3 = numpyro.sample("w3", dist.Normal(0,1).expand([dim2]).to_event(1))
#     b3 = numpyro.sample("b3", dist.Normal(0,1).expand([dim3]).to_event(1))

#     w4 = numpyro.sample("w4", dist.Normal(0,1).expand([dim3]).to_event(1))
#     b4 = numpyro.sample("b4", dist.Normal(0,1).expand([dim4]).to_event(1))        

    # w5 = numpyro.sample("w5", dist.Normal(0,1).expand([dim4,10]).to_event(2))
    # b5 = numpyro.sample("b5", dist.Normal(0,1).expand([10]).to_event(1))
    
    
    w5 = numpyro.sample("w5", dist.Normal(0,1).expand([dim2,10]).to_event(2))
    b5 = numpyro.sample("b5", dist.Normal(0,1).expand([10]).to_event(1))    

    # Convolution
    #variables = conv.init(random.PRNGKey(3),x.reshape((-1,1,28,28)))
    #conv_model = conv.apply(variables,x.reshape((-1,1,28,28)))
    #conv = flax_CNN()
    conv_numpyro = module.random_flax_module("conv", conv, dist.Normal(0, 1), input_shape=((x.shape[0],28,28,1)))
    cx = nonlin(conv_numpyro(x.reshape((-1,28,28,1))))
    #print(conv_numpyro(x.reshape((-1,28,28,1))))

    # Layer 1: dim1
    h1 = circ_matmul(cx, w1)
    h1 = nonlin(h1[:, 0:dim1] + b1)
    # Layer 2: dim2
    h2 = circ_matmul(h1, w2)
    h2 = nonlin(h2[:, 0:dim2] + b2)


#     h3 = circ_matmul(h2, w3)
#     h3 = nonlin(h3[:, 0:dim3] + b3)

#     h4 = circ_matmul(h3, w4)
#     h4 = nonlin(h4[:, 0:dim4] + b4)
    

  
    # Layer 3: dim=10 (logits)
    h5 = jnp.matmul(h2,w5) + b5
    # Register the logits for easy prediction
    numpyro.deterministic("logits", h5)

    # Likelihood
    with numpyro.plate("labels", x.shape[0]):
        y_obs = numpyro.sample("y_obs", dist.CategoricalLogits(logits = h5), obs = y, rng_key=key)
    

In [34]:
def model_2(x, y=None):
  w1_f = numpyro.sample("w1_f", dist.Normal(0,1).expand([2*28*28,dim1]).to_event(1)) #weight matrix dimension: (indim(1568), h1dim(dim1 128))
  b1_f = numpyro.sample("b1_f", dist.Normal(0,1).expand([dim1]).to_event(1)) # bias dimension: (128,)

  w2_f = numpyro.sample("w2_f", dist.Normal(0,1).expand([dim1,dim2]).to_event(1)) #weight matrix dimension: (h1dim(dim1), h2dim(dim2))
  b2_f = numpyro.sample("b2_f", dist.Normal(0,1).expand([dim2]).to_event(1)) # bias dimension: (dim2,)

#   w3 = numpyro.sample("w3", dist.Normal(0,1).expand([dim2,dim3]).to_event(2)) #weight matrix dimension: (h2dim(dim2), outdim(10))
#   b3 = numpyro.sample("b3", dist.Normal(0,1).expand([dim3]).to_event(1)) # bias dimension: (outdim(10))

#   w4 = numpyro.sample("w4", dist.Normal(0,1).expand([dim3,dim4]).to_event(2)) #weight matrix dimension: (h2dim(dim2), outdim(10))
#   b4 = numpyro.sample("b4", dist.Normal(0,1).expand([dim4]).to_event(1)) # bias dimension: (outdim(10))

  w5_f = numpyro.sample("w5_f", dist.Normal(0,1).expand([dim2,10]).to_event(2)) #weight matrix dimension: (h2dim(dim2), outdim(10))
  b5_f = numpyro.sample("b5_f", dist.Normal(0,1).expand([10]).to_event(1)) # bias dimension: (outdim(10))  

  conv_numpyro = module.random_flax_module("conv_f", conv, dist.Normal(0, 1), input_shape=((x.shape[0],28,28,1)))
  cx = nonlin(conv_numpyro(x.reshape((-1,28,28,1))))


  h1 = nonlin(jnp.matmul(cx,w1_f) + b1_f)

  h2 = nonlin(jnp.matmul(h1,w2_f) + b2_f)


#   h3 = nonlin(jnp.matmul(h2,w3_f) + b3_f)

#   h4 = nonlin(jnp.matmul(h3,w4_f) + b4_f)  

  h5 = jnp.matmul(h2,w5_f) + b5_f

  # Register the logits for easy prediction
  numpyro.deterministic("logits_f", h5)


  # Likelihood
  with numpyro.plate("labels_f", x.shape[0]):
      y_obs_f = numpyro.sample("y_obs_f", dist.CategoricalLogits(logits = h5), obs = y, rng_key=key)




In [15]:
jax.local_device_count()

1

In [35]:
# Default max_tree_depth is 10
kernel = NUTS(model_2,
              target_accept_prob = 0.8,
              max_tree_depth = 10
              )

mcmc = MCMC(kernel,
            num_samples = 100,
            num_warmup = 50,
            num_chains = 2,
            progress_bar = True)

mcmc.run(random.PRNGKey(0), x,y)

  mcmc = MCMC(kernel,
sample: 100%|██████████| 150/150 [14:47<00:00,  5.92s/it, 1023 steps of size 2.05e-03. acc. prob=0.86]
sample: 100%|██████████| 150/150 [14:51<00:00,  5.94s/it, 1023 steps of size 2.00e-03. acc. prob=0.87]


In [None]:
print(jax.devices)

<function devices at 0x7f85b9f517e0>


In [12]:
# Default max_tree_depth is 10
kernel_1 = NUTS(model,
              target_accept_prob = 0.8,
              max_tree_depth = 10
              )

mcmc_1 = MCMC(kernel_1,
            num_samples = 100,
            num_warmup = 50,
            num_chains = 2,
            progress_bar = True)

mcmc_1.run(random.PRNGKey(0), x,y)

  mcmc_1 = MCMC(kernel_1,
sample: 100%|██████████| 150/150 [10:49<00:00,  4.33s/it, 1023 steps of size 2.24e-03. acc. prob=0.95]
sample: 100%|██████████| 150/150 [10:44<00:00,  4.30s/it, 1023 steps of size 2.76e-03. acc. prob=0.76]


In [36]:
posterior_samples = mcmc.get_samples()


posterior_predictive_test = numpyro.infer.Predictive(model_2, posterior_samples)(
        jax.random.PRNGKey(3),xv)

posterior_predictive_train = numpyro.infer.Predictive(model_2, posterior_samples)(
        jax.random.PRNGKey(3),x)

prior_predictive = numpyro.infer.Predictive(model_2, num_samples=500)(
        jax.random.PRNGKey(3),xv)


In [13]:
posterior_samples_1 = mcmc_1.get_samples()


posterior_predictive_test_1 = numpyro.infer.Predictive(model, posterior_samples_1)(
        jax.random.PRNGKey(3),xv)

posterior_predictive_train_1 = numpyro.infer.Predictive(model, posterior_samples_1)(
        jax.random.PRNGKey(3),x)

prior_predictive_1 = numpyro.infer.Predictive(model, num_samples=500)(
        jax.random.PRNGKey(3),xv)

In [22]:
!pip install arviz
import arviz



In [20]:
def accuracy(pred, data):
  """
  Calculate accuracy of predicted labels (integers).

  pred: predictions, ndarray[sample_index, chain_index, data_index, logits]
  data: actual data (digit), ndarray[data_index]

  Prediction is taken as most common predicted value.
  Returns accuracy (#correct/#total).
  """
  n=data.shape[0]
  correct=0
  total=0
  for i in range(0, n):
      # Get most common prediction value from logits
      pred_i=int(jnp.argmax(jnp.sum(pred[:,i,:],0)))
      # Compare prediction with data
      if int(data[i])==int(pred_i):
          correct+=1.0
      total+=1.0
  # Return fractional accuracy
  return correct/total

In [39]:
# Diagnostics from Arviz
#mcmc.print_summary()
#data = arviz.from_numpyro(mcmc, prior=prior_predictive, posterior_predictive=posterior_predictive_test)
#summary = arviz.summary(data)
#print(summary)

#arviz.plot_trace(posterior_samples)

# Diagnostics from Pyro
#report = mcmc.summary()

# Accuracy on test set
logits = posterior_predictive_test['logits_f']
print("Success posterior test = %.3f" % accuracy(logits, yv))

# Accuracy on training set
logits = posterior_predictive_train['logits_f']
print("Success posterior training = %.3f" % accuracy(logits, y))

logits = prior_predictive['logits_f']
print("Success prior = %.3f" % accuracy(logits, yv))

print("Posterior test diagnostics:")
numpyro.diagnostics.print_summary(posterior_predictive_test)

Success posterior test = 0.883
Success posterior training = 1.000
Success prior = 0.153
Posterior test diagnostics:

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
logits_f[0]     -1.58     14.07     -2.78    -25.45     21.06  32888.18      1.02
logits_f[1]     -2.41     14.37     -3.82    -25.99     22.14  35610.87      1.02
logits_f[2]      0.79     12.20      0.41    -19.77     20.37  35679.58      1.02
logits_f[3]     -0.70     12.15     -1.12    -21.07     18.76  55123.87      1.02
logits_f[4]     -0.09     12.85     -0.63    -21.44     21.13  53809.85      1.02
logits_f[5]      0.45     11.49      0.11    -18.09     19.76  15910.72      1.02
logits_f[6]     -1.48     13.71     -2.69    -25.15     20.04  57032.00      1.01
logits_f[7]     -0.30     14.89     -1.68    -24.96     24.53  46335.20      1.01
logits_f[8]      2.30     11.57      1.89    -16.87     21.25  55068.99      1.02
logits_f[9]      1.70     13.08      1.02    -20.42     22.78  

In [30]:
# Accuracy on test set
logits = posterior_predictive_test_1['logits']

print("Success posterior test = %.3f" % accuracy(logits, yv))

# Accuracy on training set
logits = posterior_predictive_train_1['logits']
print("Success posterior training = %.3f" % accuracy(logits, y))

logits = prior_predictive_1['logits']
print("Success prior = %.3f" % accuracy(logits, yv))

print("Posterior test diagnostics:")
numpyro.diagnostics.print_summary(posterior_predictive_test_1)

Success posterior test = 0.907
Success posterior training = 1.000
Success prior = 0.137
Posterior test diagnostics:

                mean       std    median      5.0%     95.0%     n_eff     r_hat
 logits[0]     -1.24     14.47     -2.54    -24.94     23.42  48835.61      1.01
 logits[1]     -1.09     13.87     -2.30    -24.18     22.60  51045.04      1.01
 logits[2]     -0.75     14.53     -1.18    -24.81     23.15  40331.78      1.01
 logits[3]      0.54     12.72      0.07    -20.02     22.08  53181.66      1.01
 logits[4]     -1.25     14.62     -1.79    -24.92     23.47  53077.15      1.02
 logits[5]      1.49     11.55      0.93    -17.03     20.62  55902.17      1.01
 logits[6]     -2.27     15.51     -2.96    -28.84     23.40  57791.70      1.01
 logits[7]     -1.20     16.22     -2.54    -27.04     26.90  48126.08      1.01
 logits[8]      3.42     12.11      2.86    -16.38     23.59  54406.70      1.01
 logits[9]      3.62     12.36      3.14    -16.94     23.82  30721.54   