# Adversarial Training LRNs

This code illustrates our method for adversarial training of LRNs. The notebook is configured for running on a TPU hosted runtime on Google Colab.

# Preliminaries

Install required packages.

In [None]:
!pip install git+https://github.com/Microsatellites-and-Space-Microsystems/pose_estimation_domain_gap --quiet

Provide access to Google Drive.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Set network name and dataset directories.

In [None]:
import os

network_name='my_first_LRN'

#Directories to train and validation datasets
train_dataset_path='gs://.../*.record'
validation_dataset_path='gs://.../*.record'

#Directory for saving trained weights
google_drive_base_dir='/content/gdrive/MyDrive/'
weights_export_dir=google_drive_base_dir+network_name+'.h5'

#Directory for checkpoints
checkpoint_dir = 'gs://.../'+network_name+'/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

Set seeds.

In [None]:
import tensorflow as tf
import numpy as np
import random as rnd

rnd.seed(242)
np.random.seed(312)
tf.random.set_seed(112)

Initialize the TPU.

In [None]:
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
            
  print('Connection to TPU server successfull!')
            
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.TPUStrategy(tpu)

To train the NNs with Cloud TPUs, the dataset must be stored in a Cloud Bucket. Then it is necessary to give the TPU access to the Bucket.

In [None]:
#A convinent way to provide access to Google Cloud Platform is to create a service account https://cloud.google.com/iam/docs/creating-managing-service-account-keys#iam-service-account-keys-create-console linked to the project
#The procedure will download a .json file 
#Replace the fields below with the information contained in the file

#If using TPU, it is also necessary to enable the TPU service account (service-[project_number]@cloud-tpu.iam.gserviceaccount.com) as an IAM user for the project

import json

data_all={
  "type": "service_account",
  "project_id": ,
  "private_key_id": ,
  "private_key": "-----BEGIN PRIVATE KEY-----\n...==\n-----END PRIVATE KEY-----\n",
  "client_email": "",
  "client_id": "",
  "auth_uri": "https://accounts.google.com/o/oauth2/auth",
  "token_uri": "https://oauth2.googleapis.com/token",
  "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
  "client_x509_cert_url": ""
}

parsed = json.dumps(data_all)

with open('/content/.config/application_default_credentials.json', 'w') as f:
  f.write(parsed)
!gcloud auth activate-service-account --key-file '/content/.config/application_default_credentials.json'

#Alternatively

#!gcloud auth login
#!gcloud config set project 'myproject' #set the project id here

#from google.colab import auth
#auth.authenticate_user()

# Initialize LRN Swin based

Initialize the NN encoder. To try different backbones modify the imported model in the first line of the following cell.

In [None]:
from models_and_layers.tfswin import SwinTransformerTiny224 as transformerEncoder
with tpu_strategy.scope(): 

  def get_encoder(input_shape):
    
    input = tf.keras.layers.Input(shape=(input_shape, input_shape, 3))
    model=transformerEncoder(include_top=False)(input)
    x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(model)
    model = tf.keras.models.Model(inputs=input, outputs=x)
    
    return model

Initialize discriminator and regressor heads.

In [None]:
class Discriminator(tf.keras.Model):
  def __init__(self,hidden_dim):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.basic_layers = tf.keras.Sequential(
          [tf.keras.layers.Dense(self.hidden_dim*4,activation='gelu',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=1509)),
          tf.keras.layers.Dense(self.hidden_dim,activation='gelu',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=9)),
          tf.keras.layers.Dense(1,name='cls',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=3412)),
          ]
        )
  def call(self, x):
    x = self.basic_layers(x)
    return x
  
class Regressor(tf.keras.Model):
  def __init__(self,hidden_dim,num_keypoints):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_keypoints = num_keypoints
        self.basic_layers = tf.keras.Sequential(
        [tf.keras.layers.Dense(hidden_dim*4,activation='gelu',name='kpts1',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=121)),
         tf.keras.layers.Dense(hidden_dim,activation='gelu',name='kpts2',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=432)),
         tf.keras.layers.Dense(num_keypoints*2,activation='linear',name='kpts',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=3454)),
         ]
        )
  def call(self, x):
    x = self.basic_layers(x)
    return x

Build the model. 

In [None]:
input_shape = 224  #Assumed squared
hidden_dim=768     #Output size of the encoder
num_keypoints = 11 #Satellite's keypoints

with tpu_strategy.scope(): 
  encoder=get_encoder(input_shape)
  discriminator = Discriminator(hidden_dim)(encoder.output)
  regressor=Regressor(hidden_dim,num_keypoints)(encoder.output)
  network=tf.keras.models.Model([encoder.input], [discriminator,regressor])



Visualize the NN details.

In [None]:
network.summary()

In [None]:
tf.keras.utils.plot_model(network,show_shapes=True)

# Initialize LRN EfficientNet based

In [None]:
from models_and_layers.efficientnet import EfficientNetV1B5

class get_encoder(tf.keras.Model):
  def __init__(self,hidden_dim,input_shape):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_keypoints = num_keypoints
        self.basic_layers = tf.keras.Sequential([
    EfficientNetV1B5(num_classes=0,input_shape=(input_shape,input_shape,3),pretrained="imagenet"),
    tf.keras.layers.Conv2D(self.hidden_dim,1,kernel_initializer=tf.keras.initializers.GlorotUniform(seed=231)),
    tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')])
         
  def call(self, x):
    x = self.basic_layers(x)
    return x

In [None]:
class Discriminator(tf.keras.Model):
  def __init__(self,hidden_dim):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.basic_layers = tf.keras.Sequential(
          [tf.keras.layers.Dense(self.hidden_dim*4,activation='gelu',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=1509)),
          tf.keras.layers.Dense(self.hidden_dim,activation='gelu',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=9)),
          tf.keras.layers.Dense(1,name='cls',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=3412)),
          ]
        )
  def call(self, x):
    x = self.basic_layers(x)
    return x
  
