In [1]:
# Testing how TF Checkpoints work

import tensorflow as tf
from tensorflow.keras import layers, Sequential
import numpy as np
np.set_printoptions(precision=3, suppress=True)

In [2]:
x = tf.ones([6,1])

In [3]:
class CFM(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.dynamics = LocallyLinearPredictor()
        self.observer = Observer()

    def call(self, x):
        # forward pass
        z, z_pos = self.encoder(x), self.encoder(x)  # b x z_dim
        z_seq = self.dynamics(z)
        y_seq = self.observer(z_pos)
        return z_seq, y_seq

class Encoder(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.encoder_sequtial = Sequential([
            layers.Conv2D(filters=64, kernel_size=3),
            layers.LeakyReLU(0.2),
            layers.Conv2D(filters=64, kernel_size=4, strides=2),
            layers.LeakyReLU(0.2),
            # 64 x 32 x 32
            layers.Conv2D(filters=64, kernel_size=3, strides=1),
            layers.LeakyReLU(0.2),
            layers.Conv2D(filters=128, kernel_size=4, strides=2),
            layers.LeakyReLU(0.2),
            # 128 x 16 x 16
            layers.Conv2D(filters=256, kernel_size=4, strides=2),
            layers.LeakyReLU(0.2),
            # Option 1: 256 x 8 x 8
            layers.Conv2D(filters=256, kernel_size=4, strides=2),
            layers.LeakyReLU(0.2),
            # 256 x 4 x 4
        ], name='encoder_sequential')
        self.z_dim = 32
        self.out = layers.Dense(self.z_dim)

    # @tf.function
    def call(self, x):
        h = self.encoder_sequtial(x)
        z = self.out(h)
        return z

class LocallyLinearPredictor(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.z_dim = 32

        my_layers = []
        for h in [32,32]:
            my_layers.append(layers.Dense(h, activation="relu"))
        my_layers.append(layers.Dense(self.z_dim * self.z_dim, activation=None))

        self.dynamics_sequential = Sequential(my_layers, name='predictor_sequential')

    # @tf.function
    def call(self, x):
        z = self.dynamics_sequential(x)
        return z


class Observer(tf.keras.Model):

    def __init__(self):
        super().__init__()
        my_layers = []
        for h in [32,32]:
            my_layers.append(layers.Dense(h, activation="relu"))
        final_dim = 16
        my_layers.append(layers.Dense(final_dim, activation=None))

        self.observer_sequential = Sequential(my_layers, name='observer_sequential')

    # @tf.function
    def call(self, x):
        y = self.observer_sequential(x)
        return y


In [4]:
class MyModel(tf.keras.layers.Layer):
    
    def __init__(self):
        super().__init__(self)
        self.encoder = Encoder()
        
    def call(self, x):
        return self.encoder(x)
        
class Encoder(tf.keras.layers.Layer):
    
    def __init__(self):
        super().__init__(self)
        self.s = tf.keras.Sequential([tf.keras.layers.Dense(10), tf.keras.layers.Dense(10)])
        
    def call(self, x):
        return self.s(x)

In [5]:
# model = MyModel()
model = CFM()
_ = model(x)

In [6]:
ckpt = tf.train.Checkpoint(kwd=model)
ckpt.save("my-ckpt")

'ckpt-1'

In [7]:
w = model.get_weights()
for w_i in w:
    print(w_i[0], w_i.shape)

[-0.315 -0.362 -0.403  0.395 -0.529 -0.439 -0.302  0.34  -0.134 -0.554] (1, 10)
0.0 (10,)
[-0.529  0.374 -0.391  0.094  0.437 -0.17  -0.141 -0.031 -0.21   0.52 ] (10, 10)
0.0 (10,)
[ 0.059 -0.183 -0.298  0.036  0.229 -0.105  0.265  0.365 -0.113 -0.013
 -0.193 -0.326  0.301  0.343  0.342  0.12   0.294  0.186 -0.038  0.062
 -0.154 -0.333  0.064  0.004 -0.268  0.041  0.223 -0.012  0.371 -0.158
 -0.15  -0.357] (10, 32)
0.0 (32,)
[-0.204 -0.065  0.051  0.077 -0.087  0.113  0.2    0.167 -0.129 -0.15
  0.036 -0.118  0.236 -0.067 -0.109  0.249  0.28   0.286 -0.297  0.175
  0.24   0.172 -0.237  0.08   0.161 -0.041 -0.161  0.283 -0.006 -0.114
 -0.056 -0.031] (32, 32)
0.0 (32,)
[-0.037  0.001 -0.067 ... -0.044  0.012  0.043] (32, 1024)
0.0 (1024,)
[-0.188 -0.223 -0.267 -0.232  0.072  0.02   0.37  -0.155  0.259 -0.035
 -0.253  0.068 -0.354  0.191  0.033 -0.162 -0.309 -0.256 -0.22  -0.261
 -0.311  0.268  0.203  0.126  0.158  0.358 -0.128 -0.183  0.318 -0.215
  0.289 -0.001] (10, 32)
0.0 (32,)
[ 0.0

In [8]:
# model2 = MyModel()
model2 = CFM()

In [9]:
ckpt2 = tf.train.Checkpoint(kwd=model2)
status = ckpt2.restore("my-ckpt-1")
_ = model2(x)

In [10]:
w = model2.get_weights()
for w_i in w:
    print(w_i[0], w_i.shape)

[-0.315 -0.362 -0.403  0.395 -0.529 -0.439 -0.302  0.34  -0.134 -0.554] (1, 10)
0.0 (10,)
[-0.529  0.374 -0.391  0.094  0.437 -0.17  -0.141 -0.031 -0.21   0.52 ] (10, 10)
0.0 (10,)
[ 0.059 -0.183 -0.298  0.036  0.229 -0.105  0.265  0.365 -0.113 -0.013
 -0.193 -0.326  0.301  0.343  0.342  0.12   0.294  0.186 -0.038  0.062
 -0.154 -0.333  0.064  0.004 -0.268  0.041  0.223 -0.012  0.371 -0.158
 -0.15  -0.357] (10, 32)
0.0 (32,)
[-0.204 -0.065  0.051  0.077 -0.087  0.113  0.2    0.167 -0.129 -0.15
  0.036 -0.118  0.236 -0.067 -0.109  0.249  0.28   0.286 -0.297  0.175
  0.24   0.172 -0.237  0.08   0.161 -0.041 -0.161  0.283 -0.006 -0.114
 -0.056 -0.031] (32, 32)
0.0 (32,)
[-0.037  0.001 -0.067 ... -0.044  0.012  0.043] (32, 1024)
0.0 (1024,)
[-0.188 -0.223 -0.267 -0.232  0.072  0.02   0.37  -0.155  0.259 -0.035
 -0.253  0.068 -0.354  0.191  0.033 -0.162 -0.309 -0.256 -0.22  -0.261
 -0.311  0.268  0.203  0.126  0.158  0.358 -0.128 -0.183  0.318 -0.215
  0.289 -0.001] (10, 32)
0.0 (32,)
[ 0.0

In [11]:
status.assert_consumed()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f07001486a0>