<a href="https://colab.research.google.com/github/PetchMa/deeplearning_fundamentals/blob/main/CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Build CNN Model

In this notebook we once again try to build a CNN model from scratch uisng JAX and other packages with as little help as possible as means of learning explicitly how these algorithms work.

In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax import jit, vmap, grad, pmap,value_and_grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader 
import numpy as np

# Convolutional Layers
Convolutional networks are special because of these layers that slides kernels around images and "aggregate" data from adjacent pixels. We will implement this from scratch as the following function below. Note this is absolutely digustingly slow because of the implementation and in reality there exists smarter stride tricks which I've yet to fully understand. However the function below captures the core mechanics of the CNN model.

Besides just the convolutional layer we also have a pooling layer that reduces the dimensionality of the data.

Also keep in mind that this is uses python LISTS in the convolutional layer and the reason why is because JAX is completely functional and thus the arrays are immutable and cannot be assigned values and thus to store the data I can't update values like you'd normally would in a numpy array. Yes I am away this is quite disgusting but at least it works...

In [2]:
@jit
def conv2d(image, filter):
  # Height and width of output image
  Hout = image.shape[0] - filter.shape[0] + 1
  Wout = image.shape[1] - filter.shape[1] + 1
  output = []
  
  # loops through the h index
  for i in range(Hout):
    # loops through the w index
    rows = []
    for j in range(Wout):
      # loops through the depth of the filters
      depth  = []
      for cout in range(filter.shape[2]):
        depth.append(jnp.multiply(image[ i:i+filter.shape[0], j:j+filter.shape[1], :], filter[:,:,cout]).sum())
      rows.append(depth)
    output.append(rows)
  return jnp.array(output)

# @jax.jit
# def conv2d(image, filter):
#   # Height and width of output image
#   Hout = image.shape[0] - filter.shape[0] + 1
#   Wout = image.shape[1] - filter.shape[1] + 1
#   print([Hout, Wout, filter.shape[2]])
#   output = np.zeros([Hout, Wout, filter.shape[2]])
  
#   # loops through the h index
#   for i in range(Hout):
#     # loops through the w index
#     rows = []
#     for j in range(Wout):
#       # loops through the depth of the filters
#       depth  = []
#       for cout in range(filter.shape[2]):
#         print( jnp.multiply(image[ i:i+filter.shape[0], j:j+filter.shape[1], :], filter[:,:,cout]).sum())
#         output[i,j,cout] = jnp.multiply(image[ i:i+filter.shape[0], j:j+filter.shape[1], :], filter[:,:,cout]).sum().astype(float)
#   return jnp.array(output)

@jit
def pooling(mat,ksize,method='max',pad=False):

    m, n = mat.shape[:2]
    ky,kx=ksize

    _ceil=lambda x,y: int(jnp.ceil(x/float(y)))

    if pad:
        ny=_ceil(m,ky)
        nx=_ceil(n,kx)
        size=(ny*ky, nx*kx)+mat.shape[2:]
        mat_pad=jnp.full(size,jnp.nan)
        mat_pad[:m,:n,...]=mat
    else:
        ny=m//ky
        nx=n//kx
        mat_pad=mat[:ny*ky, :nx*kx, ...]

    new_shape=(ny,ky,nx,kx)+mat.shape[2:]

    if method=='max':
        result=jnp.nanmax(mat_pad.reshape(new_shape),axis=(1,3))
    else:
        result=jnp.nanmean(mat_pad.reshape(new_shape),axis=(1,3))

    return result

# Initialize Model
Now that we have some idea what this special layer is we can then start initializing the weights of the model. We realize that we want to randomly sample the filter weights and then we want to randomly sample the weights and biases of the fully connected neural network. Thus we have the following: 

We create the data structure like this: a dictionary which contains the filters to the convolutional mode and then we construct the weights for the fully connected and it is stored in a dictionary!

However in order to string these layers together we need to make sure the shapes match!! Thus we need to do the following:

In [3]:
def calc_output_pooling_shape(image, filter):
  Hout = image[0] - filter[0] + 1
  Wout = image[1] - filter[1] + 1

  return ( Hout//2, Wout//2, filter[2])


def init_conv_model(filters, layer_widths, img_shape, parent_key, scale =0.01):
  # This first part is the convolutional layers
  conv_layers = []
  keys = jax.random.split(parent_key,num=len(filters))

  image_shape = img_shape

  for curr_filter, kernel_key in zip(filters, keys):
    conv_layers.append(scale*jax.random.normal(kernel_key, shape=curr_filter))
    image_shape = calc_output_pooling_shape(image_shape, curr_filter)

  # then we flatten the layers into a single vector
  flatten_dimension = image_shape[0]*image_shape[1]*image_shape[2]
  in_width = layer_widths[0]
  fully_connected  = [] 

  keys = jax.random.split(parent_key,num=len(layer_widths)-1)
  weight_key, bias_key = jax.random.split(keys[0])
  fully_connected.append(
                  [scale*jax.random.normal(weight_key, shape=(in_width, flatten_dimension)),
                  scale*jax.random.normal(bias_key, shape=(in_width,))]
  )
  # then we feed it properly through the fully connected! This makes sures the shape
  # is correct
  for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):
    weight_key, bias_key = jax.random.split(key)
    fully_connected.append(
                   [scale*jax.random.normal(weight_key, shape=(out_width, in_width)),
                    scale*jax.random.normal(bias_key, shape=(out_width,))]
    )
  params = {}
  params['conv_weights'] = conv_layers
  params['full_connected_weights'] = fully_connected
  return params

