In [1]:
import argparse
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import torch
from torchvision import transforms, datasets

from jax.numpy.fft import irfft, rfft, fft, ifft

from jax import grad, jit, vmap
import jax.numpy as jnp
import jax.random as random
import jax

!pip install numpyro
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import numpyro.contrib.module as module

!pip install -q flax
from flax import linen as nn



print(torch.backends.cudnn.version())


matplotlib.use("Agg")  # noqa: E402
np.random.seed(0)
key = numpyro.prng_key()
numpyro.set_host_device_count(4)

None


In [2]:
print(jax.devices(backend='gpu'))

[gpu(id=0)]


In [3]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)


class NumpyLoader(torch.utils.data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [4]:
def get_mnist(n, m):
    """
    Download MNIST and return train and evaluation sets.
    """
    mnist = datasets.MNIST('data',
        train=True,
        download=True,
        transform=FlattenAndCast())
    mnist = list(mnist)
    # One batch with all of mnist
    train_loader = NumpyLoader(mnist, batch_size=len(mnist), num_workers=0)
    x, y = list(train_loader)[0]
    # Normalize
    x = (x - x.mean()) / x.std()
    # Train and test set
    train_x, train_y = x[0:n], y[0:n]
    val_x, val_y = x[n:n+m], y[n:n+m]
    return train_x, train_y, val_x, val_y

In [5]:
class flax_CNN(nn.Module):
    @nn.compact   
    def __call__(self, x):
        #print(x.shape)
        x_length = x.shape[0]
        x = nn.Conv(features = 8, kernel_size = (5,5),strides = (1,1), 
                   padding = (2,2), 
                   use_bias= False)(x)
        x = nn.max_pool(x,window_shape=(2,2),strides = (2,2))
        all_len = len(x.flatten())
        x = x.reshape((x_length, int((all_len/x_length))))

        return x


In [6]:
_x, _y, _xv, _yv = get_mnist(50000,10000)

# Training
N=1000
# Test
M=200

# Get the training and test data from the MNIST global variables
x, y, xv, yv = _x[0:N], _y[0:N], _xv[N:N+M], _yv[N:N+M]


In [66]:
dim1 = 256
dim2 = 128
key = random.PRNGKey(0)
conv = flax_CNN() #from (200,28,28) to (200, 1568)
conv.init(key,x.reshape(x.shape[0],28,28,1))

rng_key = random.PRNGKey(1)

tab = conv.tabulate(jax.random.PRNGKey(0), x.reshape((-1,28,28,1)))
print(conv.kernel)


AttributeError: "flax_CNN" object has no attribute "kernel". If "kernel" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

In [8]:
def circ_matmul(x, w, output_size):
    # xw = fft(fft(w)*ifft(x)).real
    # Note the use of the n argument to get right output shape
    xw = irfft(jnp.conj(rfft(w)) * rfft(x), n=output_size)
    #xw = irfft(torch.conj(rfft(w)) * rfft(x), n=w.shape[0])

    return xw

In [9]:
# the non-linearity we use in our neural network
def nonlin(x):
    return jnp.tanh(x)

In [10]:
def circ_conv(x, w,input_size, output_size):
    out = irfft(jnp.conj(rfft(w)) * rfft(x), n=output_size)
    print(out.shape)
    return out
    

In [11]:
out = circ_conv(x[0], jnp.ones([1, 784]), 784, 1586)
print(out)

(1, 1586)
[[8.829421 8.829421 8.829421 ... 8.829421 8.829421 8.829421]]


In [12]:
def model_circulant_all(x, y=None, dim1 = dim1, dim2 = dim2):
    w1 = numpyro.sample("w1", dist.Normal(0,1).expand([28*28]).to_event(1))
    b1 = numpyro.sample("b1", dist.Normal(0,1).expand([dim1]).to_event(1))

    w2 = numpyro.sample("w2", dist.Normal(0,1).expand([dim1]).to_event(1))
    b2 = numpyro.sample("b2", dist.Normal(0,1).expand([dim2]).to_event(1))


    w3 = numpyro.sample("w3", dist.Normal(0,1).expand([dim2,10]).to_event(2))
    b3 = numpyro.sample("b3", dist.Normal(0,1).expand([10]).to_event(1))    

    # Convolution
    # conv_numpyro = module.random_flax_module("conv", conv, dist.Normal(0, 1), input_shape=((x.shape[0],28,28,1)), apply_rng=[key])
    # init_variables = conv.init(key, x.reshape((x.shape[0],28,28,1)))
    # cx = nonlin(conv.apply(init_variables,x.reshape((x.shape[0],28,28,1))))
    
    # w_conv = numpyro.sample("w_conv", dist.Normal(0,1).expand([28*28]))
    # cx = nonlin(circ_matmul(x, w_conv, 28*28))
    #print(shape)

    # Layer 1: dim1
    h1 = circ_matmul(x.reshape(x.shape[0],28*28), w1, w1.shape[0])
    h1 = nonlin(h1[:, 0:dim1] + b1)
    # Layer 2: dim2
    h2 = circ_matmul(h1, w2, w2.shape[0])
    h2 = nonlin(h2[:, 0:dim2] + b2)

    # Layer 3: dim=10 (logits)
    h3 = jnp.matmul(h2,w3) + b3
    # Register the logits for easy prediction
    numpyro.deterministic("logits", h3)

    # Likelihood

    with numpyro.plate("labels", x.shape[0]):
        y_obs = numpyro.sample("y_obs", dist.CategoricalLogits(logits=h3),
                               obs=y, rng_key=key)

In [13]:
def model_circulant_weight(x, y=None, dim1 = dim1, dim2 = dim2):
    w1 = numpyro.sample("w1", dist.Normal(0,1).expand([28*28*2]).to_event(1))
    b1 = numpyro.sample("b1", dist.Normal(0,1).expand([dim1]).to_event(1))

    w2 = numpyro.sample("w2", dist.Normal(0,1).expand([dim1]).to_event(1))
    b2 = numpyro.sample("b2", dist.Normal(0,1).expand([dim2]).to_event(1))


    w3 = numpyro.sample("w3", dist.Normal(0,1).expand([dim2,10]).to_event(2))
    b3 = numpyro.sample("b3", dist.Normal(0,1).expand([10]).to_event(1))



    # Convolution
    # conv_numpyro = module.random_flax_module("conv", conv, dist.Normal(0, 1), input_shape=((x.shape[0],28,28,1)), apply_rng=[key])
    init_variables = conv.init(key, x.reshape((x.shape[0],28,28,1)))
    cx = nonlin(conv.apply(init_variables,x.reshape((x.shape[0],28,28,1))))

    # w_conv = numpyro.sample("w_conv", dist.Normal(0,1).expand([28*28]))
    # cx = nonlin(circ_matmul(x, w_conv, 28*28))
    #print(shape)

    # Layer 1: dim1
    h1 = circ_matmul(cx, w1, w1.shape[0])
    h1 = nonlin(h1[:, 0:dim1] + b1)
    # Layer 2: dim2
    h2 = circ_matmul(h1, w2, w2.shape[0])
    h2 = nonlin(h2[:, 0:dim2] + b2)

    # Layer 3: dim=10 (logits)
    h3 = jnp.matmul(h2,w3) + b3
    # Register the logits for easy prediction
    numpyro.deterministic("logits", h3)

    # Likelihood

    with numpyro.plate("labels", x.shape[0]):
        y_obs = numpyro.sample("y_obs", dist.CategoricalLogits(logits=h3),
                               obs=y, rng_key=key)

In [14]:
def model_full_weight(x, y=None):
  w1_full = numpyro.sample("w1", dist.Normal(0,1).expand([2*28*28,dim1]).to_event(1)) #weight matrix dimension: (indim(1568), h1dim(dim1 128))
  b1_full = numpyro.sample("b1", dist.Normal(0,1).expand([dim1]).to_event(1)) # bias dimension: (128,)

  w2_full = numpyro.sample("w2", dist.Normal(0,1).expand([dim1,dim2]).to_event(1)) #weight matrix dimension: (h1dim(dim1), h2dim(dim2))
  b2_full = numpyro.sample("b2", dist.Normal(0,1).expand([dim2]).to_event(1)) # bias dimension: (dim2,)

  w3_full = numpyro.sample("w3", dist.Normal(0,1).expand([dim2,10]).to_event(2)) #weight matrix dimension: (h2dim(dim2), outdim(10))
  b3_full = numpyro.sample("b3", dist.Normal(0,1).expand([10]).to_event(1)) # bias dimension: (outdim(10))



  # conv_numpyro = module.random_flax_module("conv_full", conv, dist.Normal(0, 1), input_shape=((x.shape[0],28,28,1)))
  # cx = nonlin(conv.apply(x.reshape((-1,28,28,1))))
  init_variables = conv.init(key, x.reshape((x.shape[0],28,28,1)))
  cx = nonlin(conv.apply(init_variables,x.reshape((x.shape[0],28,28,1))))

  h1 = nonlin(jnp.matmul(cx,w1_full) + b1_full)

  h2 = nonlin(jnp.matmul(h1,w2_full) + b2_full)

  h3 = jnp.matmul(h2,w3_full) + b3_full

  # Register the logits for easy prediction
  numpyro.deterministic("logits_full", h3)


  # Likelihood

  with numpyro.plate("labels_full", x.shape[0]):
      y_obs_f = numpyro.sample("y_obs_full", dist.CategoricalLogits(logits = h3), obs = y, rng_key=key)




In [15]:
# The point estimate model common neural network

class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=5,
                stride=1,
                padding=2,
                bias = False
            )
        # Lift to Pyro
        self.maxp = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Sequential(torch.nn.Linear(in_features=2*28*28, out_features=dim1), torch.nn.Tanh(), 
                                      torch.nn.Linear(in_features= dim1, out_features=dim2), torch.nn.Tanh(), 
                                      torch.nn.Linear(in_features = dim2, out_features=10), torch.nn.Tanh())

    def forward(self, x):
        cx = self.conv(x)
        px = self.maxp(cx)
        fx = torch.flatten(px, 1)
        x_output = self.fc(fx)
        return x_output