class Regressor(tf.keras.Model):
  def __init__(self,hidden_dim,num_keypoints):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_keypoints = num_keypoints
        self.basic_layers = tf.keras.Sequential(
        [tf.keras.layers.Dense(hidden_dim*4,activation='gelu',name='kpts1',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=121)),
         tf.keras.layers.Dense(hidden_dim,activation='gelu',name='kpts2',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=432)),
         tf.keras.layers.Dense(num_keypoints*2,activation='linear',name='kpts',kernel_initializer=tf.keras.initializers.GlorotUniform(seed=3454)),
         ]
        )
  def call(self, x):
    x = self.basic_layers(x)
    return x

In [None]:
# Build the model:

input_shape = 224  #Assumed squared
hidden_dim=768     #Output size of the encoder
num_keypoints = 11 #Satellite's keypoints
input=tf.keras.layers.Input(shape=(input_shape, input_shape, 3))

with tpu_strategy.scope(): 
  encoder=get_encoder(hidden_dim,input_shape)
  discriminator = Discriminator(hidden_dim)(encoder(input))
  regressor=Regressor(hidden_dim,num_keypoints)(encoder(input))
  network=tf.keras.models.Model([encoder.input], [discriminator,regressor])


Visualize NN details.

In [None]:
network.summary()

In [None]:
tf.keras.utils.plot_model(network,show_shapes=True)

# Dataset processing
The following cells contain all the functions to preprocess the dataset.

In [None]:
import tensorflow_addons as tfa

#Load TFRecords files

def load_tf_records(filepath):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = True

    filenames = tf.io.gfile.glob(filepath)
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=tf.data.experimental.AUTOTUNE)
    dataset = dataset.with_options(ignore_order)
    
    return dataset

#Define TFRecord structure

def tf_records_file_features_description():
    image_feature_description = {
        'image/actual_channels': tf.io.FixedLenFeature([], tf.int64),
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/dataset_class': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image/filename': tf.io.FixedLenFeature([], tf.string),
        
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/format': tf.io.FixedLenFeature([], tf.string),


        'image/object/bbox/xmin':tf.io.FixedLenFeature([], tf.float32),
        'image/object/bbox/xmax':tf.io.FixedLenFeature([], tf.float32),
        'image/object/bbox/ymin':tf.io.FixedLenFeature([], tf.float32),
        'image/object/bbox/ymax':tf.io.FixedLenFeature([], tf.float32),

        'image/object/kpts/X_A':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_A':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_B':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_B':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_C':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_C':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_D':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_D':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_E':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_E':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_F':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_F':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_G':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_G':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_H':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_H':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_I':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_I':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_L':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_L':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/X_M':tf.io.FixedLenFeature([], tf.float32),
        'image/object/kpts/Y_M':tf.io.FixedLenFeature([], tf.float32),
    }
    return image_feature_description

#Decode JPEG first (we will cache the output)

def decode_dataset(example_proto,image_size):
    features=tf.io.parse_single_example(example_proto, tf_records_file_features_description())
    
    raw_image = tf.io.decode_jpeg(features['image/encoded'],channels=0) #0: Use the number of channels in the JPEG-encoded image.
    
    return raw_image, features

#Apply augmentations during training

