# Importing libraries

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import csv
import time
import os
from datetime import datetime

In [None]:
import tensorflow as tf
from tensorflow.keras import regularizers
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)
else:
    print("No compatible GPUs found")


# Common constants

In [None]:
weightDecay = 0.0
momentum = 0.9
learningRate = 0.0004
batch_size = 16
L2penalty = weightDecay/learningRate

In [None]:
convOptions = {
    "strides": 1,
    "padding": 'SAME', 
    "activation": tf.nn.relu,
    "kernel_regularizer": regularizers.l2(L2penalty),
    #"bias_regularizer": regularizers.l2(L2penalty),
    #"use_bias": False, 
    "kernel_initializer": tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.01, seed=None),
    #"bias_initalizer": tf.keras.initializers.Zeros()
}


convTransOptions = {
    "strides": (2,2),
    "padding": 'SAME', 
    "activation": tf.nn.relu,
    "kernel_regularizer": regularizers.l2(L2penalty),
    "bias_regularizer": regularizers.l2(L2penalty),
    "kernel_initializer": tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.01, seed=None),
    #"bias_initalizer": tf.keras.initializers.Zeros()
}

maxPoolOptions = {
    "pool_size": 2,
    "strides": 2,
    "padding": 'SAME'
}

### PC specifics, paths etc.

In [None]:
MASKS_PATH = '/qarr/studia/magister/datasets/FlickrLogos-v2/classes/masks/'
INPUT_PATH = '/qarr/studia/magister/datasets/FlickrLogos-v2/classes/jpg/'
modelSaveName = "batch_normalized"#None # Edit to save the model after training
LOGS_PATH = "/qarr/studia/magister/models/logs/" + modelSaveName + datetime.now().strftime("%Y%m%d-%H%M%S")
MODEL_CHECKPOINT_PATH = "/qarr/studia/magister/models/" + datetime.now().strftime("%Y%m%d-%H%M%S")

# Defining the model architecture

Firstly, clear the session to remove any lingering variables/models in memory from eager execution.

In [None]:
tf.keras.backend.clear_session()

Defining the first part of transcoder - the encoder reducing the dimensions of the input in VGG manner using 2 CNN steps with 3x3 kernels followed with the MaxPool layer.

In [None]:
LInputTarget = tf.keras.Input(dtype = tf.float32, shape = [256, 256, 3], name = 'Target')
layersEncoder = [LInputTarget]
transcoderInputs = []
with tf.name_scope("Encoder"):
    filtersNumber=[64, 64, None, 128, 128, None, 256, 256, None, 512, 512, None, 512, 512, None]
    for fn in filtersNumber:
        if fn is None:
            layersEncoder.append(tf.keras.layers.BatchNormalization()(layersEncoder[-1]))
            layersEncoder.append(tf.keras.layers.MaxPool2D(**maxPoolOptions)(layersEncoder[-1]))
            transcoderInputs.append(layersEncoder[-1])
        else:
            layersEncoder.append(tf.keras.layers.Conv2D(filters = fn, kernel_size=3, **convOptions)(layersEncoder[-1]))
    encoderOutput = layersEncoder[-1]
    
    modelEncoder = tf.keras.Model(
        inputs=LInputTarget, 
        outputs=encoderOutput,
        name="Encoder model"
    )

In [None]:
modelEncoder.summary()

In [None]:
tf.keras.utils.plot_model(modelEncoder, "unet_encoder.png", show_shapes=True)

Defining the conditional branch that translates input querry in VGG-like stack to 1x1x512 dimensions to tile and then corelate with consequent steps of the decoder part.