In [16]:
jax.local_device_count()

1

In [17]:
# train the only circulant model
kernel_all_cir = NUTS(model_circulant_all,
              target_accept_prob = 0.8,
              max_tree_depth = 12
              )

mcmc_all_cir = MCMC(kernel_all_cir,
            num_samples = 100,
            num_warmup = 50,
            num_chains = 2,
            progress_bar = True)

%time mcmc_all_cir.run(random.PRNGKey(0), x,y)

  mcmc_all_cir = MCMC(kernel_all_cir,
sample: 100%|██████████| 150/150 [05:56<00:00,  2.38s/it, 1023 steps of size 3.85e-03. acc. prob=0.72]
sample: 100%|██████████| 150/150 [08:25<00:00,  3.37s/it, 1023 steps of size 3.46e-03. acc. prob=0.81]


CPU times: user 11min 51s, sys: 2min 35s, total: 14min 27s
Wall time: 14min 30s


In [18]:
# train the circulant matrix model

kernel = NUTS(model_circulant_weight,
              target_accept_prob = 0.8,
              max_tree_depth = 12
              )

mcmc = MCMC(kernel,
            num_samples = 100,
            num_warmup = 50,
            num_chains = 2,
            progress_bar = True)

%time mcmc.run(random.PRNGKey(0), x,y)

  mcmc = MCMC(kernel,
sample: 100%|██████████| 150/150 [11:27<00:00,  4.58s/it, 2047 steps of size 2.38e-03. acc. prob=0.96]
sample: 100%|██████████| 150/150 [07:21<00:00,  2.94s/it, 1023 steps of size 4.31e-03. acc. prob=0.71]


CPU times: user 15min 45s, sys: 3min 6s, total: 18min 51s
Wall time: 18min 52s


In [19]:
# Train the full weight matrix model
kernel_all_full = NUTS(model_full_weight,
              target_accept_prob = 0.8,
              max_tree_depth = 12
              )

mcmc_all_full = MCMC(kernel_all_full,
            num_samples = 100,
            num_warmup = 50,
            num_chains = 2,
            progress_bar = True)

%time mcmc_all_full.run(random.PRNGKey(1), x, y)

  mcmc_all_full = MCMC(kernel_all_full,
sample: 100%|██████████| 150/150 [23:11<00:00,  9.28s/it, 4095 steps of size 1.75e-03. acc. prob=0.61]
sample: 100%|██████████| 150/150 [23:19<00:00,  9.33s/it, 4095 steps of size 8.66e-04. acc. prob=0.81]


CPU times: user 39min 28s, sys: 7min 5s, total: 46min 34s
Wall time: 46min 34s


In [20]:
# Train the point estimation mmodel

def get_mnist_torch(n, m):
    """
    Download MNIST and return train and evaluation sets.
    """
    img_to_tensor = transforms.ToTensor()
    mnist = datasets.MNIST('data',
        train=True,
        download=True,
        transform=img_to_tensor)
    mnist = list(mnist)
    # One batch with all of mnist
    train_loader = torch.utils.data.DataLoader(mnist,
        batch_size=len(mnist),
        shuffle=True)
    # x = images tensor, y = labels tensor 
    x, y = list(train_loader)[0]
    # Flatten images
    x = x.view(-1, 28*28)
    # Normalize
    x = (x - x.mean()) / x.std()
    # Train and test set
    train_x, train_y = x[0:n], y[0:n]
    val_x, val_y = x[n:n+m], y[n:n+m]
    return train_x, train_y, val_x, val_y

#evice = torch.device.device
#print(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_point_estimation(model, dataloader, criterion, optimizer, num_epochs):
    model.to(device)
    loss_list = []
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for _, (data, label) in enumerate(dataloader):

            data = data.to(device)
            labels = label.to(device)
            model.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        # Print the average loss at every epoch
        average_loss = epoch_loss / (len(dataloader))
        loss_list.append(average_loss)
        if (epoch + 1) % 10 == 0:
            print(f'Epoch: {epoch + 1}, Average Loss: {average_loss}')
    torch.save(model, "cnn.pth")
    return loss_list

point_estimate_model = CNN()
x_torch, y_torch, xv_torch, yv_torch = get_mnist_torch(N,M)

x_torch = x_torch.view((-1,1,28,28))
print(x_torch.shape)
xv_torch = xv_torch.view((-1,1,28,28))


train_set = torch.utils.data.TensorDataset(x_torch,y_torch)
test_set = torch.utils.data.TensorDataset(xv_torch,yv_torch)

train_loader = torch.utils.data.DataLoader(train_set, shuffle = True, batch_size = 1)

test_loader = torch.utils.data.DataLoader(test_set, shuffle = True, batch_size = 1)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(point_estimate_model.parameters(), lr = 0.001)

#loss_list_point_estimation = train_point_estimation(point_estimate_model,train_loader, criterion, optimizer, 100)




torch.Size([1000, 1, 28, 28])


In [21]:
# get posterior and prediction of the all circular model

posterior_samples_all_cir = mcmc_all_cir.get_samples()


posterior_predictive_test_all_cir = numpyro.infer.Predictive(model_circulant_all, posterior_samples_all_cir)(
    jax.random.PRNGKey(3),xv)

posterior_predictive_train_all_cir = numpyro.infer.Predictive(model_circulant_all, posterior_samples_all_cir)(
    jax.random.PRNGKey(3),x)

prior_predictive_all_cir = numpyro.infer.Predictive(model_circulant_all, num_samples=500)(
    jax.random.PRNGKey(3),xv)

In [22]:
# get posterior and prediction of the circular weight matrix model

posterior_samples = mcmc.get_samples()


posterior_predictive_test = numpyro.infer.Predictive(model_circulant_weight, posterior_samples)(
        jax.random.PRNGKey(3),xv)

posterior_predictive_train = numpyro.infer.Predictive(model_circulant_weight, posterior_samples)(
        jax.random.PRNGKey(3),x)

prior_predictive = numpyro.infer.Predictive(model_circulant_weight, num_samples=500)(
        jax.random.PRNGKey(3),xv)


In [23]:
# get posterior and prediction of the full weight matrix model

posterior_samples_all_full = mcmc_all_full.get_samples()


posterior_predictive_test_all_full = numpyro.infer.Predictive(model_full_weight, posterior_samples_all_full)(
        jax.random.PRNGKey(3),xv)

posterior_predictive_train_all_full = numpyro.infer.Predictive(model_full_weight, posterior_samples_all_full)(
        jax.random.PRNGKey(3),x)

prior_predictive_all_full = numpyro.infer.Predictive(model_full_weight, num_samples=500)(
        jax.random.PRNGKey(3),xv)

In [24]:
import arviz as az
az.style.use("arviz-doc")

In [25]:
def accuracy(pred, data):
    """
    Calculate accuracy of predicted labels (integers).

    pred: predictions, ndarray[sample_index, chain_index, data_index, logits]
    data: actual data (digit), ndarray[data_index]

    Prediction is taken as most common predicted value.
    Returns accuracy (#correct/#total).
    """
    n=data.shape[0]
    correct=0
    total=0
    for i in range(0, n):
        # Get most common prediction value from logits
        pred_i=int(jnp.argmax(jnp.sum(pred[:,i,:],0)))
        # Compare prediction with data
        if int(data[i])==int(pred_i):
            correct+=1.0
        total+=1.0
    # Return fractional accuracy
    return correct/total

In [26]:
# summary of circulant matrix model
az.style.use("arviz-doc")



numpyro_data = az.from_numpyro(
    mcmc,
    prior=prior_predictive,
    posterior_predictive=posterior_predictive_train,

)
post = numpyro_data.posterior
print(post)
w_ess = az.ess(post,var_names=['w1','w2','w3'])
w_ess_lst = [w_ess.w1.values.flatten(),w_ess.w2.values.flatten(),w_ess.w3.values.flatten()]

fig, ax = plt.subplots()

ax.boxplot(w_ess_lst)
plt.savefig("w_ess.png")


fig2, ax2 = plt.subplots()
ax2.violinplot(w_ess_lst,
                  showmeans=False,
                  showmedians=True)
plt.savefig("w_ess_vio.png")



# w_ess_med = w_ess.median(dim = ['w1_dim_0','w2_dim_0','w3_dim_0','w3_dim_1'])
#w_ess_qutile = w_ess.quantile(q= , dim = ['w1_dim_0','w2_dim_0','w3_dim_0','w3_dim_1'])
print(w_ess)


# print(w_ess_lst)
# print(w_ess.w3.values.flatten())
# # print(w_ess_med)
# print(post)

<xarray.Dataset>
Dimensions:       (chain: 2, draw: 100, b1_dim_0: 256, b2_dim_0: 128,
                   b3_dim_0: 10, logits_dim_0: 1000, logits_dim_1: 10,
                   w1_dim_0: 1568, w2_dim_0: 256, w3_dim_0: 128, w3_dim_1: 10)
Coordinates:
  * chain         (chain) int64 0 1
  * draw          (draw) int64 0 1 2 3 4 5 6 7 8 ... 91 92 93 94 95 96 97 98 99
  * b1_dim_0      (b1_dim_0) int64 0 1 2 3 4 5 6 ... 249 250 251 252 253 254 255
  * b2_dim_0      (b2_dim_0) int64 0 1 2 3 4 5 6 ... 121 122 123 124 125 126 127
  * b3_dim_0      (b3_dim_0) int64 0 1 2 3 4 5 6 7 8 9
  * logits_dim_0  (logits_dim_0) int64 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
  * logits_dim_1  (logits_dim_1) int64 0 1 2 3 4 5 6 7 8 9
  * w1_dim_0      (w1_dim_0) int64 0 1 2 3 4 5 ... 1562 1563 1564 1565 1566 1567
  * w2_dim_0      (w2_dim_0) int64 0 1 2 3 4 5 6 ... 249 250 251 252 253 254 255
  * w3_dim_0      (w3_dim_0) int64 0 1 2 3 4 5 6 ... 121 122 123 124 125 126 127
  * w3_dim_1      (w3_dim_1) int64

In [27]:
def plot_diags(inferdata, model_name):
    plt.close('all')
    post = inferdata.posterior

    r_hat = az.rhat(post, var_names=['w1','w2','w3'])
    r_hat_lst = [r_hat.w1.values.flatten(),r_hat.w2.values.flatten(),r_hat.w3.values.flatten()]
    # print(r_hat)
    fig1, ax1 = plt.subplots()
    ax1.boxplot(r_hat_lst)
    plt.savefig("r_hat_{}.png".format(model_name))

    w_ess = az.ess(post,var_names=['w1','w2','w3'])
    w_ess_lst = [w_ess.w1.values.flatten(),w_ess.w2.values.flatten(),w_ess.w3.values.flatten()]


    fig, ax = plt.subplots()
    ax.boxplot(w_ess_lst)
    plt.savefig("w_ess_{}.png".format(model_name))


    fig2, ax2 = plt.subplots()
    ax2.violinplot(w_ess_lst,
               showmeans=False,
               showmedians=True)
    plt.savefig("w_ess_vio_{}.png".format(model_name))

    az.plot_ppc(inferdata, data_pairs={"obs": "obs"}, alpha=0.03, textsize=14)
    # plt.show()
    plt.savefig("ppc_{}.png".format(model_name))

    return w_ess_lst


In [28]:
plot_diags(numpyro_data,"cir_weight")


numpyro_data_all_cir = az.from_numpyro(
    mcmc_all_cir,
    prior=prior_predictive_all_cir,
    posterior_predictive=posterior_predictive_train_all_cir,

)

plot_diags(numpyro_data_all_cir, "all_cir")

numpyro_data_all_full = az.from_numpyro(
    mcmc_all_full,
    prior=prior_predictive_all_full,
    posterior_predictive=posterior_predictive_train_all_full,

)

plot_diags(numpyro_data_all_full, "all_full")

[array([155.77341835, 150.24534522, 182.84368972, ..., 151.19779148,
         88.35897442,  90.83058298]),
 array([138.77785544, 119.31192123,  91.81359846, ..., 262.17116148,
         85.92081197, 125.65311814]),
 array([ 93.16707823, 115.24246889,  74.42361195, ...,  11.5757246 ,
        185.69616671, 111.70697826])]

In [74]:

#summary = az.summary(numpyro_data)
#calculate the parameter number of full weight model
num_layer_1 = numpyro_data_all_full.posterior.w1.shape[2]*numpyro_data_all_full.posterior.w1.shape[3] + numpyro_data_all_full.posterior.b1.shape[2]
num_layer_2 = numpyro_data_all_full.posterior.w2.shape[2]*numpyro_data_all_full.posterior.w2.shape[3] + numpyro_data_all_full.posterior.b2.shape[2]
num_layer_3 = numpyro_data_all_full.posterior.w3.shape[2]*numpyro_data_all_full.posterior.w3.shape[3] + numpyro_data_all_full.posterior.b3.shape[2]
total_num = num_layer_1+num_layer_2+num_layer_3

def get_param_num(model, cnn=True, cir=False):
    if cir:
        num_layer_1 = model.posterior.w1.shape[2] + model.posterior.b1.shape[2]
        num_layer_2 = model.posterior.w2.shape[2] + model.posterior.b2.shape[2]
    else:
        num_layer_1 = model.posterior.w1.shape[2]*model.posterior.w1.shape[3] + model.posterior.b1.shape[2]
        num_layer_2 = model.posterior.w2.shape[2]*model.posterior.w2.shape[3] + model.posterior.b2.shape[2]
    num_layer_3 = model.posterior.w3.shape[2]*model.posterior.w3.shape[3] + model.posterior.b3.shape[2]
    total_num = num_layer_1+num_layer_2+num_layer_3
    if cnn:
        total_num = total_num + 200
    return total_num


In [72]:

# az.plot_ess(summary_data_circulant,var_names=('w1'), kind = 'evolution')
# plt.savefig("posterior_sample.png")
# az.plot_trace(summary_data_circulant,var_names=("w1"))
# plt.savefig("trace_w1.png")
# az.plot_trace(summary_data_circulant,var_names=("w2"))
# plt.savefig("trace_w2.png")
# az.plot_trace(summary_data_circulant,var_names=("w3"))
# plt.savefig("trace_w3.png")
az.plot_ppc(numpyro_data, data_pairs={"obs": "obs"}, alpha=0.03, textsize=14)
plt.show()
plt.savefig("ppc.png")

  plt.show()


In [77]:
# calculate the number of parameters, w and b

num_params = get_param_num(numpyro_data,cir=True)
num_params_all_cir = get_param_num(numpyro_data_all_cir, False , cir=True)
num_params_all_full = get_param_num(numpyro_data_all_full)

print('number of parameters of cnn and circulant matrix model: ', num_params)
print('number of parameters of cnn and full weight matrix model: ', num_params_all_full)
print('number of parameters of only circulant matrix model: ', num_params_all_cir)

number of parameters of cnn and circulant matrix model:  3698
number of parameters of cnn and full weight matrix model:  436050
number of parameters of only circulant matrix model:  2714


In [36]:
# Accuracy on test set
logits = posterior_predictive_test['logits']
print("Success posterior test = %.3f" % accuracy(logits, yv))

# Accuracy on training set
logits = posterior_predictive_train['logits']
print("Success posterior training = %.3f" % accuracy(logits, y))

logits = prior_predictive['logits']
print("Success prior = %.3f" % accuracy(logits, yv))

print("Posterior test diagnostics:")
numpyro.diagnostics.print_summary(posterior_samples)

Success posterior test = 0.920
Success posterior training = 1.000
Success prior = 0.105
Posterior test diagnostics:

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        b1      0.00      1.01      0.00     -1.67      1.64  52074.41      1.00
        b2      0.00      1.00      0.01     -1.64      1.64  26165.44      1.00
        b3     -0.01      0.98      0.01     -1.70      1.47   2842.25      0.99
 logits[0]     -1.89     14.50     -2.90    -26.89     22.27 159085.98      1.01
 logits[1]     -1.02     14.54     -2.52    -25.09     24.73 150428.94      1.01
 logits[2]     -0.55     15.47     -0.95    -26.90     25.25 180240.09      1.01
 logits[3]      0.35     13.78     -0.53    -22.91     23.16 188373.73      1.01
 logits[4]     -0.69     15.20     -1.71    -25.55     25.27 170207.91      1.01
 logits[5]      2.64     11.93      2.16    -17.28     22.35 148855.59      1.01
 logits[6]     -1.71     15.14     -2.50    -27.54     23.56 196587.91   

In [37]:
# summary of full weight matrix model



# Accuracy on test set
logits = posterior_predictive_test_all_full['logits_full']

print("Success posterior test = %.3f" % accuracy(logits, yv))

# Accuracy on training set
logits = posterior_predictive_train_all_full['logits_full']
print("Success posterior training = %.3f" % accuracy(logits, y))

logits = prior_predictive_all_full['logits_full']
print("Success prior = %.3f" % accuracy(logits, yv))

# print("Posterior test diagnostics:")
# numpyro.diagnostics.print_summary(posterior_samples_1)

Success posterior test = 0.890
Success posterior training = 1.000
Success prior = 0.065


In [35]:
#