def apply_augmentations(raw_image, features,target_image_size):

    #Recover image features
    image_height=tf.cast(features['image/height'],dtype=tf.float32)
    image_width=tf.cast(features['image/width'],dtype=tf.float32)
    dataset_class=tf.cast(features['image/dataset_class'],dtype=tf.float32)

    xmin=features['image/object/bbox/xmin']
    ymin=features['image/object/bbox/ymin']
    xmax=features['image/object/bbox/xmax']
    ymax=features['image/object/bbox/ymax']

    #Principal point
    cx=image_width/2.0
    cy=image_height/2.0

    #The image will be rotated wrt the center point 
    X_A=features['image/object/kpts/X_A']-cx
    Y_A=features['image/object/kpts/Y_A']-cy
    X_B=features['image/object/kpts/X_B']-cx
    Y_B=features['image/object/kpts/Y_B']-cy
    X_C=features['image/object/kpts/X_C']-cx
    Y_C=features['image/object/kpts/Y_C']-cy
    X_D=features['image/object/kpts/X_D']-cx
    Y_D=features['image/object/kpts/Y_D']-cy
    X_E=features['image/object/kpts/X_E']-cx
    Y_E=features['image/object/kpts/Y_E']-cy
    X_F=features['image/object/kpts/X_F']-cx
    Y_F=features['image/object/kpts/Y_F']-cy
    X_G=features['image/object/kpts/X_G']-cx
    Y_G=features['image/object/kpts/Y_G']-cy
    X_H=features['image/object/kpts/X_H']-cx
    Y_H=features['image/object/kpts/Y_H']-cy
    X_I=features['image/object/kpts/X_I']-cx
    Y_I=features['image/object/kpts/Y_I']-cy
    X_L=features['image/object/kpts/X_L']-cx
    Y_L=features['image/object/kpts/Y_L']-cy
    X_M=features['image/object/kpts/X_M']-cx
    Y_M=features['image/object/kpts/Y_M']-cy

    #Random rotation angle
    rotation_angle= tf.random.uniform(
        shape=[], minval=tf.constant(-np.pi), maxval=tf.constant(np.pi),seed=5000
    )
    
    #Rotation matrix
    cos = tf.cos(rotation_angle)
    sin = tf.sin(rotation_angle)
    R=tf.reshape([cos, sin, -sin,cos],[2,2])
    
    #Rotate the bounding box
    q=tf.matmul(R,tf.reshape([xmin-cx,xmin-cx,xmax-cx,xmax-cx,
                              ymin-cy,ymax-cy,ymin-cy,ymax-cy],[2,4]))

    #tl = top left, bl = bottom left, tr = top right, br = bottom right
    xtl=q[0,0]
    ytl=q[1,0]
    xbl=q[0,1]
    ybl=q[1,1]
    xtr=q[0,2]
    ytr=q[1,2]
    xbr=q[0,3]
    ybr=q[1,3]
    
    #Recover rotated bbox coordinates in original image frame
    
    xmin_rotated=tf.reduce_min([xtl,xbl,xtr,xbr])+cx
    xmax_rotated=tf.reduce_max([xtl,xbl,xtr,xbr])+cx
    ymin_rotated=tf.reduce_min([ytl,ybl,ytr,ybr])+cy
    ymax_rotated=tf.reduce_max([ytl,ybl,ytr,ybr])+cy      

    #Clip the values between 0 and the original image dimensions
    xmin_rotated=tf.maximum(xmin_rotated,0.0)
    xmax_rotated=tf.minimum(xmax_rotated,image_width)
    ymin_rotated=tf.maximum(ymin_rotated,0.0)
    ymax_rotated=tf.minimum(ymax_rotated,image_height)

    #Reconstruct the rotated (and clipped) GT bbox
    #Width, height and center coordinates
    gt_w=xmax_rotated-xmin_rotated
    gt_h=ymax_rotated-ymin_rotated
    x_c=(xmax_rotated+xmin_rotated)/2.0
    y_c=(ymax_rotated+ymin_rotated)/2.0

    #The cropping region is made square to avoid image distortion
    #In general we take the largest size
    bbox_side=tf.cond(tf.greater(gt_w,gt_h),lambda: gt_w, lambda: gt_h)
    
    #Consider the case where the bbox is greater than the image height
    bbox_side=tf.cond(tf.greater(bbox_side,image_height),lambda: image_height,lambda: bbox_side)
    y_c=tf.cond(tf.equal(bbox_side,image_height),lambda: cy, lambda: y_c)
    
    #Move the bbox vertices
    ymin_rotated=tf.cond(tf.equal(bbox_side,image_height),lambda: y_c-image_height/2,lambda: ymin_rotated)
    ymax_rotated=tf.cond(tf.equal(bbox_side,image_height),lambda: y_c+image_height/2,lambda: ymax_rotated)
    xmin_rotated=tf.cond(tf.equal(bbox_side,image_height),lambda: x_c-image_height/2,lambda: xmin_rotated)
    xmax_rotated=tf.cond(tf.equal(bbox_side,image_height),lambda: x_c+image_height/2,lambda: xmax_rotated)

    #The bounding box is randomly re-scaled and translated
    [xmin_new, ymin_new, bbox_side]=scale_and_translate(xmin_rotated,xmax_rotated,ymin_rotated,ymax_rotated,bbox_side,x_c,y_c,target_image_size,image_height,image_width)
    
    #Crop to bounding box requires integers
    ymin_new=tf.floor(ymin_new)
    xmin_new=tf.floor(xmin_new)
    bbox_side=tf.floor(bbox_side)
    
    bbox_side=tf.cond(tf.greater(ymin_new+bbox_side,image_height),lambda: tf.floor(image_height-ymin_new), lambda: bbox_side)
    bbox_side=tf.cond(tf.greater(xmin_new+bbox_side,image_width),lambda: tf.floor(image_width-xmin_new), lambda: bbox_side)
    
    #Rotate image
    image=tfa.image.rotate(raw_image, rotation_angle)

    #Crop image
    image=tf.image.crop_to_bounding_box(image,
                                        tf.cast(ymin_new,tf.int32),
                                        tf.cast(xmin_new,tf.int32),
                                        tf.cast(bbox_side,tf.int32),
                                        tf.cast(bbox_side,tf.int32),
                                        )
    
    #Resize image to fit CNN requirements
    image=tf.image.resize(image,
                          [target_image_size,target_image_size],
                          method=tf.image.ResizeMethod.BILINEAR,
                          antialias=False
    )
    
   
    #Rotation of all keypoints and coordinates normalization
    [X_A,Y_A] = rotate_and_normalize_landmarks(R,X_A,Y_A,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_B,Y_B] = rotate_and_normalize_landmarks(R,X_B,Y_B,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_C,Y_C] = rotate_and_normalize_landmarks(R,X_C,Y_C,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_D,Y_D] = rotate_and_normalize_landmarks(R,X_D,Y_D,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_E,Y_E] = rotate_and_normalize_landmarks(R,X_E,Y_E,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_F,Y_F] = rotate_and_normalize_landmarks(R,X_F,Y_F,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_G,Y_G] = rotate_and_normalize_landmarks(R,X_G,Y_G,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_H,Y_H] = rotate_and_normalize_landmarks(R,X_H,Y_H,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_I,Y_I] = rotate_and_normalize_landmarks(R,X_I,Y_I,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_L,Y_L] = rotate_and_normalize_landmarks(R,X_L,Y_L,cx,cy,xmin_new,ymin_new,bbox_side)
    [X_M,Y_M] = rotate_and_normalize_landmarks(R,X_M,Y_M,cx,cy,xmin_new,ymin_new,bbox_side)
    
    #We will use the last entry of this vector to interleave the dataset
    output_data = [X_A, Y_A, X_B, Y_B, X_C, Y_C, X_D, Y_D, X_E,Y_E,X_F,Y_F, X_G, Y_G, X_H, Y_H, X_I, Y_I, X_L, Y_L, X_M, Y_M, dataset_class]
    output_data = tf.reshape(output_data,[23])

    #Further augmentations: edit the function pixel_level_augment
    #Augment only synthetic images
    image = tf.cond(tf.equal(dataset_class,1.0), lambda: pixel_level_augment(image,target_image_size), lambda: image)

    #Augment all images
    #image = pixel_level_augment(image,target_image_size)

    image = tf.clip_by_value(image,0,255)
    
    #To RGB
    image=tf.image.grayscale_to_rgb(image)

    #Rescale
    image = (image / 127.5) - 1.0

    image = tf.reshape(image, [target_image_size, target_image_size, 3])

    #We make no difference here between sunlamp and lightbox, we consider only synthetic / real
    dataset_class=tf.cond(tf.equal(dataset_class,1.0), lambda: 0.0, lambda: 1.0)

    return image, {'discriminator': dataset_class, 'regressor': output_data}

