## Importing Libraries


In [1]:
import tensorflow as tf
import numpy as np

## Saving and loading from weights 
- Used when model architecture, configurations, optimizer and states is already available.
- If model is saved with model sub classing, then it should be loaded with model subclassing only. It can't be loaded via functional or sequential api.



In [2]:
def data_generator(x,y,batch_size,epochs):
  dataset = tf.data.Dataset.from_tensor_slices((tf.cast(x/255.0,tf.float32),tf.cast(y,tf.int32)))
  dataset = dataset.batch(batch_size)
  dataset = dataset.repeat(epochs)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset

class Basic_model(tf.keras.Model):
  def __init__(self):
    super(Basic_model,self).__init__()
    self.kernel_init = tf.keras.initializers.glorot_normal()
    self.conv = tf.keras.layers.Conv2D(64,3,kernel_initializer=self.kernel_init)
    self.bn = tf.keras.layers.BatchNormalization()
    self.flatten = tf.keras.layers.Flatten()
    self.out = tf.keras.layers.Dense(10,activation='softmax')
  
  def call(self,input_tensor):
    x = self.conv(input_tensor)
    x = self.bn(x)
    x = self.flatten(x)
    x = self.out(x)
    return x

  def compile(self, optimizer,loss):
      super(Basic_model, self).compile()
      self.optimizer = optimizer
      self.loss = loss

  def train_step(self,input_data):
    input,output = input_data
    with tf.GradientTape() as t:
      y_pred = self(input, training = True)
      loss_val = self.loss(output,y_pred)
    grad = t.gradient(loss_val,self.trainable_variables)
    self.optimizer.apply_gradients(zip(grad,self.trainable_variables))
    return {"loss":loss_val}


def main():
  batch_size = 128
  epochs = 1

  (x_train,y_train),(x_test,y_test) = tf.keras.datasets.cifar10.load_data()
  train_dataset = data_generator(x_train,y_train,batch_size,epochs)
  test_dataset = data_generator(x_test,y_test,batch_size,epochs)
  
  model = Basic_model()
  optimizer = tf.keras.optimizers.Adam()
  loss = tf.keras.losses.SparseCategoricalCrossentropy()
  model.compile(optimizer=optimizer, loss=loss)
  model.fit(train_dataset,epochs=epochs,batch_size=batch_size)
  #Saving model weights
  model.save_weights('Basic_model_weights/')
  model.load_weights('Basic_model_weights/')

if __name__ == '__main__':
  main()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


## Saving and loading from checkpoint
- Used when we want to save model in between the training process.


In [3]:
def data_generator(x,y,batch_size,epochs):
  dataset = tf.data.Dataset.from_tensor_slices((tf.cast(x/255.0,tf.float32),tf.cast(y,tf.int32)))
  dataset = dataset.batch(batch_size)
  dataset = dataset.repeat(epochs)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset

class Basic_model(tf.keras.Model):
  def __init__(self):
    super(Basic_model,self).__init__()
    self.kernel_init = tf.keras.initializers.glorot_normal()
    self.conv = tf.keras.layers.Conv2D(64,3,kernel_initializer=self.kernel_init)
    self.bn = tf.keras.layers.BatchNormalization()
    self.flatten = tf.keras.layers.Flatten()
    self.out = tf.keras.layers.Dense(10,activation='softmax')
  
  def call(self,input_tensor):
    x = self.conv(input_tensor)
    x = self.bn(x)
    x = self.flatten(x)
    x = self.out(x)
    return x

  def compile(self, optimizer,loss):
      super(Basic_model, self).compile()
      self.optimizer = optimizer
      self.loss = loss

  def train_step(self,input_data):
    input,output = input_data
    with tf.GradientTape() as t:
      y_pred = self(input, training = True)
      loss_val = self.loss(output,y_pred)
    grad = t.gradient(loss_val,self.trainable_variables)
    self.optimizer.apply_gradients(zip(grad,self.trainable_variables))
    return {"loss":loss_val}


