# CNN-ViT Training
This Notebook illustrates how to build and train a CNN-ViT model. The notebook is configured for running on a TPU hosted runtime on Google Colab.

# Preliminaries

Install required packages.

In [None]:
!pip install git+https://github.com/Microsatellites-and-Space-Microsystems/pose_estimation_domain_gap --quiet

Provide access to Google Drive.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Set network name and directories.

In [None]:
import os

network_name='my_first_CNN_ViT'

#Directories to train and validation datasets
train_dataset_path='gs://.../*.record'
validation_dataset_path='gs://.../*.record'

#Directory for saving trained weights
google_drive_base_dir='/content/gdrive/MyDrive/'
weights_export_dir=google_drive_base_dir+network_name+'.h5'

#Directory for checkpoints
checkpoint_dir = 'gs://.../'+network_name+'/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

Set seeds.

In [None]:
import tensorflow as tf
import numpy as np
import random as rnd

rnd.seed(2)
np.random.seed(3)
tf.random.set_seed(1)

Initialize the TPU.

In [None]:
import tensorflow as tf
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
            
  print('Connection to TPU server successfull!')
            
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.TPUStrategy(tpu)

In [None]:
#A convinent way to provide access to Google Cloud Platform is to create a service account https://cloud.google.com/iam/docs/creating-managing-service-account-keys#iam-service-account-keys-create-console linked to the project
#The procedure will download a .json file 
#Replace the fields below with the information contained in the file

#If using TPU, it is also necessary to enable the TPU service account (service-[project_number]@cloud-tpu.iam.gserviceaccount.com) as an IAM user for the project

import json

data_all={
  "type": "service_account",
  "project_id": ,
  "private_key_id": ,
  "private_key": "-----BEGIN PRIVATE KEY-----\n...==\n-----END PRIVATE KEY-----\n",
  "client_email": "",
  "client_id": "",
  "auth_uri": "https://accounts.google.com/o/oauth2/auth",
  "token_uri": "https://oauth2.googleapis.com/token",
  "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
  "client_x509_cert_url": ""
}

parsed = json.dumps(data_all)

with open('/content/.config/application_default_credentials.json', 'w') as f:
  f.write(parsed)
!gcloud auth activate-service-account --key-file '/content/.config/application_default_credentials.json'

#Alternatively

#!gcloud auth login
#!gcloud config set project 'myproject' #set the project id here

#from google.colab import auth
#auth.authenticate_user()

# Dataset processing



In [None]:
import tensorflow_addons as tfa

#Load TFRecords files
def load_tf_records(filepath):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = True

    filenames = tf.io.gfile.glob(filepath)
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=tf.data.experimental.AUTOTUNE)
    dataset = dataset.with_options(ignore_order)
    
    return dataset


#Define TFRecord structure

def tf_records_file_features_description():
    image_feature_description = {
        'image/actual_channels': tf.io.FixedLenFeature([], tf.int64),
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image/filename': tf.io.FixedLenFeature([], tf.string),
        
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/format': tf.io.FixedLenFeature([], tf.string),

        'image/object/kpts/X_A':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_A':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_B':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_B':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_C':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_C':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_D':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_D':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_E':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_E':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_F':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_F':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_G':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_G':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_H':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_H':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_I':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_I':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_L':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_L':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_M':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_M':tf.io.FixedLenFeature([], tf.float32),

    }
    return image_feature_description

#Decode JPEG and resize (we will cache the output)
def decode_dataset(example_proto,target_image_height,target_image_width):
    features=tf.io.parse_single_example(example_proto, tf_records_file_features_description())
    
    raw_image = tf.io.decode_jpeg(features['image/encoded'],channels=0) #0: Use the number of channels in the JPEG-encoded image.
    image=tf.image.resize(raw_image,
                          [target_image_height,target_image_width],
                          method=tf.image.ResizeMethod.BILINEAR,
                          antialias=False
    )
    
    return image, features