def rotate_and_normalize_landmarks(R,xp,yp,cx,cy,xmin_new, ymin_new,bbox_side):

    q=tf.tensordot(R,tf.stack([xp,yp]),axes=1)
    xp=q[0]+cx-xmin_new
    yp=q[1]+cy-ymin_new

    xpn=xp/bbox_side
    ypn=yp/bbox_side
    
    return xpn, ypn

def scale_and_translate(xmin_rotated,xmax_rotated,ymin_rotated,ymax_rotated,bbox_side_old,x_c,y_c,target_image_size,image_height,image_width):

    #Random bbox scaling factor 0-15%. In case the bounding box is already as large as the image height we avoid increasing it further
    #also, the maximum scaling factor should not lead to a bbox greater than the image height.
    #If bbox size il equal to image height, ymin is already at 0 and ymax is at image height
    ss = tf.random.uniform( shape=[], minval=1.00, maxval=tf.minimum(image_height/bbox_side_old,1.15),seed=1089 )
    scaling_factor=tf.cond(tf.equal(bbox_side_old,image_height), lambda:1.00, lambda: ss)
      
    bbox_side_new=bbox_side_old*scaling_factor
    
    #Enlarge bbox if smaller than target image size
    bbox_side_new=tf.cond(tf.less(bbox_side_new,target_image_size),lambda: tf.constant(target_image_size,dtype=tf.float32),lambda: bbox_side_new)

    #Default translation is between -0.1 and +0.1, but we limit shifting values to ensure xmin and xmax are within the image dimensions
    h_default=tf.random.uniform(shape=[], minval=tf.reduce_max([-0.1,0.5-x_c/bbox_side_new,(xmax_rotated-x_c)/bbox_side_new-0.5]),
                                                                          maxval=tf.reduce_min([0.1, (image_width-x_c)/bbox_side_new-0.5,(xmin_rotated-x_c)/bbox_side_new+0.5]),seed=15)

    h_shift_factor=tf.case([(tf.equal(scaling_factor,1.00),lambda:0.0),
                              (tf.equal(xmin_rotated,0.0),lambda: 0.5-x_c/bbox_side_new),
                              (tf.equal(xmax_rotated,image_width),lambda: (image_width-x_c)/bbox_side_new-0.5)],
                              default=lambda: h_default)
    
    v_default=tf.random.uniform(shape=[], minval=tf.reduce_max([-0.1,0.5-y_c/bbox_side_new,(ymax_rotated-y_c)/bbox_side_new-0.5]),
                                                                          maxval=tf.reduce_min([0.1, (image_height-y_c)/bbox_side_new-0.5,(ymin_rotated-y_c)/bbox_side_new+0.5]),seed=24)

    v_shift_factor=tf.case([(tf.logical_and(tf.equal(ymax_rotated,image_height),tf.equal(ymin_rotated,0.0)),lambda: 0.0),
                              (tf.equal(ymin_rotated,0.0),lambda: 0.5-y_c/bbox_side_new),
                              (tf.equal(ymax_rotated,image_height),lambda: (image_height-y_c)/bbox_side_new-0.5)],
                              default=lambda: v_default)
                              
    xmin_new=(x_c-bbox_side_new/2.0+h_shift_factor*bbox_side_new)
    ymin_new=(y_c-bbox_side_new/2.0+v_shift_factor*bbox_side_new)
    xmax_new=(x_c+bbox_side_new/2.0+h_shift_factor*bbox_side_new)
    ymax_new=(y_c+bbox_side_new/2.0+v_shift_factor*bbox_side_new)
    
    xmin_new=tf.maximum(xmin_new,0.0)
    ymin_new=tf.maximum(ymin_new,0.0)
    
    return xmin_new, ymin_new, bbox_side_new