filters = [(3,3,16), (3,3,64)]
layers = [784, 512, 256, 10]
rng = jax.random.PRNGKey(seed=0)


convolutional_model_weights = init_conv_model(filters, layers, (28,28), rng, scale =0.01)
print(jax.tree_map(lambda x: x.shape, convolutional_model_weights))

{'conv_weights': [(3, 3, 16), (3, 3, 64)], 'full_connected_weights': [[(784, 1600), (784,)], [(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]}


# Feedforward Neural Network

Now we need to make the neural network "alive" by implementing each of the feed forward processes with the initialized model weights that we have. We get the following:

In [None]:

def feedforward(params, img):
  conv_filters = params['conv_weights']
  fully_connected = params['full_connected_weights']
  x = img
  for filter in conv_filters:
    x = conv2d(x, filter)
    x = pooling(x,(2,2),method='max',pad=False)
    x = jax.nn.relu(x)
  # we then unravel the function and flatten it
  x = jnp.ravel(x)

  hidden_layers = fully_connected[:-1]
  activation = x
  for w,b in hidden_layers:
    activation = jax.nn.relu(jnp.dot(w,activation)+b)

  w_last, b_last = fully_connected[-1]
  logits = jnp.dot(w_last,activation)+b_last
  return logits - logsumexp(logits)

 
filters = [(3,3,3), (3,3,16)]
layers = [784, 512, 256, 10]
rng = jax.random.PRNGKey(seed=0)

convolutional_model_weights = init_conv_model(filters, layers, (28,28), rng, scale =0.01)

# LOOK AT THIS VMAP FUNCTION AND REMEMBER IT CLEARLY
batched_cnn_predict = vmap(feedforward, in_axes=(None, 0))

dummy_imgs_flat = np.random.randn(16, 28,28,1)

predictions = batched_cnn_predict(convolutional_model_weights, dummy_imgs_flat)

# Loss Function
We once again use the categorical cross entropy loss which is the following: 

In [None]:
def loss_fn(params, x, y):
    """ Compute the multi-class cross-entropy loss """
    preds = batched_cnn_predict(params, x)
    return -np.sum(preds * y)

# def loss_fn(params, x, y):
#   predictions = batched_cnn_predict(params, imgs)
#   return jnp.mean((predictions - y) ** 2)

# Data Loading
This is once again some uninteresting data loading and augmentation to scale everything down between 0 and 1 instead of the usual RGB 255 colour scheme.
# Preprocess Data
We need to massage the data a bit and rescale the data to the shapes we desire namely renormalize the data.

In [None]:
# data loading
def custom_transform(x):
  return np.expand_dims(x/np.max(x), axis=2)
train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='train_mnist', train=False, download=True, transform=custom_transform)

In [None]:
def custom_collate_fn(batch):
  transposed_data = list(zip(*batch))
  labels = np.array(transposed_data[1])
  imgs = np.stack(transposed_data[0])
  return imgs, labels


train_loader = DataLoader(train_dataset, batch_size=128, shuffle = True, collate_fn=custom_collate_fn)
batch_data = next(iter(train_loader))
imgs = batch_data[0]
labels = batch_data[1]
print(imgs.shape)

# Training
Now we get to the fun part of training this, this will be the exact same procedure as the MLP training process! The script is slow af and so I only ran just to see that the loss is decreasing. The concept is correct, and in the real world nobody actually does this anymore lmao

In [None]:
num_epochs = 10

def update(params, imgs, gt_lbls, lr = 0.01):
  loss, grads = value_and_grad(loss_fn)(params,imgs,gt_lbls)
  return loss, jax.tree_multimap(lambda p,g:p-lr*g, params, grads)



convolutional_model_weights = init_conv_model(filters, layers, (28,28), rng, scale =0.1)


for epochs in range(num_epochs):
  for count, (imgs, lbls) in enumerate(train_loader):
    gt_labels = jax.nn.one_hot(lbls, len(MNIST.classes))
    loss, MLP_params = update(convolutional_model_weights, imgs, gt_labels, lr=0.1)
    print(loss)
  break