def apply_augmentations(raw_image, features,target_image_height,target_image_width):

    #Recover image features
    image_height=tf.cast(features['image/height'],dtype=tf.float32)
    image_width=tf.cast(features['image/width'],dtype=tf.float32)

    #Principal point
    cx = image_width/2.0
    cy = image_height/2.0

    X_A=features['image/object/kpts/X_A']-cx
    Y_A=features['image/object/kpts/Y_A']-cy
    X_B=features['image/object/kpts/X_B']-cx
    Y_B=features['image/object/kpts/Y_B']-cy
    X_C=features['image/object/kpts/X_C']-cx
    Y_C=features['image/object/kpts/Y_C']-cy
    X_D=features['image/object/kpts/X_D']-cx
    Y_D=features['image/object/kpts/Y_D']-cy
    X_E=features['image/object/kpts/X_E']-cx
    Y_E=features['image/object/kpts/Y_E']-cy
    X_F=features['image/object/kpts/X_F']-cx
    Y_F=features['image/object/kpts/Y_F']-cy
    X_G=features['image/object/kpts/X_G']-cx
    Y_G=features['image/object/kpts/Y_G']-cy
    X_H=features['image/object/kpts/X_H']-cx
    Y_H=features['image/object/kpts/Y_H']-cy
    X_I=features['image/object/kpts/X_I']-cx
    Y_I=features['image/object/kpts/Y_I']-cy
    X_L=features['image/object/kpts/X_L']-cx
    Y_L=features['image/object/kpts/Y_L']-cy
    X_M=features['image/object/kpts/X_M']-cx
    Y_M=features['image/object/kpts/Y_M']-cy

    rotation_angle= tf.random.uniform(
        shape=[], minval=tf.constant(-np.pi), maxval=tf.constant(np.pi),seed=5000
    )
    
    #Rotation matrix
    cos = tf.cos(rotation_angle)
    sin = tf.sin(rotation_angle)
    R=tf.reshape([cos, sin, -sin,cos],[2,2])

    [X_A,Y_A] = rotate_and_normalize_landmarks(R,X_A,Y_A,cx,cy,image_height,image_width)
    [X_B,Y_B] = rotate_and_normalize_landmarks(R,X_B,Y_B,cx,cy,image_height,image_width)
    [X_C,Y_C] = rotate_and_normalize_landmarks(R,X_C,Y_C,cx,cy,image_height,image_width)
    [X_D,Y_D] = rotate_and_normalize_landmarks(R,X_D,Y_D,cx,cy,image_height,image_width)
    [X_E,Y_E] = rotate_and_normalize_landmarks(R,X_E,Y_E,cx,cy,image_height,image_width)
    [X_F,Y_F] = rotate_and_normalize_landmarks(R,X_F,Y_F,cx,cy,image_height,image_width)
    [X_G,Y_G] = rotate_and_normalize_landmarks(R,X_G,Y_G,cx,cy,image_height,image_width)
    [X_H,Y_H] = rotate_and_normalize_landmarks(R,X_H,Y_H,cx,cy,image_height,image_width)
    [X_I,Y_I] = rotate_and_normalize_landmarks(R,X_I,Y_I,cx,cy,image_height,image_width)
    [X_L,Y_L] = rotate_and_normalize_landmarks(R,X_L,Y_L,cx,cy,image_height,image_width)
    [X_M,Y_M] = rotate_and_normalize_landmarks(R,X_M,Y_M,cx,cy,image_height,image_width)

    #Rotate image
    image=tfa.image.rotate(raw_image, rotation_angle)

    #To RGB
    image=tf.image.grayscale_to_rgb(image)

    #Apply pixel level augmentations: edit the function pixel_level_augment
    image = pixel_level_augment(image,target_image_height,target_image_width)
    image = tf.clip_by_value(image,0,255)
    
    #Rescale
    image=(image - 127.00) / 128.00

    image = tf.reshape(image,[target_image_height,target_image_width,3])

    output_kpts = [X_A, Y_A, X_B, Y_B, X_C, Y_C, X_D, Y_D, X_E,Y_E,X_F,Y_F, X_G, Y_G, X_H, Y_H, X_I, Y_I, X_L, Y_L, X_M, Y_M]
    
    return image,  { 'kpts_regressor': output_kpts}

def rotate_and_normalize_landmarks(R,xp,yp,cx,cy,image_height,image_width):
    
    q=tf.tensordot(R,tf.stack([xp,yp]),axes=1)
    xp=q[0]+cx
    yp=q[1]+cy

    xpn=xp/image_width
    ypn=yp/image_height
    
    return xpn, ypn