In [None]:
LInputQuery  = tf.keras.Input(dtype = tf.float32, shape = [64, 64, 3], name = 'Query')
layersConditionalEncoder = [LInputQuery]
with tf.name_scope("Conditional"):
    filtersNumber=[32, 32, None, 64, 64, None, 128, None, 256, None, 512, None]
    for fn in filtersNumber:
        if fn is None:
            layersConditionalEncoder.append(tf.keras.layers.BatchNormalization()(layersConditionalEncoder[-1]))
            layersConditionalEncoder.append(
                tf.keras.layers.MaxPool2D(**maxPoolOptions)(layersConditionalEncoder[-1])
            )
        else:
            layersConditionalEncoder.append(
                tf.keras.layers.Conv2D(filters = fn, kernel_size=3, **convOptions)(layersConditionalEncoder[-1])
            )
    # Todo replace with fully connected x2
    layersConditionalEncoder.append(
                tf.keras.layers.Conv2D(filters = 512, 
                                       kernel_size=2, 
                                       strides=2, 
                                       **{k:v for k,v in convOptions.items() if k != 'strides'}
                                      )(layersConditionalEncoder[-1])
            )
    #layersConditionalEncoder.append(tf.keras.layers.BatchNormalization()(layersConditionalEncoder[-1]))
    conditionalEncoderOutput = layersConditionalEncoder[-1]

    modelConditional = tf.keras.Model(
        inputs=LInputQuery,
        outputs=conditionalEncoderOutput, 
        name="Latent Representation Encoder"
    )

In [None]:
modelConditional.summary()

In [None]:
tf.keras.utils.plot_model(modelConditional, "unet_encoder.png", show_shapes=True)

Defining the decoder branch that combines tiled conditonal branch results with the consequent steps of the encoder part to detect the pattern of conditionally trained branch on different scopes of encoded resolution.

Defining custom Softmax layer, because current 2.3.1 tf implementation has a bug.

