In [1]:
# Testing how TF Checkpoints work

import tensorflow as tf
import numpy as np
np.set_printoptions(precision=3, suppress=True)

In [2]:
x = tf.ones([6,1])

In [None]:
class CFM(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.dynamics = LocallyLinearPredictor()
        self.observer = Observer()

    def call(self, example):
        observation, observation_pos = self.get_positive_pairs(example)

        # forward pass
        z, z_pos = self.encoder(observation), self.encoder(observation_pos)  # b x z_dim
        pred_inputs = {
            'z': z['z'],
            'z_pos': z_pos['z'],
        }
        for k in self.action_keys:
            pred_inputs[k] = example[k]
        z_seq = self.dynamics(pred_inputs)
        y_seq = self.observer(pred_inputs)

        output = {
            'z_pos': z_pos['z'],
        }
        output.update(z_seq)
        output.update(y_seq)
        return output



class Encoder(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.encoder_sequtial = Sequential([
            layers.Conv2D(filters=64, kernel_size=3),
            layers.LeakyReLU(0.2),
            layers.Conv2D(filters=64, kernel_size=4, strides=2),
            layers.LeakyReLU(0.2),
            # 64 x 32 x 32
            layers.Conv2D(filters=64, kernel_size=3, strides=1),
            layers.LeakyReLU(0.2),
            layers.Conv2D(filters=128, kernel_size=4, strides=2),
            layers.LeakyReLU(0.2),
            # 128 x 16 x 16
            layers.Conv2D(filters=256, kernel_size=4, strides=2),
            layers.LeakyReLU(0.2),
            # Option 1: 256 x 8 x 8
            layers.Conv2D(filters=256, kernel_size=4, strides=2),
            layers.LeakyReLU(0.2),
            # 256 x 4 x 4
        ], name='encoder_sequential')
        self.out = layers.Dense(self.z_dim)

    # @tf.function
    def call(self):
        o = tf.concat([observation[k] for k in self.obs_keys], axis=-1)
        h = self.encoder_sequtial(o)
        # NOTE: [:-3] gets all but the last 3 dimensions, which are the H, W, and C of the tensor
        # doing this specifically allows x to have multiple "batch" dimensions,
        # which is useful to treating [batch, time, ...] as all just batch dimensions
        h = tf.reshape(h, h.shape.as_list()[:-3] + [-1])
        z = self.out(h)
        return {
            'z': z
        }

class LocallyLinearPredictor(tf.keras.Model):

    def __init__(self):
        super().__init__()

        my_layers = []
        for h in self.hparams['fc_layer_sizes']:
            my_layers.append(layers.Dense(h, activation="relu"))
        my_layers.append(layers.Dense(self.z_dim * self.z_dim, activation=None))

        self.dynamics_sequential = Sequential(my_layers, name='predictor_sequential')

    # @tf.function
    def call(self, inputs, **kwargs):
        a = tf.concat([inputs[k] for k in self.action_keys], axis=-1)

        z = inputs['z']
        x = tf.concat((z, a), axis=-1)
        linear_dynamics_params = self.dynamics_sequential(x)
        linear_dynamics_matrix = tf.reshape(linear_dynamics_params, x.shape.as_list()[:-1] + [self.z_dim, self.z_dim])
        z_next = tf.squeeze(tf.linalg.matmul(linear_dynamics_matrix, tf.expand_dims(z, axis=-1)), axis=-1)
        z_seq = tf.concat([z, z_next], axis=1)
        return {
            'z': z_seq
        }

    def compute_loss(self, dataset_element, outputs):
        raise NotImplementedError()


class Observer(tf.keras.Model):

    def __init__(self):
        super().__init__()
        my_layers = []
        for h in self.hparams['fc_layer_sizes']:
            my_layers.append(layers.Dense(h, activation="relu"))
        my_layers.append(layers.Dense(final_dim, activation=None))

        self.observer_sequential = Sequential(my_layers, name='observer_sequential')

    # @tf.function
    def call(self, inputs, **kwargs):
        z = tf.concat([inputs[k] for k in self.state_keys], axis=-1)
        y = self.observer_sequential(z)
        y_dict = vector_to_dict(self.observation_features_description, y)
        return y_dict

    def compute_loss(self, dataset_element, outputs):
        raise NotImplementedError()

In [3]:
class MyModel(tf.keras.Model):
    
    def __init__(self):
        super().__init__(self)
        self.encoder = Encoder()
        
    def call(self, x):
        return self.encoder(x)
        
class Encoder(tf.keras.Model):
    
    def __init__(self):
        super().__init__(self)
        self.s = tf.keras.Sequential([tf.keras.layers.Dense(10), tf.keras.layers.Dense(10)])
        
    def call(self, x):
        return self.s(x)
        
model = MyModel()
_ = model(x)

In [4]:
ckpt = tf.train.Checkpoint(kwd=model)
ckpt.save("ckpt")

'ckpt-1'

In [5]:
w = model.get_weights()
for w_i in w:
    print(w_i[0], w_i.shape)

[-0.362 -0.333  0.319 -0.388 -0.15   0.716  0.047 -0.155  0.442  0.03 ] (1, 10)
0.0 (10,)
[ 0.119 -0.291 -0.352 -0.115 -0.091  0.49  -0.371  0.082 -0.351  0.46 ] (10, 10)
0.0 (10,)


In [6]:
model2 = MyModel()

In [7]:
ckpt2 = tf.train.Checkpoint(kwd=model2)
status = ckpt2.restore("ckpt-1")
_ = model2(x)

In [8]:
w = model2.get_weights()
for w_i in w:
    print(w_i[0], w_i.shape)

[-0.362 -0.333  0.319 -0.388 -0.15   0.716  0.047 -0.155  0.442  0.03 ] (1, 10)
0.0 (10,)
[ 0.119 -0.291 -0.352 -0.115 -0.091  0.49  -0.371  0.082 -0.351  0.46 ] (10, 10)
0.0 (10,)


In [9]:
status.assert_consumed()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f318c141a30>