def main():
  batch_size = 128
  epochs = 2

  (x_train,y_train),(x_test,y_test) = tf.keras.datasets.cifar10.load_data()
  train_dataset = data_generator(x_train,y_train,batch_size,epochs)
  test_dataset = data_generator(x_test,y_test,batch_size,epochs)  

  model = Basic_model()
  #For loading the model from latest checkpoint

  optimizer = tf.keras.optimizers.Adam()
  loss = tf.keras.losses.SparseCategoricalCrossentropy()
  model.compile(optimizer=optimizer, loss=loss)

  model_save_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='Basic_model_checkpnt_/'+str(epochs),
                                                             monitor='loss',
                                                             mode='min',
                                                             save_best_only=True,
                                                             save_weights=True
                                                             )
  model.fit(train_dataset,epochs=epochs,batch_size=batch_size,callbacks=[model_save_checkpoint])
  basic_model_latest_checkpoint = tf.train.latest_checkpoint('Basic_model_checkpnt_/')
  model.load_weights(basic_model_latest_checkpoint)


if __name__ == '__main__':
  main()


Epoch 1/2
Epoch 2/2


## Saving and loading from serialization
- Used when we want to whole save model with config, weights, optimizer.


In [4]:
def data_generator(x,y,batch_size,epochs):
  dataset = tf.data.Dataset.from_tensor_slices((tf.cast(x/255.0,tf.float32),tf.cast(y,tf.int32)))
  dataset = dataset.batch(batch_size)
  dataset = dataset.repeat(epochs)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset

class Basic_model(tf.keras.Model):
  def __init__(self):
    super(Basic_model,self).__init__()
    self.kernel_init = tf.keras.initializers.glorot_normal()
    self.conv = tf.keras.layers.Conv2D(64,3,kernel_initializer=self.kernel_init)
    self.bn = tf.keras.layers.BatchNormalization()
    self.flatten = tf.keras.layers.Flatten()
    self.out = tf.keras.layers.Dense(10,activation='softmax')
  
  def call(self,input_tensor):
    x = self.conv(input_tensor)
    x = self.bn(x)
    x = self.flatten(x)
    x = self.out(x)
    return x

  def compile(self, optimizer,loss):
      super(Basic_model, self).compile()
      self.optimizer = optimizer
      self.loss = loss

  def train_step(self,input_data):
    input,output = input_data
    with tf.GradientTape() as t:
      y_pred = self(input, training = True)
      loss_val = self.loss(output,y_pred)
    grad = t.gradient(loss_val,self.trainable_variables)
    self.optimizer.apply_gradients(zip(grad,self.trainable_variables))
    return {"loss":loss_val}
  



def main():
  batch_size = 128
  epochs = 1

  (x_train,y_train),(x_test,y_test) = tf.keras.datasets.cifar10.load_data()
  train_dataset = data_generator(x_train,y_train,batch_size,epochs)
  test_dataset = data_generator(x_test,y_test,batch_size,epochs)  
  
  model = Basic_model()
  optimizer = tf.keras.optimizers.Adam()
  loss = tf.keras.losses.SparseCategoricalCrossentropy()
  model.compile(optimizer=optimizer, loss=loss)
  model.fit(train_dataset,epochs=epochs,batch_size=batch_size)
  model.save('Basic_model_complete/')

  #When loading the model for inference, we just need to load from the saved directory
  new_model = tf.keras.models.load_model('Basic_model_complete/')
  new_model.summary()
  new_model.get_weights()
  print("optimizer",new_model.optimizer)
if __name__ == '__main__':
  main()

INFO:tensorflow:Assets written to: Basic_model_complete/assets
Model: "basic_model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_2 (Conv2D)            multiple                  1792      
_________________________________________________________________
batch_normalization_2 (Batch multiple                  256       
_________________________________________________________________
flatten_2 (Flatten)          multiple                  0         
_________________________________________________________________
dense_2 (Dense)              multiple                  576010    
Total params: 578,058
Trainable params: 577,930
Non-trainable params: 128
_________________________________________________________________
optimizer <tensorflow.python.keras.optimizer_v2.adam.Adam object at 0x7f5cca66e550>