In [None]:
from tensorflow.python.keras.utils import tf_utils
class Softmax(tf.keras.layers.Layer):
  def __init__(self, axis=-1, **kwargs):
    super(Softmax, self).__init__(**kwargs)
    self.supports_masking = True
    self.axis = axis

  def call(self, inputs):
    return tf.keras.activations.softmax(inputs, axis=self.axis)

  def get_config(self):
    config = {'axis': self.axis}
    base_config = super(Softmax, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  @tf_utils.shape_type_conversion
  def compute_output_shape(self, input_shape):
    return input_shape

In [None]:
with tf.name_scope("Transcoder"):
    layersTransDecoder = []
    upsampledLayers = []
    tiles = [(8,8), (16,16), (32, 32), (64, 64), (128,128)]
    filters = [(None, 512, 512, 512), 
               (512, 512, 512, 512), 
               (256, 256, 256, 256), 
               (128, 128, 128, 128),
               (64, 64, 64, 64)]
    
    for tile, encodedInput, fs in zip(tiles, reversed(transcoderInputs), filters):
        # Tiling output from conditional encoder
        layersTransDecoder.append(tf.keras.layers.UpSampling2D(size=tile)(conditionalEncoderOutput))
        # Concatenating tiled output with reverse order of encoder MaxPool layers
        layersTransDecoder.append(tf.keras.layers.Concatenate()([layersTransDecoder[-1], encodedInput]))
        # Flattening the concatenation with 1x1 conv if needed and joining with last cycle's result
        if fs[0] is not None:
            layersTransDecoder.append(tf.keras.layers.Conv2D(fs[0], kernel_size=1, **convOptions)(layersTransDecoder[-1]))
            layersTransDecoder.append(tf.keras.layers.Concatenate()([layersTransDecoder[-1], upsampledLayers[-1]]))
        # Transdecoding encoded values with Conv2D layers
        layersTransDecoder.append(tf.keras.layers.Conv2D(fs[1], kernel_size=3, **convOptions)(layersTransDecoder[-1]))
        layersTransDecoder.append(tf.keras.layers.Conv2D(fs[2], kernel_size=3, **convOptions)(layersTransDecoder[-1]))
        
        layersTransDecoder.append(tf.keras.layers.BatchNormalization()(layersTransDecoder[-1]))
        # Upsampling with transposed convolution filters, saving the layer for next cycle merging
        layersTransDecoder.append(tf.keras.layers.Conv2DTranspose(fs[3], kernel_size=3, **convTransOptions)(layersTransDecoder[-1]))
        layersTransDecoder.append(tf.keras.layers.BatchNormalization()(layersTransDecoder[-1]))
        upsampledLayers.append(layersTransDecoder[-1])
    
    layersTransDecoder.append(Softmax(axis=[1,2])(layersTransDecoder[-1])) #Experimental
    unetOutput = tf.keras.layers.Conv2D(filters = 1, kernel_size=3, **convOptions, name="Output")(layersTransDecoder[-1])
    #unetOutput = Softmax(axis=[1,2])(unetOutput)
    layersTransDecoder.append(unetOutput)

    modelUnet = tf.keras.Model(inputs=[LInputTarget, LInputQuery], outputs=[unetOutput], name="Unet")

In [None]:
modelUnet.summary()

In [None]:
tf.keras.utils.plot_model(modelUnet, "unet_model.png", show_shapes=True)

### Compiling the model with SGD optimizer and BinaryCrossentropy loss

Weight decay to be implemented using L2 normalization.

In [None]:
def custom_bce_loss(y_true, y_pred, ratio = 0.5):
    y_true = tf.convert_to_tensor(y_true)
    y_pred = tf.convert_to_tensor(y_pred)
    dtype = y_pred.dtype.base_dtype
    epsilon = tf.keras.backend.epsilon
    shape = y_true.shape
    pixels = np.prod(shape)
    
    mask = tf.greater_equal(y_true, 0.5)
    mask = tf.cast(mask, dtype)
    whites = tf.math.count_nonzero(mask, dtype=dtype) # sum?
    whites_weight = ratio * pixels / (whites + epsilon())
    blacks_weight = (1 - ratio) * pixels / (pixels - whites + epsilon())
    mask = tf.multiply(mask, whites_weight - blacks_weight)
    mask = tf.add(mask, blacks_weight)
    
    # mean over whole batch
    
    y_true = tf.convert_to_tensor(y_true)
    y_pred = tf.convert_to_tensor(y_pred)
    
    #epsilon_ = _constant_to_tensor(epsilon(), y_pred.dtype.base_dtype)
    y_pred = tf.clip_by_value(y_pred, epsilon(), 1. - epsilon())

    # Compute cross entropy from probabilities.
    bce = y_true * tf.math.log(y_pred + epsilon())
    bce += (1 - y_true) * tf.math.log(1 - y_pred + epsilon())
    bce = tf.multiply(bce, mask)
    return -bce

In [None]:
optimizer = tf.keras.optimizers.SGD(learning_rate=learningRate, momentum=momentum, nesterov=False, name="SGD") # weight decay 0.0005 by L2

modelUnet.compile(optimizer=optimizer,
              #custom_bce_loss,
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              #metrics=[tf.keras.metrics.TruePositives(),
              #         tf.keras.metrics.TrueNegatives(),]
                 )

# Preparing the dataset

In [None]:
classes = [o for o in os.listdir(INPUT_PATH) if os.path.isdir(INPUT_PATH + '/' + o)]
classes = [o for o in classes if o != 'no-logo']

In [None]:
print(classes)

## Loading pictures into numpy arrays, slicing and resizing

In [None]:
images = dict()
targets = dict()
queries = dict()
start_time = time.time()

def rescale(nparray, scale=255.0):
    return np.array(nparray, dtype=np.float32)/scale

for c in classes:
    root_input = INPUT_PATH + '/' + c 
    root_masks = MASKS_PATH + '/' + c
    images[c] = list()
    targets[c] = list()
    queries[c] = list()
    
    for f in os.listdir(root_input):
        img = cv2.imread(f'{root_input}/{f}')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(f'{root_masks}/{f}.mask.merged.png', cv2.IMREAD_GRAYSCALE)
        bboxes = []
        
        with open(f'{root_masks}/{f}.bboxes.txt') as csvfile:
            bboxread = csv.reader(csvfile, delimiter=' ')
            next(bboxread)
            for row in bboxread:
                bboxes.append(row)
                
        for bbox in bboxes:
            x,y,w,h = [int(i) for i in bbox]
            imgslice = img[y:y+h, x:x+w]
            imgslice = cv2.resize(imgslice, dsize=(64, 64), interpolation=cv2.INTER_CUBIC)
            queries[c].append(rescale(imgslice, 255.0))
            # Biore tylko pierwszy z dostepnych bbox na obrazku
            break 
            
        img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        mask = cv2.resize(mask, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
    
        images[c].append(rescale(img, 255.0))
        targets[c].append(rescale(mask, 255.0))


end_time = time.time()

print(f'Time taken: {end_time-start_time} seconds')


## Viewing some loaded examples

In [None]:
for v in images.values():
    plt.imshow(v[0])
    break

In [None]:
for v in targets.values():
    plt.imshow(v[0], cmap='gist_gray')
    break

In [None]:
for v in queries.values():
    plt.imshow(v[0])
    break

### Analyzing the dataset size

In [None]:
for c in classes:
    print(f'{c:>12}: {len(images[c])} logos: {len(queries[c]):3<} pairs: {len(images[c])*(len(queries[c])-1)}')
print(f'{"total":<12}: {sum([len(images[c]) for c in classes])} logos: {sum([len(queries[c]) for c in classes])} pairs: {sum([len(images[c])*len(queries[c]) for c in classes])}')

Comparison of size between static dataset generated from triples and dynamicly generated triplets from pairs in memory. Given the size of the triplets it's unfeasible to generate this dataset and save on SDD in order to speed up training.

In [None]:
print(f"{32*(70*69)*(64*64*3+256*256*4)*8/1e9} GB vs {32*(70)*(64*64*3+256*256*4)*8/1e9} GB")

Considering triplets count in the case of taking all samples from given image, not only one per image:
```
      adidas: 70 logos: 120 pairs: 8400
        aldi: 70 logos: 106 pairs: 7420
       apple: 70 logos: 76 pairs: 5320
       becks: 70 logos: 100 pairs: 7000
         bmw: 70 logos: 74 pairs: 5180
   carlsberg: 70 logos: 108 pairs: 7560
      chimay: 70 logos: 112 pairs: 7840
    cocacola: 70 logos: 130 pairs: 9100
      corona: 70 logos: 83 pairs: 5810
         dhl: 70 logos: 123 pairs: 8610
    erdinger: 70 logos: 105 pairs: 7350
        esso: 70 logos: 87 pairs: 6090
       fedex: 70 logos: 94 pairs: 6580
     ferrari: 70 logos: 73 pairs: 5110
        ford: 70 logos: 76 pairs: 5320
     fosters: 70 logos: 98 pairs: 6860
      google: 70 logos: 83 pairs: 5810
     guiness: 70 logos: 98 pairs: 6860
    heineken: 70 logos: 103 pairs: 7210
          hp: 70 logos: 112 pairs: 7840
       milka: 70 logos: 197 pairs: 13790
      nvidia: 70 logos: 114 pairs: 7980
    paulaner: 70 logos: 102 pairs: 7140
       pepsi: 70 logos: 178 pairs: 12460
 rittersport: 70 logos: 204 pairs: 14280
       shell: 70 logos: 96 pairs: 6720
      singha: 70 logos: 83 pairs: 5810
   starbucks: 70 logos: 95 pairs: 6650
stellaartois: 70 logos: 87 pairs: 6090
      texaco: 70 logos: 88 pairs: 6160
    tsingtao: 70 logos: 109 pairs: 7630
         ups: 70 logos: 90 pairs: 6300
total       : 2240 logos: 3404 pairs: 238280
```

# Defining the triplets dataset generator from binary pairs

Additional gymnastics to:
 - fix the random seed to get reproducible results
 - reserve 2 unique images per class and all pairs within those for validation set,
 - fill the validation dataset to meet the assumed ratio to train data

In [None]:
rnd = np.random.RandomState(13371337)

nclasses = len(classes)
nlogos = sum([len(images[c]) for c in classes])//nclasses
all_cases = nclasses*nlogos*(nlogos-1)
valid_cases = ((all_cases//batch_size)//10)*batch_size
valid_unique_n = 2
valid_unique_pairs = nclasses*2*np.sum(range(nlogos-1, nlogos-valid_unique_n-1, -1))

train_data_permutations = np.zeros((all_cases-valid_cases, 3), dtype=np.int8)
valid_data_permutations = np.zeros((valid_cases, 3), dtype=np.int8)

skips = np.sort(rnd.choice(all_cases-valid_cases, size=valid_cases-valid_unique_pairs+1, replace=False))
skips[-1] = all_cases

trainIt = 0
validIt = 0
skipIt = 0
for c_i in range(nclasses):
    valid_unique = rnd.choice(nlogos, size=valid_unique_n, replace=False)
    for n_i in range(nlogos):
        for l_i in range(nlogos):
            if n_i == l_i:
                continue
            if n_i in valid_unique or l_i in valid_unique:
                valid_data_permutations[validIt] = (c_i, n_i, l_i)
                validIt += 1
            elif skips[skipIt] == trainIt:
                valid_data_permutations[validIt] = (c_i, n_i, l_i)
                validIt += 1
                skipIt += 1
            else:
                train_data_permutations[trainIt] = (c_i, n_i, l_i)                
                trainIt += 1
                
train_data_permutations = rnd.permutation(train_data_permutations)
valid_data_permutations = rnd.permutation(valid_data_permutations)

Decimating the training set in order to iterate faster with debugging purposes, to be removed in final approach.

In [None]:
factor = 1
train_data_permutations = train_data_permutations[:(len(train_data_permutations)//factor//batch_size)*batch_size]
valid_data_permutations = valid_data_permutations[:(len(valid_data_permutations)//factor//batch_size)*batch_size]

Viewing the data shape to make sure it is right after the decimation.

In [None]:
def describe(x):
    try:
        return f'{x.shape}'
    except AttributeError:
        return f"{'[' + ', '.join([describe(q) for q in x]) + ']'}"

describe(train_data_permutations)
print(all_cases)

### Defining the generator fuction

In [None]:
def dataset_permutations_generator(batch_size, data_permutations, repeat=True, shuffle=True):
    s = 0
    outimage = []
    outquery = []
    outtarget = []
    loop = True
    while loop:
        if shuffle:
            data_permutations = np.random.permutation(data_permutations)
        for class_number, image_number, query_number in data_permutations:
            c = classes[class_number]
            outimage.append(images[c][image_number])
            outquery.append(queries[c][query_number])
            outtarget.append(targets[c][image_number])
            s += 1
            if s >= batch_size:
                s = 0
                yield (np.reshape(outimage, (batch_size, 256, 256, 3)),
                       np.reshape(outquery, (batch_size, 64, 64, 3))
                      ), np.reshape(outtarget, (batch_size, 256, 256, 1))
                outimage = []
                outquery = []
                outtarget = []
    loop = repeat

### Testing the generator

In [None]:
testgen = dataset_permutations_generator(1, train_data_permutations, shuffle=False)
tdat = next(testgen)

In [None]:
plt.imshow(tdat[0][0][0])

## Defining the validation and training datasets as tf.data.Dataset using generators

I've been forced to use batch_size of 24 rather than 32 due to limitations of memory on my GPU (Nvidia Geforce GTX1070 with 8GB RAM)

In [None]:
#batch_size = 24

In [None]:
unetValidDataset = tf.data.Dataset.from_generator(dataset_permutations_generator,
                                             args=[batch_size, valid_data_permutations],
                                             output_types=((tf.float32, tf.float32), tf.float32),
                                             output_shapes=(((batch_size, 256,256,3), (batch_size, 64,64,3)),
                                                          (batch_size, 256,256,1))
                                            ).prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
unetTrainDataset = tf.data.Dataset.from_generator(dataset_permutations_generator,
                                             args=[batch_size, train_data_permutations],
                                             output_types=((tf.float32, tf.float32), tf.float32),
                                             output_shapes=(((batch_size, 256,256,3), (batch_size, 64,64,3)),
                                                          (batch_size, 256,256,1))
                                            ).prefetch(tf.data.experimental.AUTOTUNE)

# Training the model
Creating callbacks for data visualisation and model saving - enable by uncommenting those in fit method call. Defining custom callback to log gradients on selected example batch and view those in TensorBoard.

In [None]:
logRawGradients = False

class GradientsLoggerTBCallback(tf.keras.callbacks.TensorBoard):
    def __init__(self, gradient_reference, logRawGradients = True, *args, **kwargs):
        super(GradientsLoggerTBCallback, self).__init__(*args, **kwargs)
        self._gradient_ref = gradient_reference
        self.gradient_logs = []
        self.logRawGradients = logRawGradients
        self._epoch = 1
        self.once = True
        
    def _get_gradient(self):
        with tf.GradientTape() as tape:
            y_pred = self.model(self._gradient_ref[0], training=True)  # Forward pass
            loss = self.model.compiled_loss(y_true=self._gradient_ref[1], y_pred=y_pred)

        # Compute gradients
        trainable_vars = self.model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        return gradients

    def _log_gradients(self, epoch):
        # changes in version "2.2.0"
        if tf.version.VERSION.split('.')[1] <= '2':
            writer = self._get_writer(self._train_run_name)
        else:
            writer = self._train_writer
        gradients = self._get_gradient()
        if self.logRawGradients:
            self.gradient_logs.append(gradients)
        
        with writer.as_default():
            # Getting names from model.trainable_weights
            for weights, grads in zip(self.model.trainable_weights, gradients):
                tf.summary.histogram(
                    weights.name.replace(':', '_') + '_grads', data=grads, step=epoch)
            writer.flush()

    def on_epoch_begin(self, epoch, logs=None):
        super(GradientsLoggerTBCallback, self).on_epoch_end(epoch, logs=logs)
        
        self._epoch += 1
        if self.histogram_freq and epoch % self.histogram_freq == 0:
            self._log_gradients(epoch)
    
    #def on_train_batch_end(self, batch, logs=None):
        #if self.histogram_freq and self._epoch % self.histogram_freq == 0 and self.once:
            #print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))
            #print(self.__dir__())
            #print(logs)
            #self.once = False

            
# Create a TensorBoard callback
#try:
tboard_callback = GradientsLoggerTBCallback(list(unetValidDataset.take(1))[0],
                                                 logRawGradients = logRawGradients,
                                                 log_dir = LOGS_PATH,
                                                 histogram_freq = 1)
                                                 #profile_batch = '1,3')
#tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = LOGS_PATH, histogram_freq = 1)
#except AlreadyExistsError:
#    print("Already exists, skipping")

In [None]:
callbackCheckpoint = tf.keras.callbacks.ModelCheckpoint(
    MODEL_CHECKPOINT_PATH,
    monitor="val_loss",
    save_best_only=True,
    save_weights_only=False,
    mode="min",
    save_freq="epoch",
    options=None
)

### Calling the fit method

Adjust epochs parameter before calling.

In [None]:
try:
    modelUnet.fit(unetTrainDataset, 
                  epochs=5, 
                  steps_per_epoch=len(train_data_permutations)//batch_size, 
                  validation_data=unetValidDataset,
                  validation_steps=len(valid_data_permutations)//batch_size,
                  callbacks=[tboard_callback]#, callbackCheckpoint]
                 ) # batch_size unspecified since it's generated by generator
except KeyboardInterrupt as e:
    print("Interrupted")

In [None]:
if modelSaveName is not None:
    modelUnet.save(modelSaveName)

# Viewing and analysing the results 

In [None]:
example = unetTrainDataset.take(1)

In [None]:
example_result = modelUnet.predict(example)
example = list(example.as_numpy_iterator())

In [None]:
subs = plt.subplots(1,4)
subs = subs[0].axes
subs[0].imshow(example[0][0][1][0])
subs[1].imshow(example[0][0][0][0])
subs[2].imshow(np.reshape(example[0][1][0], (256,256)), cmap='gist_gray')
subs[3].imshow(np.reshape(example_result[0], (256,256)), cmap='gist_gray')

Viewing means and variance of gradient after the first epoch

In [None]:
if logRawGradients:
    print(
        '\n'.join(
            [str(a.shape) + "): " + b + "\n gradient mean:" + str(np.mean(a)) + ' variance:' + str(np.var(a))  for a,b in 
                 zip(tboard_callback.gradient_logs[0], 
                     [tw.name for tw in modelUnet.trainable_weights])
            ]))

# Loading the saved net 

In [None]:
if modelSaveName is not None:
    modelUnet = tf.keras.models.load_model(modelSaveName)