In [1]:
import math

In [2]:
def calculate_output_length(length_in, kernel_size, stride=1, padding=0, dilation=1):
  return (
    length_in + 2 * padding - dilation * (kernel_size - 1) - 1
  ) // stride + 1

In [3]:
def calculate_pool_length(w, f, s):
  return math.floor( ( ( w - f ) / s ) + 1 )

In [5]:
def calculate_num_params_conv_layer(
  channels_in, kernel_width, kernel_height, channels_out
):
  return channels_in * kernel_width * kernel_height * channels_out + channels_out

In [21]:
# kernel_sizes = [64, 32, 16,   8,   4]
# pool_sizes   = [ 8,  8,  2,   2,   4]
# strides      = [ 3,  3,  2,   2,   2]
# pool_strides = [ 8,  8,  2,   4,   2]
# num_filters  = [16, 32, 64, 128,   6]

kernel_sizes = [32, 16,   8,   4]
pool_sizes   = [ 8,  3,   3,   4]
strides      = [ 3,  3,   3,   2]
pool_strides = [ 8,  2,   4,   2]
num_filters  = [32, 64, 128,   6]

NUM_LAYERS = len(kernel_sizes)


INITIAL_AUDIO_NUM_FRAMES = 238000 # --> 8 kHz
compr_out_size = INITIAL_AUDIO_NUM_FRAMES

channels_in = 1

tot_num_params = 0

for compr_id, kernel_size, stride, pool_size, pool_stride, channels_out in zip(
  range(NUM_LAYERS), 
  kernel_sizes, strides, 
  pool_sizes, pool_strides,
  num_filters
):
  
  compr_out_size = calculate_output_length(
    compr_out_size, kernel_size=kernel_size, stride=stride
  )
  
  compr_out_size = calculate_pool_length(compr_out_size, pool_size, pool_stride)

  num_params = calculate_num_params_conv_layer(
    channels_in=channels_in, kernel_width=kernel_size, kernel_height=1, 
    channels_out=channels_out
  )

  channels_in = channels_out
  
  print(f"compr_out_{compr_id}_out: {compr_out_size}, num_params: {num_params}")

  tot_num_params += num_params

print(f"\n\ntotal number of params: {tot_num_params}")


compr_out_0_out: 9915, num_params: 1056
compr_out_1_out: 1649, num_params: 32832
compr_out_2_out: 137, num_params: 65664
compr_out_3_out: 32, num_params: 3078


total number of params: 102630
