In [1]:
import tensorflow as tf

In [2]:
from datetime import datetime
%load_ext tensorboard

In [71]:
module = SimpleModule(3, 1, 'test')
print(module([0.0288, -0.3256, 0.5925]))

tf.Tensor(0.8745173, shape=(), dtype=float32)


In [70]:
class SimpleModule(tf.Module):
    def __init__(self, in_features, out_features, name=None):
        super().__init__(name=name)
        self.w = tf.Variable([-3.14, -2.31, 2.16])
    @tf.function
    def __call__(self, x):
        y = tf.reduce_sum(tf.multiply(x, self.w))
        return tf.nn.sigmoid(y)
    
class AnotherModule(tf.Module):
    def __init__(self, in_features, out_features, name=None):
        super().__init__(name=name)
        self.w = tf.Variable([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
    def __call__(self, x):
        y = tf.matmul(tf.reshape(x, [1, 3]), self.w)
        return y

In [51]:
class ComplexModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        
        self.first = AnotherModule(3, 3)
        self.second = SimpleModule(3, 1)
    
    @tf.function
    def __call__(self, x):
        x = self.first(x)
        return self.second(x)

In [37]:
my_module = ComplexModule()

In [38]:
print(my_module([0.0288, -0.3256, 0.5925]))

tf.Tensor(0.8745173, shape=(), dtype=float32)


In [39]:
chkp_path = 'my_checkpoint'
checkpoint = tf.train.Checkpoint(model=my_module)
checkpoint.write(chkp_path)

'my_checkpoint'

In [41]:
tf.train.list_variables(chkp_path)

[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('model/first/w/.ATTRIBUTES/VARIABLE_VALUE', [3, 3]),
 ('model/second/w/.ATTRIBUTES/VARIABLE_VALUE', [3])]

In [43]:
new_model = ComplexModule()
new_checkpoint = tf.train.Checkpoint(model=new_model)
new_checkpoint.restore('my_checkpoint')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fc80027b880>

In [44]:
print(new_model([0.0288, -0.3256, 0.5925]))

tf.Tensor(0.8745173, shape=(), dtype=float32)


In [52]:
current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/modules/' + current_time
writer = tf.summary.create_file_writer(log_dir)

another_model = ComplexModule()
tf.summary.trace_on(graph=True, profiler=True)
print(another_model([0.0288, -0.3256, 0.5925]))

with writer.as_default():
    tf.summary.trace_export(name='test_module2', step=0, profiler_outdir=log_dir)

tf.Tensor(0.8745173, shape=(), dtype=float32)


In [53]:
%tensorboard --logdir logs/modules

Reusing TensorBoard on port 6006 (pid 6815), started 0:49:24 ago. (Use '!kill 6815' to kill it.)

In [81]:
tf.saved_model.save(module, 'bdmi_model')

INFO:tensorflow:Assets written to: bdmi_model/assets


In [82]:
really_new_model = tf.saved_model.load('bdmi_model')

In [83]:
print(module([2., 2., 2.]))
print(really_new_model([2., 2., 2.]))

tf.Tensor(0.0013858676, shape=(), dtype=float32)
tf.Tensor(0.0013858676, shape=(), dtype=float32)
