Step 1: Set up the environment for reproducibility

In [21]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
import random
import os
from dotenv import load_dotenv

# Cargamos las variables de entorno
load_dotenv(dotenv_path='./../variables.env')

# set the seed
seed = int(os.getenv('SEED'))
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

Step 2: Compute SynFlow scores

In [22]:
def linearize(model):
  """ 
    Get the absolute value of the weights of the model and return the signs of the weights 
  """
  signs = []
  for layer in model.layers:
    if len(layer.get_weights()) > 0:
      weights = layer.get_weights()
      layer_signs = [np.sign(w) for w in weights]
      new_weights = [np.abs(w) for w in weights]
      layer.set_weights(new_weights)
      signs.append(layer_signs)
  return signs

In [23]:
def compute_synflow_per_weight(model, input_shape):
  signs = linearize(model) # Get the signs of the weights and transform the absolute value of the weights

  inputs = np.ones((1,)+input_shape)
  with tf.GradientTape() as tape:
    trainable_variables = [tf.Variable(v) for v in model.trainable_variables]
    tape.watch(trainable_variables)
    outputs = model(inputs)
    loss = tf.reduce_sum(outputs)
  grads = tape.gradient(loss, model.trainable_variables)

  # Apply gradient clipping to avoid numerical instability
  grads = [tf.clip_by_value(grad, -1.0, 1.0) for grad in grads]

  # Diagnostic prints for gradients and weights
  for i, g in enumerate(grads):
    print(f"Gradient of layer {i}: {g.numpy()}")
    if np.isnan(g.numpy()).any():
      print(f"NaN detected in gradient of layer {i}: {g.numpy()}")
      # Pause and wait for user input
      print("Model summary:")
      model.summary()
      input("NaN detected. Press enter to continue...")
      return None
    
  for i, w in enumerate(model.trainable_variables):
    print(f"Weights of layer {i}: {w.numpy()}")
    if np.isnan(w.numpy()).any():
      print(f"NaN detected in weights of layer {i}: {w.numpy()}")
      # Pause and wait for user input
      print("Model summary:")
      model.summary()
      for i, g in enumerate(grads):
        print(f"Gradient of layer {i}: {g.numpy()}")
      input("NaN detected. Press enter to continue...")
      return None
      
    synflow_scores = [np.abs(w.numpy() * g.numpy()) for w, g in zip(model.trainable_variables, grads)]

    return synflow_scores

In [24]:
# Example model with ones inizialitation
conv_args = {
  "activation": "relu",
  "padding": "same",
  "kernel_initializer": "ones",
  "bias_initializer": "ones"
}

model = tf.keras.Sequential([
  tf.keras.layers.Conv2D(32, kernel_size=(1, 1), input_shape=(64, 64, 1),  **conv_args),
  tf.keras.layers.Conv2D(64, (3, 3), **conv_args),
  tf.keras.layers.Conv2D(64, (3, 3), **conv_args),
  tf.keras.layers.Conv2D(64, (3, 3), **conv_args),
  tf.keras.layers.Conv2D(64, (3, 3), **conv_args),
  tf.keras.layers.Conv2D(64, (3, 3), **conv_args)
])

# initialize model weights
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.build()

# Define the input shape
input_shape = (64, 64, 1)

# Compute synflow scores
synflow_scores = compute_synflow_per_weight(model, input_shape)

# Print the synflow scores
for i, score in enumerate(synflow_scores):
  print(f"Synflow score of layer {i}: {score}")

# Print the synflow scores summed over all weights for each layer
total_synflow_scores = [np.sum(score) for score in synflow_scores]
for i, score_sum in enumerate(total_synflow_scores):
  print(f"Total synflow score of layer {i}: {score_sum}")

# Print the total synflow score
total_synflow_score = np.sum(total_synflow_scores)
print(f"Total synflow score: {total_synflow_score}")

Gradient of layer 0: [[[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
    1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]]
Gradient of layer 1: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.]
Gradient of layer 2: [[[[1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   ...
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]]

  [[1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   ...
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]]

  [[1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   ...
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]]]


 [[[1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   ...
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 1. 1. 1.]]

  [[1. 1. 1. ... 1. 1. 1.]
   [1. 1. 1. ... 