def pixel_level_augment(image,target_image_height,target_image_width): 

    op1 = tf.random.uniform([ ],maxval=4,dtype=tf.int32, seed=32)
    image = tf.case([(tf.equal(op1,0),lambda: equalize(image)),
                (tf.equal(op1,1),lambda: invert(image)),
                (tf.equal(op1,2),lambda: posterize(image))],
                default=lambda: solarize(image))

    prob_brightness = tf.random.uniform([],minval=0,maxval=1,seed=49)
    image = tf.cond(tf.less(prob_brightness,0.5), lambda: brightness(image, max_delta=0.5), lambda: image)

    prob_contrast = tf.random.uniform([],minval=0,maxval=1,seed=76)
    image = tf.cond(tf.less(prob_contrast,0.5), lambda: contrast(image,0.1,1.5), lambda: image)
    
    prob_blur = tf.random.uniform([],minval=0,maxval=1,seed=37)
    image = tf.cond(tf.less(prob_blur,0.5), lambda: blurring(image, sigma=1), lambda: image)

    prob_noise = tf.random.uniform([],minval=0,maxval=1,seed=42)
    image = tf.cond(tf.less(prob_noise,0.5), lambda: add_gauss_noise(image,target_image_height,target_image_width), lambda: image)
    
    return image

def brightness(image, max_delta):
  return tf.image.random_brightness(image, max_delta=max_delta,seed=1)

def contrast(image, min,max):
  return tf.image.random_contrast(image,min,max,seed=2)

def blurring(image,sigma):
  return tfa.image.gaussian_filter2d(image, sigma)

def noise(image, target_image_height,target_image_width):
  return add_gauss_noise(image,target_image_height,target_image_width)

def invert(image):
  return tf.math.abs(255-image)

def add_gauss_noise(image, target_image_height,target_image_width):
      
      mean = 0
      var = tf.random.uniform([],minval=0, maxval=50,seed=52)
      std = var**0.5
      
      gauss = tf.random.normal([target_image_height,target_image_width,3], mean,std,seed=65)

      noisy = image + gauss
      return noisy