def pixel_level_augment(image,target_image_size):
    prob_brightness = tf.random.uniform([],minval=0,maxval=1,seed=432)
    image = tf.cond(tf.less(prob_brightness,0.5), lambda: tf.image.random_brightness(image, max_delta=0.2,seed=1), lambda: image)

    prob_contrast = tf.random.uniform([],minval=0,maxval=1,seed=6736)
    image = tf.cond(tf.less(prob_contrast,0.5), lambda: tf.image.random_contrast(image,0.1,1.5,seed=2), lambda: image)

    prob_blur = tf.random.uniform([],minval=0,maxval=1,seed=787)
    image = tf.cond(tf.less(prob_blur,0.5), lambda: tfa.image.gaussian_filter2d(image, sigma=1), lambda: image)

    prob_noise = tf.random.uniform([],minval=0,maxval=1,seed=782)
    image = tf.cond(tf.less(prob_noise,0.5), lambda: add_gauss_noise(image,target_image_size), lambda: image)

    return image

def add_gauss_noise(image,target_image_size):
      
      mean = 0
      std = 0.047
      
      gauss = tf.random.normal([target_image_size,target_image_size,1], mean,std,seed=957)

      noisy = image + gauss
      return noisy

def map_validation_dataset(image, features, target_image_size):

    dataset_class=tf.cast(features['image/dataset_class'],dtype=tf.float32)
    image_height=tf.cast(features['image/height'],dtype=tf.float32)
    image_width=tf.cast(features['image/width'],dtype=tf.float32)

    xmin=features['image/object/bbox/xmin']
    ymin=features['image/object/bbox/ymin']
    xmax=features['image/object/bbox/xmax']
    ymax=features['image/object/bbox/ymax']

    #Get bbox center
    xc=(xmax+xmin)/2
    yc=(ymax+ymin)/2

    #Enlarge bbox by 15%
    bbox_w=tf.reduce_max([(xmax-xmin)*1.15,target_image_size])
    bbox_h=tf.reduce_max([(ymax-ymin)*1.15, target_image_size])

    #Get bbox upper vertex
    xmin=xc-bbox_w/2
    ymin=yc-bbox_h/2
    
    #Clip to image borders
    xmin=tf.reduce_max([xmin,0])
    ymin=tf.reduce_max([ymin,0])

    xmin=tf.math.floor(xmin)
    ymin=tf.math.floor(ymin)

    bbox_h=tf.reduce_min([bbox_h,image_height-ymin])
    bbox_w=tf.reduce_min([bbox_w,image_width-xmin])
    bbox_h=tf.math.floor(bbox_h)
    bbox_w=tf.math.floor(bbox_w)

    cropped_img = tf.image.crop_to_bounding_box(image,
                                        tf.cast(ymin,tf.int32),
                                        tf.cast(xmin,tf.int32),
                                        tf.cast(bbox_h,tf.int32),
                                        tf.cast(bbox_w,tf.int32),
                                        )
    
    cropped_img_shape = tf.shape(cropped_img);

    rows=tf.cast(cropped_img_shape[0],tf.int32)
    cols=tf.cast(cropped_img_shape[1],tf.int32)

    [cropped_img,xmin,ymin] = tf.cond(tf.math.less(rows,cols), lambda: pad_rows(cropped_img,cols,rows,xmin,ymin), lambda: pad_cols(cropped_img,cols,rows,xmin,ymin))
    
    image=tf.image.resize(cropped_img,
                          [target_image_size,target_image_size],
                          method=tf.image.ResizeMethod.BILINEAR,
                          antialias=False
    )

    image= tf.image.grayscale_to_rgb(image)
    
    image = tf.reshape(image, [target_image_size, target_image_size,3])

    bbox_side=tf.cast(tf.reduce_max(cropped_img_shape),tf.float32)
    xmin=tf.cast(xmin,tf.float32)
    ymin=tf.cast(ymin,tf.float32)

    X_A=(features['image/object/kpts/X_A']-xmin)/bbox_side
    Y_A=(features['image/object/kpts/Y_A']-ymin)/bbox_side
    X_B=(features['image/object/kpts/X_B']-xmin)/bbox_side
    Y_B=(features['image/object/kpts/Y_B']-ymin)/bbox_side
    X_C=(features['image/object/kpts/X_C']-xmin)/bbox_side
    Y_C=(features['image/object/kpts/Y_C']-ymin)/bbox_side
    X_D=(features['image/object/kpts/X_D']-xmin)/bbox_side
    Y_D=(features['image/object/kpts/Y_D']-ymin)/bbox_side
    X_E=(features['image/object/kpts/X_E']-xmin)/bbox_side
    Y_E=(features['image/object/kpts/Y_E']-ymin)/bbox_side
    X_F=(features['image/object/kpts/X_F']-xmin)/bbox_side
    Y_F=(features['image/object/kpts/Y_F']-ymin)/bbox_side
    X_G=(features['image/object/kpts/X_G']-xmin)/bbox_side
    Y_G=(features['image/object/kpts/Y_G']-ymin)/bbox_side
    X_H=(features['image/object/kpts/X_H']-xmin)/bbox_side
    Y_H=(features['image/object/kpts/Y_H']-ymin)/bbox_side
    X_I=(features['image/object/kpts/X_I']-xmin)/bbox_side
    Y_I=(features['image/object/kpts/Y_I']-ymin)/bbox_side
    X_L=(features['image/object/kpts/X_L']-xmin)/bbox_side
    Y_L=(features['image/object/kpts/Y_L']-ymin)/bbox_side
    X_M=(features['image/object/kpts/X_M']-xmin)/bbox_side
    Y_M=(features['image/object/kpts/Y_M']-ymin)/bbox_side

    cx=image_width/2.0
    cy=image_height/2.0
    

    output_data = [X_A, Y_A, X_B, Y_B, X_C, Y_C, X_D, Y_D, X_E,Y_E,X_F,Y_F, X_G, Y_G, X_H, Y_H, X_I, Y_I, X_L, Y_L, X_M, Y_M]
    
    image = (image / 127.5) - 1.0

    dataset_class=tf.cond(tf.equal(dataset_class,1.0), lambda: 0.0, lambda: 1.0)
    
    return image, {'discriminator': dataset_class, 'regressor': output_data}

