### Adapted from RiGL paper notebook:
https://github.com/google-research/rigl/blob/master/rigl/imagenet_resnet/colabs/Resnet_50_Param_Flops_Counting.ipynb

In [None]:
# Download the official ResNet50 implementation and other libraries.
# the ResNet50 module s.t. we can use the model builders for our counting.
%%bash
test -d tpu || git clone https://github.com/tensorflow/tpu tpu && mv tpu/models/experimental/resnet50_keras ./
test -d rigl || git clone https://github.com/google-research/rigl rigl_repo && mv rigl_repo/rigl ./
test -d gresearch || git clone https://github.com/google-research/google-research google_research
pip install aim==3.17

In [None]:
!tar xf /content/imagenet_exps.tar.xz

In [None]:
import numpy as np
import tensorflow as tf
from google_research.micronet_challenge import counting
from resnet50_keras import resnet_model as resnet_keras
from rigl import sparse_utils
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


In [None]:
tf.compat.v1.reset_default_graph()
model = resnet_keras.ResNet50(1000)

In [None]:
masked_layers = []
for layer in model.layers:
  if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):
    masked_layers.append(layer)


In [None]:
PARAM_SIZE=32 # bits
import functools
get_stats = functools.partial(
    sparse_utils.get_stats, first_layer_name='conv1', last_layer_name='fc1000',
    param_size=PARAM_SIZE)
def print_stats(masked_layers, default_sparsity=0.8, method='erdos_renyi',
                custom_sparsities={}, is_debug=False, width=1., **kwargs):
  print('Method: %s, Sparsity: %f' % (method, default_sparsity))
  total_flops, total_param_bits, sparsity = get_stats(
      masked_layers, default_sparsity=default_sparsity, method=method,
      custom_sparsities=custom_sparsities, is_debug=is_debug, width=width, **kwargs)
  print('Total Flops: %.3f MFlops' % (total_flops/1e6))
  print('Total Size: %.3f Mbytes' % (total_param_bits/8e6))
  print('Real Sparsity: %.3f' % (sparsity))

## Loading sparsity from aim

In [None]:
# 50 layers + 4 projection layers at the beginning of every block (for the skip connection)

resnet_layers=['conv1/kernel:0',
'res2a_branch2a/kernel:0',
'res2a_branch2b/kernel:0',
'res2a_branch2c/kernel:0',
'res2a_branch1/kernel:0',
'res2b_branch2a/kernel:0',
'res2b_branch2b/kernel:0',
'res2b_branch2c/kernel:0',
'res2c_branch2a/kernel:0',
'res2c_branch2b/kernel:0',
'res2c_branch2c/kernel:0',
'res3a_branch2a/kernel:0',
'res3a_branch2b/kernel:0',
'res3a_branch2c/kernel:0',
'res3a_branch1/kernel:0',
'res3b_branch2a/kernel:0',
'res3b_branch2b/kernel:0',
'res3b_branch2c/kernel:0',
'res3c_branch2a/kernel:0',
'res3c_branch2b/kernel:0',
'res3c_branch2c/kernel:0',
'res3d_branch2a/kernel:0',
'res3d_branch2b/kernel:0',
'res3d_branch2c/kernel:0',
'res4a_branch2a/kernel:0',
'res4a_branch2b/kernel:0',
'res4a_branch2c/kernel:0',
'res4a_branch1/kernel:0',
'res4b_branch2a/kernel:0',
'res4b_branch2b/kernel:0',
'res4b_branch2c/kernel:0',
'res4c_branch2a/kernel:0',
'res4c_branch2b/kernel:0',
'res4c_branch2c/kernel:0',
'res4d_branch2a/kernel:0',
'res4d_branch2b/kernel:0',
'res4d_branch2c/kernel:0',
'res4e_branch2a/kernel:0',
'res4e_branch2b/kernel:0',
'res4e_branch2c/kernel:0',
'res4f_branch2a/kernel:0',
'res4f_branch2b/kernel:0',
'res4f_branch2c/kernel:0',
'res5a_branch2a/kernel:0',
'res5a_branch2b/kernel:0',
'res5a_branch2c/kernel:0',
'res5a_branch1/kernel:0',
'res5b_branch2a/kernel:0',
'res5b_branch2b/kernel:0',
'res5b_branch2c/kernel:0',
'res5c_branch2a/kernel:0',
'res5c_branch2b/kernel:0',
'res5c_branch2c/kernel:0',
'fc1000/kernel:0']

