In [26]:
# SPP Layer Implementation in JAX-FLAX
# Author: Goktug Guvercin

!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [27]:
import jax
import flax
import jax.numpy as jnp

from math import floor

In [28]:
def spatial_pyramid_pool(feature_maps, pool_levels, data_format="cl", verbose=False):

  
  if data_format == "cl": # channel last
    height, width = feature_maps.shape[0:2]
  else: # channel first
    height, width = feature_maps.shape[1:3]
  
  # pool levels cannot be greater than map dimensions
  # pool levels cannot be negative value or zero
  conditions = [jnp.array(pool_levels) > height, 
                jnp.array(pool_levels) > width,
                jnp.array(pool_levels) <= 0]

  conditions = jnp.concatenate(conditions)

  if True in conditions:
    return None


  features = []
  for level in pool_levels:

    filter_height = floor(height / level)
    filter_width = floor(width / level)

    window_shape = (filter_height, filter_width)
    strides = (filter_width, filter_width)
    #remainder_height = height % level
    #remainder_width = width % level

    pooled_maps = flax.linen.max_pool(feature_maps, window_shape, strides)
    num_features = jnp.prod(jnp.array(pooled_maps.shape))
    feature_vector = pooled_maps.reshape(num_features)
    features.append(feature_vector)

    if verbose:
      print("Pool Level: ", level)
      print("Shape of pooled maps: ", pooled_maps.shape)
      print("Number of features in pooled maps: ", feature_vector.shape)
      print()

  features = jnp.concatenate(features)
  return features

"""
* SPP-Layer at first performs max pooling with a kernal whose size is same as the
size of feature maps extracted by convolutional backbone. Then, it repeats same
operation by halving kernel size in each level.

Level 0: Kernel Shape = Map Shape
Level 1: Kernel Shape = Map Shape / 2
Level 2: Kernel Shape = Map Shape / 4

pool_levels indicate the proportion of division for kernel shape

* One disadvantage of SPP layer is that if the shape of your feature maps is not
divisible by the proportion of patial-pyramid pooling levels, rounding needs to
be performed. In this case, SPP layer cannot guarantee fixed number of features
for images of two different size.

When verbose is activated, we see that the number of features in level 1 and 2
for two different feature maps of different size  are equal (512 and 2048).
However, we notice that they are not equal for pool level 4 because 10x10 is not
divisible by 4 while 16x16 is actually divisible. 
"""

key = jax.random.PRNGKey(seed=37)
feature_maps = jax.random.normal(key, (10, 10, 512))
feature_maps2 = jax.random.normal(key, (16, 16, 512))
pool_levels = [1, 2, 4]


features = spatial_pyramid_pool(feature_maps, pool_levels, verbose=True)
print()
features2 = spatial_pyramid_pool(feature_maps2, pool_levels, verbose=True)





Pool Level:  1
Shape of pooled maps:  (1, 1, 512)
Number of features in pooled maps:  (512,)

Pool Level:  2
Shape of pooled maps:  (2, 2, 512)
Number of features in pooled maps:  (2048,)

Pool Level:  4
Shape of pooled maps:  (5, 5, 512)
Number of features in pooled maps:  (12800,)


Pool Level:  1
Shape of pooled maps:  (1, 1, 512)
Number of features in pooled maps:  (512,)

Pool Level:  2
Shape of pooled maps:  (2, 2, 512)
Number of features in pooled maps:  (2048,)

Pool Level:  4
Shape of pooled maps:  (4, 4, 512)
Number of features in pooled maps:  (8192,)