def pad_rows(cropped_img,cols,rows,xmin,ymin):
    rows_to_pad_up=tf.cast((cols-rows)/2,tf.int32)
    padding_up=tf.zeros([rows_to_pad_up,cols,1],dtype=tf.uint8)

    rows_to_pad_down=cols-rows-rows_to_pad_up
    padding_down=tf.zeros([(rows_to_pad_down),cols,1],dtype=tf.uint8)

    cropped_img=tf.concat((padding_up,cropped_img,padding_down),axis=0)
    ymin = ymin-tf.cast(rows_to_pad_up,tf.float32)

    return cropped_img, xmin,ymin

def pad_cols(cropped_img,cols,rows,xmin,ymin):
    cols_to_pad_left=tf.cast((rows-cols)/2,tf.int32)
    padding_left=tf.zeros([rows,cols_to_pad_left,1],dtype=tf.uint8)
    cols_to_pad_right=rows-cols-cols_to_pad_left
    padding_right=tf.zeros([rows,cols_to_pad_right,1],dtype=tf.uint8)

    cropped_img=tf.concat((padding_left,cropped_img,padding_right),axis=1)
    xmin = xmin-tf.cast(cols_to_pad_left,tf.float32)

    return cropped_img,xmin,ymin

# (Optional) Visualize the dataset
Use the following cells to visualize the dataset.

In [None]:
image_size = 224
AUTO=tf.data.AUTOTUNE

# Train dataset preparation

all_data_record=load_tf_records(train_dataset_path).map(lambda x: decode_dataset(x, image_size), num_parallel_calls=AUTO,deterministic=False).map(lambda x,y: apply_augmentations(x,y,image_size),num_parallel_calls=AUTO,deterministic=False)

# Validation dataset preparation
validation_dataset=load_tf_records(validation_dataset_path).map(lambda x: decode_dataset(x, image_size), num_parallel_calls=AUTO,deterministic=False).map(lambda x, y: map_validation_dataset(x,y, image_size), num_parallel_calls=AUTO)


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
even=np.arange(0,22,2)
odd=np.arange(1,22,2)

for image, label in all_data_record.take(10):
  
  label1=np.reshape(label['regressor'][:-1],22)*image_size

  plt.imshow((image+1)*0.5)

  plt.plot(label1[even],label1[odd],'.')
  plt.show()

  print(label)

# Training pipeline
Dataset preprocessing.

In [None]:
image_size = input_shape
AUTO=tf.data.AUTOTUNE
batch_size=48
epochs = 40

# Train dataset preparation

all_data_record=load_tf_records(train_dataset_path).map(lambda x: decode_dataset(x, image_size), num_parallel_calls=AUTO,deterministic=False).cache().map(lambda x,y: apply_augmentations(x,y,image_size),num_parallel_calls=AUTO,deterministic=False)

@tf.function()
def get_synthetic(ds):
  return ds.filter(lambda x, y: tf.equal(y['regressor'][-1],1.))

@tf.function()
def get_lightbox(ds):
  return ds.filter(lambda x, y: tf.equal(y['regressor'][-1],2.))

@tf.function()
def get_sunlamp(ds):
  return ds.filter(lambda x, y: tf.equal(y['regressor'][-1],3.))

synthetic_ds = get_synthetic(all_data_record).repeat().take(15992*3) #few images repeated
lightbox_ds = get_lightbox(all_data_record).repeat().take(23988) # Repeat as many time as needed and take 15992*3/2
sunlamp_ds = get_sunlamp(all_data_record).repeat().take(23988) # Repeat as many time as needed and take 15992*3/2

real_ds = lightbox_ds.concatenate(sunlamp_ds).shuffle(1000, seed=1) # Small shuffle buffer size for performance reasons

# Batch the datasets for correct interleaving ratio:
synthetic_ds_batched = synthetic_ds.batch(3)
real_ds_batched = real_ds.batch(3)

# Populate batch 50-50 synthetic-real
train_ds=tf.data.Dataset.zip((synthetic_ds_batched, real_ds_batched)).map(lambda x, y: (tf.concat((x[0], y[0]), axis=0), {'discriminator': tf.concat((x[1]['discriminator'],y[1]['discriminator']),axis=0), 'regressor': tf.concat((x[1]['regressor'],y[1]['regressor']),axis=0)}),num_parallel_calls=AUTO).unbatch().prefetch(AUTO)

train_dataset = train_ds.batch(batch_size,drop_remainder=True).repeat()

# Validation dataset
validation_dataset=load_tf_records(validation_dataset_path).map(lambda x: decode_dataset(x, image_size), num_parallel_calls=AUTO,deterministic=False).map(lambda x, y: map_validation_dataset(x,y, image_size), num_parallel_calls=AUTO).batch(batch_size,drop_remainder=True).cache().repeat().prefetch(AUTO)

# Distribute datasets on TPU
train_ds_distributed=tpu_strategy.experimental_distribute_dataset(train_dataset)
valid_ds_distributed=tpu_strategy.experimental_distribute_dataset(validation_dataset)

dataset_size = 15992*3*2
validation_size = 11994

steps_per_epoch=dataset_size//batch_size
validation_steps=validation_size//batch_size

