In [1]:
import argparse
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from jax.numpy.fft import irfft, rfft, fft, ifft

from jax import grad, jit, vmap
import jax.numpy as jnp
import jax.random as random
import jax

!pip install numpyro
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import numpyro.contrib.module as module

!pip install -q flax
from flax import linen as nn

import torch
from torchvision import transforms, datasets



matplotlib.use("Agg")  # noqa: E402
np.random.seed(0)
key = numpyro.prng_key()
numpyro.set_host_device_count(4)



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(jax.devices(backend='gpu'))

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]


In [3]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)


class NumpyLoader(torch.utils.data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [4]:
def get_mnist(n, m):
    """
    Download MNIST and return train and evaluation sets.
    """
    mnist = datasets.MNIST('data',
        train=True,
        download=True,
        transform=FlattenAndCast())
    mnist = list(mnist)
    # One batch with all of mnist
    train_loader = NumpyLoader(mnist, batch_size=len(mnist), num_workers=0)
    x, y = list(train_loader)[0]
    # Normalize
    x = (x - x.mean()) / x.std()
    # Train and test set
    train_x, train_y = x[0:n], y[0:n]
    val_x, val_y = x[n:n+m], y[n:n+m]
    return train_x, train_y, val_x, val_y

In [5]:
class flax_CNN(nn.Module):
    @nn.compact   
    def __call__(self, x):
        #print(x.shape)
        x_length = x.shape[0]
        x = nn.Conv(features = 8, kernel_size = (5,5),strides = (1,1), 
                   padding = (2,2), 
                   use_bias= False)(x)
        #print(x.shape[0])
        x = nn.max_pool(x,window_shape=(2,2),strides = (2,2))
        all_len = len(x.flatten())
        x = x.reshape((x_length, int((all_len/x_length)) ))

        return x


In [6]:
_x, _y, _xv, _yv = get_mnist(50000,10000)

# Training
N=1000 
# Test
M=300

# Get the training and test data from the MNIST global variables
x, y, xv, yv = _x[0:N], _y[0:N], _xv[N:N+M], _yv[N:N+M]


In [7]:
dim1 = 128
dim2 = 32
key = random.PRNGKey(0)
conv = flax_CNN() #from (200,28,28) to (200, 1568)

rng_key = random.PRNGKey(1)

print(conv.tabulate(jax.random.PRNGKey(0), x.reshape((-1,28,28,1))))



[3m                                flax_CNN Summary                                [0m
┏━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath  [0m[1m [0m┃[1m [0m[1mmodule  [0m[1m [0m┃[1m [0m[1minputs           [0m[1m [0m┃[1m [0m[1moutputs          [0m[1m [0m┃[1m [0m[1mparams          [0m[1m [0m┃
┡━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│        │ flax_CNN │ [2mfloat32[0m[1000,28,… │ [2mfloat32[0m[1000,156… │                  │
├────────┼──────────┼───────────────────┼───────────────────┼──────────────────┤
│ Conv_0 │ Conv     │ [2mfloat32[0m[1000,28,… │ [2mfloat32[0m[1000,28,… │ kernel:          │
│        │          │                   │                   │ [2mfloat32[0m[5,5,1,8] │
│        │          │                   │                   │                  │
│        │          │                   │                   │ [1m200 [0m[1;2m(800 B)[0m      │
├──

In [8]:
def circ_matmul(x, w):
    # xw = fft(fft(w)*ifft(x)).real
    # Note the use of the n argument to get right output shape
    xw = irfft(jnp.conj(rfft(w)) * rfft(x), n=w.shape[0])
    return xw

In [9]:
# the non-linearity we use in our neural network
def nonlin(x):
    return jnp.tanh(x)

In [10]:
# TODO: add numlayer parameter, add whether has cnn parameter, add whether use circulant multiply, add cnn parameters.calculate dimensions. 

def model_circulant_weight(x, y=None, dim1 = dim1, dim2 = dim2):
    w1 = numpyro.sample("w1", dist.Normal(0,1).expand([2*28*28]).to_event(1))
    b1 = numpyro.sample("b1", dist.Normal(0,1).expand([dim1]).to_event(1))

    w2 = numpyro.sample("w2", dist.Normal(0,1).expand([dim1]).to_event(1))
    b2 = numpyro.sample("b2", dist.Normal(0,1).expand([dim2]).to_event(1))


    w3 = numpyro.sample("w3", dist.Normal(0,1).expand([dim2]).to_event(2))
    b3 = numpyro.sample("b3", dist.Normal(0,1).expand([10]).to_event(1))    

    # Convolution
    conv_numpyro = module.random_flax_module("conv", conv, dist.Normal(0, 1), input_shape=((x.shape[0],28,28,1)))
    cx = nonlin(conv_numpyro(x.reshape((-1,28,28,1))))

    # Layer 1: dim1
    h1 = circ_matmul(cx, w1)
    h1 = nonlin(h1[:, 0:dim1] + b1)
    # Layer 2: dim2
    h2 = circ_matmul(h1, w2)
    h2 = nonlin(h2[:, 0:dim2] + b2)

    # Layer 3: dim=10 (logits)
    h3 = jnp.matmul(h2,w3) + b3
    # Register the logits for easy prediction
    numpyro.deterministic("logits", h3)

    # Likelihood
    with numpyro.plate("labels", x.shape[0]):
        y_obs = numpyro.sample("y_obs", dist.CategoricalLogits(logits=h3), 
                               obs=y, rng_key=key)

In [11]:
def model_full_weight(x, y=None):
  w1_full = numpyro.sample("w1_full", dist.Normal(0,1).expand([2*28*28,dim1]).to_event(1)) #weight matrix dimension: (indim(1568), h1dim(dim1 128))
  b1_full = numpyro.sample("b1_full", dist.Normal(0,1).expand([dim1]).to_event(1)) # bias dimension: (128,)

  w2_full = numpyro.sample("w2_full", dist.Normal(0,1).expand([dim1,dim2]).to_event(1)) #weight matrix dimension: (h1dim(dim1), h2dim(dim2))
  b2_full = numpyro.sample("b2_full", dist.Normal(0,1).expand([dim2]).to_event(1)) # bias dimension: (dim2,)

  w3_full = numpyro.sample("w3_full", dist.Normal(0,1).expand([dim2,10]).to_event(2)) #weight matrix dimension: (h2dim(dim2), outdim(10))
  b3_full = numpyro.sample("b3_full", dist.Normal(0,1).expand([10]).to_event(1)) # bias dimension: (outdim(10))  

  conv_numpyro = module.random_flax_module("conv_full", conv, dist.Normal(0, 1), input_shape=((x.shape[0],28,28,1)))
  cx = nonlin(conv_numpyro(x.reshape((-1,28,28,1))))


  h1 = nonlin(jnp.matmul(cx,w1_full) + b1_full)

  h2 = nonlin(jnp.matmul(h1,w2_full) + b2_full)

  h3 = jnp.matmul(h2,w3_full) + b3_full

  # Register the logits for easy prediction
  numpyro.deterministic("logits_full", h3)


  # Likelihood
  with numpyro.plate("labels_full", x.shape[0]):
      y_obs_f = numpyro.sample("y_obs_full", dist.CategoricalLogits(logits = h3), obs = y, rng_key=key)




In [19]:
# The point estimate model common neural network

class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=5,
                stride=1,
                padding=2,
                bias = False
            )
        # Lift to Pyro
        self.maxp = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Sequential(torch.nn.Linear(in_features=2*28*28, out_features=dim1), torch.nn.Tanh(), 
                                      torch.nn.Linear(in_features= dim1, out_features=dim2), torch.nn.Tanh(), 
                                      torch.nn.Linear(in_features = dim2, out_features=10), torch.nn.Tanh())

    def forward(self, x):
        cx = self.conv(x)
        px = self.maxp(cx)
        fx = torch.flatten(px, 1)
        x_output = self.fc(fx)
        return x_output


In [20]:
jax.local_device_count()

1

In [None]:
# train the circulant matrix model
kernel = NUTS(model_circulant_weight,
              target_accept_prob = 0.8,
              max_tree_depth = 10
              )

mcmc = MCMC(kernel,
            num_samples = 100,
            num_warmup = 50,
            num_chains = 2,
            progress_bar = True)

mcmc.run(random.PRNGKey(0), x,y)

In [None]:
# Train the full weight matrix model
kernel_1 = NUTS(model_full_weight,
              target_accept_prob = 0.8,
              max_tree_depth = 10
              )

mcmc_1 = MCMC(kernel_1,
            num_samples = 100,
            num_warmup = 50,
            num_chains = 2,
            progress_bar = True)

mcmc_1.run(random.PRNGKey(0), x, y)

In [21]:
# Train the point estimation mmodel

def get_mnist_torch(n, m):
    """
    Download MNIST and return train and evaluation sets.
    """
    img_to_tensor = transforms.ToTensor()
    mnist = datasets.MNIST('data',
        train=True,
        download=True,
        transform=img_to_tensor)
    mnist = list(mnist)
    # One batch with all of mnist
    train_loader = torch.utils.data.DataLoader(mnist,
        batch_size=len(mnist),
        shuffle=True)
    # x = images tensor, y = labels tensor 
    x, y = list(train_loader)[0]
    # Flatten images
    x = x.view(-1, 28*28)
    # Normalize
    x = (x - x.mean()) / x.std()
    # Train and test set
    train_x, train_y = x[0:n], y[0:n]
    val_x, val_y = x[n:n+m], y[n:n+m]
    return train_x, train_y, val_x, val_y


def train_point_estimation(model, dataloader, criterion, optimizer, num_epochs):
    model.cuda()
    loss_list = []
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for _, (data, label) in enumerate(dataloader):

            data = data.cuda()
            label = label.cuda()
            model.zero_grad()
            outputs = model(data)
            labels = label.unsqueeze(-1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        # Print the average loss at every epoch
        average_loss = epoch_loss / (len(dataloader))
        loss_list.append(average_loss)
        if (epoch + 1) % 10 == 0:
            print(f'Epoch: {epoch + 1}, Average Loss: {average_loss}')
    return loss_list

point_estimate_model = CNN()
x_torch, y_torch, xv_torch, yv_torch = get_mnist_torch(N,M)

train_set = torch.utils.data.TensorDataset(x_torch,y_torch)
test_set = torch.utils.data.TensorDataset(xv_torch,yv_torch)

train_loader = torch.utils.data.DataLoader(train_set, shuffle = True, batch_size = 1)

test_loader = torch.utils.data.DataLoader(test_set, shuffle = True, batch_size = 1)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(point_estimate_model.parameters(), lr = 0.0001)

train_point_estimation(point_estimate_model,train_loader, criterion, optimizer, 100)


    

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [8, 1, 5, 5], but got 2-dimensional input of size [1, 784] instead

In [None]:
# get posterior and prediction of the circular weight matrix model

posterior_samples = mcmc.get_samples()


posterior_predictive_test = numpyro.infer.Predictive(model_circulant_weight, posterior_samples)(
        jax.random.PRNGKey(3),xv)

posterior_predictive_train = numpyro.infer.Predictive(model_circulant_weight, posterior_samples)(
        jax.random.PRNGKey(3),x)

prior_predictive = numpyro.infer.Predictive(model_circulant_weight, num_samples=500)(
        jax.random.PRNGKey(3),xv)


In [None]:
# get posterior and prediction of the full weight matrix model

posterior_samples_1 = mcmc_1.get_samples()


posterior_predictive_test_1 = numpyro.infer.Predictive(model_full_weight, posterior_samples_1)(
        jax.random.PRNGKey(3),xv)

posterior_predictive_train_1 = numpyro.infer.Predictive(model_full_weight, posterior_samples_1)(
        jax.random.PRNGKey(3),x)

prior_predictive_1 = numpyro.infer.Predictive(model_full_weight, num_samples=500)(
        jax.random.PRNGKey(3),xv)

In [None]:
print(posterior_samples.keys())
print(posterior_samples_1.keys())
print(posterior_samples['w1'].shape)
print(posterior_samples_1['w1_full'].shape)
print(posterior_predictive_test_1.keys())

In [None]:
!pip install arviz
import arviz as az
az.style.use("arviz-doc")

In [None]:
def accuracy(pred, data):
    """
    Calculate accuracy of predicted labels (integers).

    pred: predictions, ndarray[sample_index, chain_index, data_index, logits]
    data: actual data (digit), ndarray[data_index]

    Prediction is taken as most common predicted value.
    Returns accuracy (#correct/#total).
    """
    n=data.shape[0]
    correct=0
    total=0
    for i in range(0, n):
        # Get most common prediction value from logits
        pred_i=int(jnp.argmax(jnp.sum(pred[:,i,:],0)))
        # Compare prediction with data
        if int(data[i])==int(pred_i):
            correct+=1.0
        total+=1.0
    # Return fractional accuracy
    return correct/total

In [None]:
# summary of circulant matrix model


#summary_data_circulant = arviz.from_numpyro(posterior=mcmc, prior=prior_predictive, posterior_predictive= posterior_predictive_test )
summary_data_circulant = az.convert_to_inference_data(posterior_samples)
az.plot_ess(summary_data_circulant,var_names=['w1'], kind = 'evolution')
plt.savefig("posterior_sample.png")

In [None]:
# Accuracy on test set
logits = posterior_predictive_test['logits']
print("Success posterior test = %.3f" % accuracy(logits, yv))

# Accuracy on training set
logits = posterior_predictive_train['logits']
print("Success posterior training = %.3f" % accuracy(logits, y))

logits = prior_predictive['logits']
print("Success prior = %.3f" % accuracy(logits, yv))

print("Posterior test diagnostics:")
numpyro.diagnostics.print_summary(posterior_samples)

In [None]:
# summary of full weight matrix model



# Accuracy on test set
logits = posterior_predictive_test_1['logits_full']

print("Success posterior test = %.3f" % accuracy(logits, yv))

# Accuracy on training set
logits = posterior_predictive_train_1['logits_full']
print("Success posterior training = %.3f" % accuracy(logits, y))

logits = prior_predictive_1['logits_full']
print("Success prior = %.3f" % accuracy(logits, yv))

print("Posterior test diagnostics:")
numpyro.diagnostics.print_summary(posterior_samples_1)

In [None]:
#