In [None]:
# SPP Layer Implementation in JAX-FLAX
# Author: Goktug Guvercin

!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax
  Downloading flax-0.5.1-py3-none-any.whl (197 kB)
[K     |████████████████████████████████| 197 kB 4.5 MB/s 
[?25hCollecting rich~=11.1.0
  Downloading rich-11.1.0-py3-none-any.whl (216 kB)
[K     |████████████████████████████████| 216 kB 35.4 MB/s 
Collecting optax
  Downloading optax-0.1.2-py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 48.1 MB/s 
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 6.3 MB/s 
Collecting chex>=0.0.4
  Downloading chex-0.1.3-py3-none-any.whl (72 kB)
[K     |████████████████████████████████| 72 kB 653 kB/s 
Installing collected packages: commonmark, colorama, chex, rich, optax, flax
Successfully installed chex-0.1.3 colorama-0

In [None]:
import jax
import flax
import jax.numpy as jnp

from math import floor

In [None]:
"""
* SPP-Layer at first performs max pooling with a kernal whose size is same as the
size of feature maps extracted by convolutional backbone. Then, it repeats same
operation by halving kernel size in each level.

Level 0: Kernel Shape = Map Shape
Level 1: Kernel Shape = Map Shape / 2
Level 2: Kernel Shape = Map Shape / 4

pool_levels indicate the proportion of division for kernel shape

"""


def spatial_pyramid_pool(feature_maps, pool_levels, data_format="cl", verbose=False):

  """
  Parameters:
  ----------

  * feature_maps: 3-dimensional maps extracted by cnn backbone
  * pool_levels: a list of integers; they refer to the proportion of spatial 
                 size of pool filters to spatial size of maps (look at spp-layer 
                 explanation above)
  
  * data_format: a string value (cl: channel-last or cf: channel first)
  * verbose: boolean value;
             True: It prints pool level, shape of pooled maps and its number of features
             False: It prints nothing
  """

  
  if data_format == "cl": # channel last
    height, width = feature_maps.shape[0:2]
  elif data_format == "cf": # channel first
    height, width = feature_maps.shape[1:3]
  else:
    return None

  # pool levels cannot be greater than map dimensions
  # pool levels cannot be negative value or zero
  conditions = [jnp.array(pool_levels) > height, 
                jnp.array(pool_levels) > width,
                jnp.array(pool_levels) <= 0]

  conditions = jnp.concatenate(conditions)

  if True in conditions:
    return None


  features = []
  for level in pool_levels:

    filter_height = floor(height / level)
    filter_width = floor(width / level)

    window_shape = (filter_height, filter_width)
    strides = (filter_width, filter_width)

    pooled_maps = flax.linen.max_pool(feature_maps, window_shape, strides)
    num_features = jnp.prod(jnp.array(pooled_maps.shape))
    feature_vector = pooled_maps.reshape(num_features)
    features.append(feature_vector)

    if verbose:
      print("Pool Level: ", level)
      print("Shape of pooled maps: ", pooled_maps.shape)
      print("Number of features in pooled maps: ", feature_vector.shape)
      print()

  features = jnp.concatenate(features)
  return features

In [None]:
"""
One disadvantage of SPP layer is that if the shape of your feature maps is not
divisible by the proportion of patial-pyramid pooling levels, rounding needs to
be performed. In this case, SPP layer cannot guarantee fixed number of features
for images of two different size.

When verbose is activated, we see that the number of features in level 1 and 2
for two different feature maps of different size given as example below  are 
same (512 and 2048). However, we notice that they are not equal for pool level 
4 because 10x10 is not divisible by 4 while 16x16 is actually divisible. 
"""

key = jax.random.PRNGKey(seed=37)
feature_maps = jax.random.normal(key, (10, 10, 512))
feature_maps2 = jax.random.normal(key, (16, 16, 512))
pool_levels = [1, 2, 4]


features = spatial_pyramid_pool(feature_maps, pool_levels, verbose=True)
print()
features2 = spatial_pyramid_pool(feature_maps2, pool_levels, verbose=True)

Pool Level:  1
Shape of pooled maps:  (1, 1, 512)
Number of features in pooled maps:  (512,)

Pool Level:  2
Shape of pooled maps:  (2, 2, 512)
Number of features in pooled maps:  (2048,)

Pool Level:  4
Shape of pooled maps:  (5, 5, 512)
Number of features in pooled maps:  (12800,)


Pool Level:  1
Shape of pooled maps:  (1, 1, 512)
Number of features in pooled maps:  (512,)

Pool Level:  2
Shape of pooled maps:  (2, 2, 512)
Number of features in pooled maps:  (2048,)

Pool Level:  4
Shape of pooled maps:  (4, 4, 512)
Number of features in pooled maps:  (8192,)



In [None]:
"""
spatial_pyramid_pool() is a function that imitates pyramid pooling operation in
requested pool levels, but it is applicable to only 1 feature maps. In other words,
it does not work for a batch of maps. To construct this auto-batching system, 
we can use vmap() in JAX.

* Input axes of spatial_pyramid_pool() is 4 dimensional, but we want to vectorize 
  onto the feature maps, which is first argument. Hence, "in_axes" argument in
  vmap becomes (x, None, None, None) where batch axis in feature maps taken as
  input is 0, so x is equal to 0.

* The output of spatial_pyramid_pool() is 1 dimensional vector, so "out_axes"
  argument in vmap becomes x where batch axis in produced output is 0, so x is
  equal to 0.

"""
spp_layer = jax.vmap(spatial_pyramid_pool, in_axes=(0, None, None, None), out_axes=0)
batched_feature_maps = feature_maps = jax.random.normal(key, (2, 16, 16, 512))
batched_features = spp_layer(batched_feature_maps, pool_levels, "cl", True)
print(batched_features.shape)

Pool Level:  1
Shape of pooled maps:  (32, 1, 16)
Number of features in pooled maps:  (512,)

Pool Level:  2
Shape of pooled maps:  (64, 2, 16)
Number of features in pooled maps:  (2048,)

Pool Level:  4
Shape of pooled maps:  (128, 4, 16)
Number of features in pooled maps:  (8192,)

(2, 10752)
