In [1]:
import tensorflow as tf
import numpy as np
import importlib.util

In [2]:
def load_module(path):
    spec = importlib.util.spec_from_file_location("module.name", path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

def load_arch(arch_path):
    arch = load_module(arch_path)
    nn = arch.CNN()

    g = tf.Graph()
    with g.as_default():
        nn.create_architecture()
    return g, nn

In [6]:
baseline, _ = load_arch("arch_baseline.py")
invariant, _ = load_arch("arch_invariant_c.py")

In [7]:
def count(graph):
    variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    shapes = [x.get_shape().as_list() for x in variables]
    return sum([np.prod(x) for x in shapes])

def weights_and_filter(graph):
    variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    variables = [x for x in variables if 'F' in x.name or 'W' in x.name]
    xs = [(np.prod(x.get_shape().as_list()), x.get_shape().as_list(), x.name) for x in variables]
    return sorted(xs)

In [8]:
count(baseline), count(invariant)

(15618341, 1351898)

In [None]:
list(zip(weights_and_filter(baseline), weights_and_filter(invariant)))