# Create train dataset iterator
train_iterator = iter(train_ds_distributed)
valid_iterator = iter(valid_ds_distributed)


Define losses and metrics.

In [None]:
GLOBAL_BATCH_SIZE = batch_size
synthetic_images_per_replica = 3 # With global batch size 48 and 6 batch size per replica

### Losses:
with tpu_strategy.scope():
  # initialize cls loss with no reduction, reference: https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function
  # from logits = True when we do not apply a sigmoid function to the output dense layer
  cls_lossobj = tf.keras.losses.BinaryCrossentropy(
      reduction=tf.keras.losses.Reduction.NONE,
      from_logits=True)

  def discriminator_loss(labels, predictions):
    per_example_loss = cls_lossobj(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=batch_size)

  def mae_loss(y_true,y_pred):
    # We know that the first three elements of the batch (axis 0) are the synthetic images.
    #y_true as also contains the dataset_class, used to interleave synthetic and real images
    y_true = y_true[0:3,0:num_keypoints*2]
    y_pred = y_pred[0:3,0:num_keypoints*2]

    # Compute the absolute error for each coordinate and take the average
    per_example_loss = tf.math.reduce_mean(tf.math.abs(y_pred-y_true),axis=-1)

    synthetic_images_per_batch = synthetic_images_per_replica*8

    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=synthetic_images_per_batch)

  def val_mae_loss(y_true,y_pred):
    
    # Compute the absolute error for each coordinate and take the average 
    per_example_loss = tf.math.reduce_mean(tf.math.abs(y_pred-y_true),axis=-1)

    # In the validation loss we average over the entire batch since all images are synthetic
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

### Metrics
with tpu_strategy.scope(): 
  
  # train loss accumulators
  encoder_loss_tracker = tf.keras.metrics.Mean(name="encoder_loss")
  bbox_loss_tracker = tf.keras.metrics.Mean(name="bbox_loss")
  cls_loss_tracker = tf.keras.metrics.Mean(name="cls_loss")

  # validation loss accumulators
  encoder_val_loss_tracker = tf.keras.metrics.Mean(name="encoder_val_loss")
  bbox_val_loss_tracker = tf.keras.metrics.Mean(name="bbox_val_loss")
  cls_val_loss_tracker = tf.keras.metrics.Mean(name="cls_val_loss")

  # Binary accuracy references: https://www.tensorflow.org/hub/tutorials/tf2_text_classification and https://github.com/tensorflow/tensorflow/issues/41413

  # Threshold = 0 when from logits True in loss
  cls_accuracy_tracker = tf.keras.metrics.BinaryAccuracy(name='cls_accuracy', threshold=0.0)
  cls_val_accuracy_tracker = tf.keras.metrics.BinaryAccuracy(name='cls_val_accuracy', threshold=0.0)

Optimizers and checkpoint manager.

In [None]:
import time
K = tf.keras.backend 
import tensorflow_addons as tfa

logger=tf.get_logger()

total_steps = steps_per_epoch*epochs

with tpu_strategy.scope(): 
  
  # gamma value to be used for first epoch
  gamma = K.variable(0.0)

  # We define an optimizer for each NN element
  optimizer_encoder=tfa.optimizers.AdamW(weight_decay=1e-8,
      learning_rate=tf.keras.optimizers.schedules.CosineDecay(5e-5, total_steps)
  )
  
  optimizer_regressor=tfa.optimizers.AdamW(weight_decay=1e-8,
      learning_rate=tf.keras.optimizers.schedules.CosineDecay(5e-5, total_steps)
  )

  optimizer_discriminator=tfa.optimizers.AdamW(weight_decay=1e-8,
      learning_rate=tf.keras.optimizers.schedules.CosineDecay(5e-5, total_steps)
  )

  # Checkpoint manager
  checkpoint = tf.train.Checkpoint(
      epoch=tf.Variable(-1), # we add 1 as soon as we start training
      optimizer_encoder=optimizer_encoder,
      optimizer_regressor=optimizer_regressor,
      optimizer_discriminator=optimizer_discriminator,
      network=network,
      )
  
  manager = tf.train.CheckpointManager(checkpoint,checkpoint_prefix, max_to_keep=3)

Custom training loop: note, the code is configured for Swin Transformers backbones. To train the EfficientNet based model, modify the "index" variables in the cell below.