print(len(resnet_layers))

In [None]:
# Extract layer sparsity from aim run
import aim
import pandas as pd
from IPython.display import display

run_hash = '9eaf384e01d64dac80a7d9f6'
# data=[]
# for i in range(49):
query = f"'Live neurons in layer' in metric.name and '; whole training dataset' in metric.name and run.hash=='{run_hash}'"
print(query)
df = aim.Repo("/content/imagenet_exps").query_metrics(query).dataframe()
df.head(10)

# data.append(aim.Repo("/content/imagenet_exps").query_metrics(query).dataframe())

# df = pd.concat(data, ignore_index=True)

In [None]:
# Add neuron ratio column (which will be equal to params ratio for conv layers)
df["sparsity"] = 0
# print(df["metric.name"].unique())
for layer in df["metric.name"].unique():
  neuron_in_layer = df[df['metric.name']==layer]
  neuron_in_layer = neuron_in_layer[neuron_in_layer['step']==0]["value"].iloc[0]
  # print(neuron_in_layer)
  df.loc[df['metric.name'] == layer, "sparsity"] = 1 - df.loc[df['metric.name'] == layer, "value"] / neuron_in_layer

# df.tail(10)

# Now a function to map layer to specific sparsity at any give step
def retrieve_sparsity(step):
  j = 0 # True index, taking into account the skip (with projections) layers
  sparsities = []
  for i, layer_name in enumerate(resnet_layers[:-1]):
    if "branch1" in layer_name:
      j += 1  # To handle skip layers
    col_name = f"Live neurons in layer {i-j}; whole training dataset"
    sparsities.append(df[(df['metric.name'] == col_name) & (df['step'] == step)]["sparsity"].iloc[0])
    #print(sparsities)
  sparsities.append(0) # We never prune the fully connected layers with structured pruning
  return dict(zip(resnet_layers, sparsities))

print(retrieve_sparsity(0))
print(retrieve_sparsity(500000))

# Pruning FLOPs
We calculate theoratical FLOPs for pruning, which means we will start counting sparse FLOPs when the pruning starts.

In [None]:
# Provide training and inference flops for the run precised by the aim hash above

pruning_freq = 5000
batch_size = 256
total_steps = 500456

training_flops=0
seq_flops = []
for step in range(0, total_steps, pruning_freq):
  step_sparsities=retrieve_sparsity(step)
  c_flops, _, _ = get_stats(
      masked_layers, default_sparsity=0.0, method='random', custom_sparsities=step_sparsities)
  seq_flops.append(c_flops)
  if step < (total_steps//pruning_freq)*pruning_freq:
    training_flops += c_flops * 3 * pruning_freq * batch_size
    # print(step)
# Get the tail:
training_flops += c_flops * 3 * (total_steps-step) * batch_size
print(f"training flops:{training_flops}")
print(f"inference flops:{c_flops}")
#print(seq_flops)

In [None]:
# From RiGL notebook; used to calculute dense network flops

p_start, p_end, p_freq = 10000,25000,1000
target_sparsity = 0.8
total_flops = []
for i in range(0,32001,1000):
  if i < p_start:
    sparsity = 0.
  elif p_end < i:
    sparsity = target_sparsity
  else:
    sparsity = (1-(1-(i-p_start)/float(p_end-p_start))**3)*target_sparsity
  # print(i, sparsity)
  c_flops, _, _ = get_stats(
      masked_layers, default_sparsity=sparsity, method='random', custom_sparsities={'conv1/kernel:0':0, 'fc1000/kernel:0':0.0})
  # print(i, c_flops, sparsity)
  total_flops.append(c_flops)
avg_flops = sum(total_flops) / len(total_flops)
print('Average Flops: ', avg_flops, avg_flops/total_flops[0])
# print(total_flops)
print('Training Flops: ', total_flops[0]* 3 * 1281167 * 100)
print('Inference Flops: ', total_flops[0])