def equalize(image):
  """source: https://github.com/tensorflow/models/blob/master/official/vision/ops/augment.py
  Implements Equalize function from PIL using TF ops."""
  def scale_channel(im, c):
    """Scale the data in the channel to implement equalize."""
    im = tf.cast(im[:, :, c], tf.int32)
    # Compute the histogram of the image channel.
    histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)

    # For the purposes of computing the step, filter out the nonzeros.
    nonzero = tf.where(tf.not_equal(histo, 0))
    nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
    step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255

    def build_lut(histo, step):
      # Compute the cumulative sum, shifting by step // 2
      # and then normalization by step.
      lut = (tf.cumsum(histo) + (step // 2)) // step
      # Shift lut, prepending with 0.
      lut = tf.concat([[0], lut[:-1]], 0)
      # Clip the counts to be in range.  This is done
      # in the C code for image.point.
      return tf.clip_by_value(lut, 0, 255)

    # If step is zero, return the original image.  Otherwise, build
    # lut from the full histogram and step and then index from it.
    result = tf.cond(tf.equal(step, 0),
                     lambda: im,
                     lambda: tf.gather(build_lut(histo, step), im))

    return tf.cast(result, tf.float32)

  # Assumes RGB for now.  Scales each channel independently
  # and then stacks the result.
  s1 = scale_channel(image, 0)
  s2 = scale_channel(image, 1)
  s3 = scale_channel(image, 2)
  image = tf.stack([s1, s2, s3], 2)
  return image

def solarize(image, threshold=128.):
  """source: https://github.com/tensorflow/models/blob/master/official/vision/ops/augment.py"""
  # For each pixel in the image, select the pixel
  # if the value is less than the threshold.
  # Otherwise, subtract 255 from the pixel.
  return tf.where(image < threshold, image, 255. - image)


def posterize(image,bits=2):
  """source: https://github.com/tensorflow/models/blob/master/official/vision/ops/augment.py
  Equivalent of PIL Posterize."""
  image=tf.cast(image,tf.uint8)
  shift = 8 - bits
  return tf.cast(tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift),tf.float32)


def map_validation_dataset(image, features, target_image_height,target_image_width):
    
    image_height=tf.cast(features['image/height'],dtype=tf.float32)
    image_width=tf.cast(features['image/width'],dtype=tf.float32)

    X_A=features['image/object/kpts/X_A']/image_width
    Y_A=features['image/object/kpts/Y_A']/image_height
    X_B=features['image/object/kpts/X_B']/image_width
    Y_B=features['image/object/kpts/Y_B']/image_height
    X_C=features['image/object/kpts/X_C']/image_width
    Y_C=features['image/object/kpts/Y_C']/image_height
    X_D=features['image/object/kpts/X_D']/image_width
    Y_D=features['image/object/kpts/Y_D']/image_height
    X_E=features['image/object/kpts/X_E']/image_width
    Y_E=features['image/object/kpts/Y_E']/image_height
    X_F=features['image/object/kpts/X_F']/image_width
    Y_F=features['image/object/kpts/Y_F']/image_height
    X_G=features['image/object/kpts/X_G']/image_width
    Y_G=features['image/object/kpts/Y_G']/image_height
    X_H=features['image/object/kpts/X_H']/image_width
    Y_H=features['image/object/kpts/Y_H']/image_height
    X_I=features['image/object/kpts/X_I']/image_width
    Y_I=features['image/object/kpts/Y_I']/image_height
    X_L=features['image/object/kpts/X_L']/image_width
    Y_L=features['image/object/kpts/Y_L']/image_height
    X_M=features['image/object/kpts/X_M']/image_width
    Y_M=features['image/object/kpts/Y_M']/image_height

    image=tf.image.grayscale_to_rgb(image)
    image = tf.cast(image, tf.float32)

    image=(image - 127.00) / 128.00
    image = tf.reshape(image, [target_image_height, target_image_width, 3])
    
    output_kpts = [X_A, Y_A, X_B, Y_B, X_C, Y_C, X_D, Y_D, X_E,Y_E,X_F,Y_F, X_G, Y_G, X_H, Y_H, X_I, Y_I, X_L, Y_L, X_M, Y_M]
    
    return image, { 'kpts_regressor': output_kpts}

# (Optional) Visualize the dataset
Use the following cells to visualize the dataset and check that everything is fine.

In [None]:
input_shape=[320, 512, 3]

height = input_shape[0]
width = input_shape[1]

AUTO=tf.data.AUTOTUNE

train_dataset=load_tf_records(train_dataset_path).map(lambda x : decode_dataset(x, height,width), num_parallel_calls=AUTO).map(lambda x,y: apply_augmentations(x,y,height,width),num_parallel_calls=AUTO)

validation_dataset=load_tf_records(validation_dataset_path).map(lambda x: decode_dataset(x, height,width), num_parallel_calls=AUTO).map(lambda x, y: map_validation_dataset(x,y, height,width), num_parallel_calls=AUTO)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
even=np.arange(0,22,2)
odd=np.arange(1,22,2)

for image, label in train_dataset.take(10):
  
  plt.imshow((image*128.0+127.0)/255.0)
  plt.plot(label['kpts_regressor'].numpy()[even]*width,label['kpts_regressor'].numpy()[odd]*height,'.')

  plt.show()
  print(label)

# Model building

Initialize the encoder (EfficientNet backbone + ViT).

In [None]:
from models_and_layers.efficientnet_lite import EfficientNetLiteB4
from models_and_layers.vit_layers import AddPositionEmbs, TransformerBlock

#Code adapted from https://github.com/faustomorales/vit-keras
#Licensed under Apache 2.0 license
#Removed classToken

def build_encoder(
    input_shape=(320, 512, 3),
    patch_size=4,
    num_layers=6,
    hidden_size=256,
    num_heads=8,
    mlp_dim=2048,
    dropout=0.1
):
    """Build transformer encoder.

    Args:
        input_shape: The size of input images.
        patch_size: The size of each patch (must fit evenly in image_size)
        num_layers: The number of transformer layers to use.
        hidden_size: The number of filters to use
        num_heads: The number of transformer heads
        mlp_dim: The number of dimensions for the MLP output in the transformers.
        dropout_rate: fraction of the units to drop for dense layers.
    """
    
    inputlayer=tf.keras.layers.Input(shape=(input_shape[0], input_shape[1], 3))

    model = EfficientNetLiteB4(weights='imagenet', input_shape=(input_shape[0], input_shape[1], 3),include_top=False)(inputlayer)
    #model=tf.keras.models.Model(inputs=model.input,outputs=model.layers[-1].output)(inputlayer)
    #x = tf.keras.layers.Conv2D(64,1)(model)
    #x = tf.keras.layers.Input(shape=(image_size[0], image_size[1], 3))
    y = tf.keras.layers.Conv2D(
        filters=hidden_size,
        kernel_size=patch_size,
        strides=patch_size,
        padding="valid",
        name="embedding",
        kernel_initializer=tf.keras.initializers.GlorotUniform(seed=1116),
    )(model)
    y = tf.keras.layers.Reshape((y.shape[1] * y.shape[2], hidden_size))(y)

    y = AddPositionEmbs(name="Transformer/posembed_input")(y)
    for n in range(num_layers):
        y, _ = TransformerBlock(
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            dropout=dropout,
            name=f"Transformer/encoderblock_{n}",
        )(y)
    y = tf.keras.layers.LayerNormalization(
        epsilon=1e-6, name="Transformer/encoder_norm"
    )(y)

    y=tf.keras.layers.GlobalAveragePooling1D()(y)
    #y=tf.keras.layers.Flatten()(y)
    
    return tf.keras.models.Model(inputs=inputlayer, outputs=y)

Initialize regression head.

In [None]:
class kpts_regressor(tf.keras.Model):
  def __init__(self,hidden_dim,num_keypoints):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_keypoints = num_keypoints
        self.basic_layers = tf.keras.Sequential(
            [tf.keras.layers.Dropout(0.1,seed=43),
             tf.keras.layers.Dense(self.hidden_dim,activation='gelu',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=9001)),
          tf.keras.layers.Dropout(0.1,seed=819),
          tf.keras.layers.Dense(self.hidden_dim/2,activation='gelu',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=901)),
          tf.keras.layers.Dense(22,kernel_initializer=tf.keras.initializers.GlorotUniform(seed=976),name='kpts'),
         ]
        )
  def call(self, x):
    x = self.basic_layers(x)
    return x