In [None]:
@tf.function
def train_step(iterator,steps_per_epoch):
  """The step function for one training step."""
  def step_fn(data):
        images, y_true = data
        y_truecls = tf.reshape(y_true['discriminator'],[GLOBAL_BATCH_SIZE//8,1])

        with tf.GradientTape(persistent=True) as tape:
          # Forward pass
          y_pred_cls, y_pred_bbox = network(images,training=True)

          # Compute losses 
          cls_loss = discriminator_loss(y_truecls,y_pred_cls)
          bbox_loss=mae_loss(y_true['regressor'],y_pred_bbox)
          encoder_loss = bbox_loss - cls_loss*gamma*0.01 

        # Accumulate loss, reference: https://www.tensorflow.org/guide/tpu#improving_performance_with_multiple_steps_inside_tffunction
        bbox_loss_tracker.update_state(bbox_loss*tpu_strategy.num_replicas_in_sync)
        cls_loss_tracker.update_state(cls_loss*tpu_strategy.num_replicas_in_sync)
        encoder_loss_tracker.update_state(encoder_loss*tpu_strategy.num_replicas_in_sync)

        cls_accuracy_tracker.update_state(y_truecls , y_pred_cls)

        # Compute gradients for encoder
        trainable_vars_encoder = network.get_layer(index=1).trainable_variables
        gradients_encoder = tape.gradient(encoder_loss, trainable_vars_encoder)


        # Update weights for encoder: https://www.tensorflow.org/guide/tpu#improving_performance_with_multiple_steps_inside_tffunction
        optimizer_encoder.apply_gradients(list(zip(gradients_encoder, trainable_vars_encoder)))

        # Compute gradients for discrimination head
        trainable_vars_discriminator = network.get_layer(index=3).trainable_variables # set index = 2 for EfficientNet
        gradients_discriminator = tape.gradient(cls_loss, trainable_vars_discriminator)
        
        # Update weights for discrimination head
        optimizer_discriminator.apply_gradients(list(zip(gradients_discriminator, trainable_vars_discriminator)))

        # Compute gradients for regression head
        trainable_vars_regressor = network.get_layer(index=4).trainable_variables # set index = 3 for EfficientNet
        gradients_regressor = tape.gradient(bbox_loss, trainable_vars_regressor)
        
        # Update weights for regression head
        optimizer_regressor.apply_gradients(list(zip(gradients_regressor, trainable_vars_regressor)))


  for _ in tf.range(steps_per_epoch):
    tpu_strategy.run(step_fn, args=(next(iterator),))

@tf.function
def valid_step(data_iter,validation_steps):
  def valid_step_fn(data):
        images, y_true = data
        y_truecls = tf.reshape(y_true['discriminator'],[GLOBAL_BATCH_SIZE//8,1])

        y_pred_cls, y_pred_bbox = network(images,training=False)
        
        # Compute losses
        cls_val_loss = discriminator_loss(y_truecls,y_pred_cls)
        bbox_val_loss=val_mae_loss(y_true['regressor'],y_pred_bbox)
        encoder_val_loss = bbox_val_loss - cls_val_loss*gamma*0.01

        # Accumulate losses
        bbox_val_loss_tracker.update_state(bbox_val_loss*tpu_strategy.num_replicas_in_sync)
        cls_val_loss_tracker.update_state(cls_val_loss*tpu_strategy.num_replicas_in_sync)
        encoder_val_loss_tracker.update_state(encoder_val_loss*tpu_strategy.num_replicas_in_sync)

        # Compute and accumulate accuracy
        cls_val_accuracy_tracker.update_state(y_truecls , y_pred_cls)
        
  for _ in tf.range(validation_steps):
    tpu_strategy.run(valid_step_fn, args=(next(data_iter),))

In [None]:
def restore_model(manager):
  with tpu_strategy.scope(): 
    if manager.latest_checkpoint:
      checkpoint.restore(manager.latest_checkpoint)
      print("Restored from {}".format(manager.latest_checkpoint))
    else:
      print("Initializing from scratch.")
    return checkpoint

def train_model(eventually_restored_checkpoint,manager):    
  
  while checkpoint.epoch.numpy() +1 < epochs:
    epoch_start_time = time.time()
    eventually_restored_checkpoint.epoch.assign_add(1)
    epoch=eventually_restored_checkpoint.epoch.numpy()
    print('Epoch: {}, initial learning rate: {}'.format(epoch+1, round(optimizer_encoder._decayed_lr(tf.float32).numpy(),10)))
    
    # update gamma
    K.set_value(gamma, 2.0/(1.0+tf.math.exp(-10.0*(epoch+1)/epochs))-1.0) 

    # train step
    train_step(train_iterator,steps_per_epoch)

    print('Current step: {}, encoder loss: {}, kpts loss: {}, cls loss: {}, cls_accuracy: {}\n'.format(
      optimizer_encoder.iterations.numpy(),
      round(float(encoder_loss_tracker.result()), 4),
      round(float(bbox_loss_tracker.result()), 4),
      round(float(cls_loss_tracker.result()), 4),
      round(float(cls_accuracy_tracker.result()),4)))

    # validation step
    valid_step(valid_iterator,tf.convert_to_tensor(validation_steps))
    print('Validation data - encoder loss: {}, kpts loss: {}, cls loss: {}, cls_accuracy: {}\n'.format(
      round(float(encoder_val_loss_tracker.result()), 4),
      round(float(bbox_val_loss_tracker.result()), 4),
      round(float(cls_val_loss_tracker.result()), 4),
      round(float(cls_val_accuracy_tracker.result()),4)))

    # Export checkpoints every 5 epochs
    if (epoch+1) % 5 == 0:
      print('\n Saving checkpoint...\n')

      with tpu_strategy.scope():
        manager.save()

    # Train accumulators reset
    encoder_loss_tracker.reset_states()
    bbox_loss_tracker.reset_states()
    cls_loss_tracker.reset_states()
    cls_accuracy_tracker.reset_states()

    # Validation accumulators reset
    encoder_val_loss_tracker.reset_states()
    bbox_val_loss_tracker.reset_states()
    cls_val_loss_tracker.reset_states()
    cls_val_accuracy_tracker.reset_states()

Resume checkpoint.

In [None]:
### TO RESTORE FROM CHECKPOINT RUN THIS CELL BEFORE TRAINING

del(network,optimizer_regressor,optimizer_encoder,optimizer_discriminator)


checkpoint=restore_model(manager)
network=checkpoint.network
optimizer_regressor=checkpoint.optimizer_regressor
optimizer_encoder=checkpoint.optimizer_encoder
optimizer_discriminator=checkpoint.optimizer_discriminator


Start training and export weights.

In [None]:
train_model(checkpoint,manager)
network.save_weights(weights_export_dir)