In [1]:
import tensorflow as tf
import numpy as np
import importlib.util

In [2]:
def load_module(path):
    spec = importlib.util.spec_from_file_location("module.name", path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

def load_arch(arch_path):
    arch = load_module(arch_path)
    nn = arch.CNN()

    g = tf.Graph()
    with g.as_default():
        nn.create_architecture()
    return g, nn

In [3]:
baseline, _ = load_arch("arch_baseline.py")
invariant, _ = load_arch("arch_invariant_a.py")

In [4]:
def count(graph):
    variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    shapes = [x.get_shape().as_list() for x in variables]
    return sum([np.prod(x) for x in shapes])

def weights_and_filter(graph):
    variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    variables = [x for x in variables if 'F' in x.name or 'W' in x.name]
    xs = [(np.prod(x.get_shape().as_list()), x.get_shape().as_list(), x.name) for x in variables]
    return sorted(xs)

In [5]:
count(baseline), count(invariant)

(15618341, 4913561)

In [6]:
list(zip(weights_and_filter(baseline), weights_and_filter(invariant)))

[((1728, [6, 6, 3, 16], 'nn/conv-3-16/F:0'),
  (192, [4, 4, 3, 4], 'nn/conv-3-8x4/F:0')),
 ((2304, [3, 3, 16, 16], 'nn/conv-16-16/F:0'),
  (1152, [3, 3, 32, 4], 'nn/conv-8x4-8x4/F:0')),
 ((4608, [3, 3, 16, 32], 'nn/conv-16-32/F:0'),
  (4608, [3, 3, 64, 8], 'nn/conv-8x8-8x8/F:0')),
 ((9216, [3, 3, 32, 32], 'nn/conv-32-32/F:0'),
  (6400, [5, 5, 32, 8], 'nn/conv-8x4-8x8/F:0')),
 ((18432, [3, 3, 32, 64], 'nn/conv-32-64/F:0'),
  (9472, [256, 37], 'nn/fc-8x256-37/W:0')),
 ((36864, [3, 3, 64, 64], 'nn/conv-64-64/F:0'),
  (18432, [3, 3, 128, 16], 'nn/conv-8x16-8x16/F:0')),
 ((36864, [3, 3, 64, 64], 'nn/conv-64-64_1/F:0'),
  (25600, [5, 5, 64, 16], 'nn/conv-8x8-8x16/F:0')),
 ((37888, [1024, 37], 'nn/fc-1024-37/W:0'),
  (73728, [3, 3, 256, 32], 'nn/conv-8x32-8x32/F:0')),
 ((73728, [3, 3, 64, 128], 'nn/conv-64-128/F:0'),
  (102400, [5, 5, 128, 32], 'nn/conv-8x16-8x32/F:0')),
 ((147456, [3, 3, 128, 128], 'nn/conv-128-128/F:0'),
  (147456, [3, 3, 64, 256], 'nn/conv-8x64-256/F:0')),
 ((294912, [3, 3