Build the model.

In [None]:
#Vit tiny:
hidden_dim=192
num_keypoints = 11
input_shape=[320, 512, 3]
inputlayer=tf.keras.layers.Input(shape=(input_shape[0], input_shape[1], 3))

with tpu_strategy.scope(): 
  encoder=test=build_encoder(input_shape=(320, 512, 3),
    patch_size=1,
    num_layers=1,
    hidden_size=hidden_dim,
    num_heads=3,
    mlp_dim=hidden_dim*3,
    dropout=0.1
  )(inputlayer)
  encoder=tf.keras.models.Model([inputlayer], [encoder])
  regressor_kpts = kpts_regressor(hidden_dim,num_keypoints)(encoder.output)
  network=tf.keras.models.Model([encoder.input], [regressor_kpts])

Visualize NN details.

In [None]:
network.summary()

In [None]:
tf.keras.utils.plot_model(network,show_shapes=True)

# Train the model

Dataset preprocessing.

In [None]:
batch_size=64
epochs = 80

height = input_shape[0]
width = input_shape[1]
AUTO=tf.data.AUTOTUNE

# Train dataset preparation

all_train_record=load_tf_records(train_dataset_path).map(lambda x : decode_dataset(x, height,width), num_parallel_calls=AUTO).cache().shuffle(15000,seed=29).map(lambda x,y: apply_augmentations(x,y,height,width),num_parallel_calls=AUTO)

train_dataset = all_train_record.batch(batch_size,drop_remainder=True).repeat()

test_dataset=load_tf_records(validation_dataset_path).map(lambda x: decode_dataset(x, height,width), num_parallel_calls=AUTO).map(lambda x, y: map_validation_dataset(x,y, height,width), num_parallel_calls=AUTO).batch(batch_size,drop_remainder=True).cache().repeat().prefetch(AUTO)


steps_per_epoch=np.round(47966//batch_size)
validation_steps=np.round(2791//batch_size) #sunlamp images

Compile the model.

In [None]:
total_steps = steps_per_epoch*epochs
base_lr=1e-4
with tpu_strategy.scope(): 

  optimizer=tfa.optimizers.AdamW(weight_decay=1e-8,
      learning_rate=tf.keras.optimizers.schedules.CosineDecay(base_lr, total_steps)
  )

  losses={"kpts_regressor": 'MAE',
                }
  network.compile(optimizer=optimizer,
                  loss=losses,
  )

#Callbacks

#Learning rate callback
logger=tf.get_logger()
class LearningRateLoggingCallback(tf.keras.callbacks.Callback):

    def on_epoch_end(self, epoch,logs={}):
        lr = self.model.optimizer._decayed_lr(tf.float32)
        logger.info("lr value = %s" % lr)

#Backup and restore callback
backup_and_restore_callback=tf.keras.callbacks.BackupAndRestore(
    backup_dir=checkpoint_dir
)

In [None]:
network.fit(train_dataset,
        validation_data=test_dataset,
        epochs=epochs ,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        callbacks=[LearningRateLoggingCallback(),backup_and_restore_callback],
        verbose=2
       )

network.save_weights(weights_export_dir)