In [None]:
%%capture
# Machine Learning and Data Science Imports
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as backend
import collections
from tensorflow.keras import *
import tensorflow_addons as tfa
import pandas as pd; pd.options.mode.chained_assignment = None;
import numpy as np
import string

# Built In Imports
from kaggle_datasets import KaggleDatasets
from collections import Counter
from glob import glob
import random
import math
from tqdm.notebook import tqdm
import os


import matplotlib.pyplot as plt
import cv2

AUTO = tf.data.experimental.AUTOTUNE
def seed_it_all(seed=7):
    """ Attempt to be Reproducible """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
seed_it_all()


Load TPU

In [None]:
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
    keras.mixed_precision.set_global_policy("mixed_bfloat16")
except:
    TPU = None
    strategy = tf.distribute.get_strategy() 
N_REPLICAS = strategy.num_replicas_in_sync
tf.config.optimizer.set_jit(True)
bs = 2
BATCH_SIZE = bs * N_REPLICAS
TARGET_DTYPE = tf.bfloat16 if TPU else tf.float32

# Data Pipeline

Data Module
- Hosted on Google Cloud Storage for Feeding into TPU.


In [None]:
def get_anchors(image_size):
    all_anchors = []
    grid_size = np.array([8, 16, 32, 64, 128], np.float32)
    anchor_sizes = grid_size * 4.
    anchor_ratios = np.array([0.5, 1, 2], np.float32)
    anchor_scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)], np.float32)
    for i in range(5):
        # anchors (1, 9, 4)
        anchors_list = np.zeros((len(anchor_ratios) * len(anchor_scales), 4))
        anchors_list[:, 2] = np.tile(anchor_sizes[i] * anchor_scales, 3) / np.sqrt(np.repeat(anchor_ratios, len(anchor_scales)))
        anchors_list[:, 3] = np.tile(anchor_sizes[i] * anchor_scales, 3) * np.sqrt(np.repeat(anchor_ratios, len(anchor_scales)))
        anchors_list[:, 0::2] -= anchors_list[:, 2:3] * 0.5
        anchors_list[:, 1::2] -= anchors_list[:, 3:4] * 0.5
        anchors_list = np.expand_dims(anchors_list, axis=0)
        
        # shift (K, 1, 4)
        shift_pts = np.arange(grid_size[i] // 2, image_size, grid_size[i], dtype=np.float32)  # [4, 12, 30, ...]
        shift_list = np.zeros((len(shift_pts) ** 2, 4)) # (K, 1, 4)
        shift_list[:,0] = shift_list[:,2] = np.tile(shift_pts, len(shift_pts)) # x1 x2   np.tile (0, 1, 2) => (0, 1, 2, 0, 1, 2, 0, 1, 2)
        shift_list[:,1] = shift_list[:,3] = np.repeat(shift_pts, len(shift_pts)) # y1 y2  np.repeat (0, 1, 2) => (0, 0, 0, 1, 1, 1, 2, 2, 2)
        shift_list = np.expand_dims(shift_list, axis=1)
        
        # merge (K, 9, 4) = > (K * 9, 4)
        all_anchors.append(np.reshape(anchors_list + shift_list, (-1, 4)))
    all_anchors = np.concatenate(all_anchors, axis=0)
    return tf.cast(all_anchors, tf.float32)

In [None]:
def compute_area(bbox):
    # Computes Areas of a Bbox
    ones = bbox[:, :2] # (N, 2)
    twos = bbox[:, 2:4] # (N, 2)
    diff = tf.maximum(twos - ones, 0.0) # (N, 2)
    return diff[:, 0] * diff[:, 1] # (N,)
def filter_bboxes(box):
    # Filters out Bboxes if their area is < 0
    indices_for_object = tf.where(tf.math.reduce_all([box[:,4] >= 0, 
                                                          box[:, 2] >= box[:, 0], 
                                                          box[:, 3] >= box[:, 1]], axis=0))
    box = tf.cond(tf.shape(indices_for_object)[0] == 0, 
                    lambda: tf.reshape([0., 0., -1., -1., 0.], (-1, 1, box.shape[-1])),
                    lambda: tf.reshape(tf.gather(box, indices_for_object), (-1, 1, box.shape[-1]))) # (?, 1, 5)
    return box
def compute_centers(bbox):
    return 0.5 * (bbox[:, :2] + bbox[:, 2:4])
def compute_iou(bbox, anchors):
    # Computes IOU Between Bbox and Anchors
    # Bbox: tensor(N, min(4)) - As long as the first 4 values are (x1, y1, x2, y2)
    # Anchors: (x1, y1, x2, y2)
    
    # Expand bbox and anchor for broadcasting
    area_bbox = compute_area(bbox) # (N, )
    area_anchor = compute_area(anchors) # (M,)
    
    area_bbox = tf.expand_dims(area_bbox, axis = 0) # (1, N)
    area_anchor = tf.expand_dims(area_anchor, axis = 1) # (M, 1)
    
    
    bbox = tf.expand_dims(bbox, axis = 0) # (1, N, 5)
    anchors = tf.expand_dims(anchors, axis = 1) # (M, 1, 5)
    top_left_bboxes = tf.maximum(bbox[:, :, :2], anchors[:, :, :2]) # (M, N, 2)
    bottom_right_bbox = tf.minimum(bbox[:, :, 2:4], anchors[:, :, 2:4]) # (M, N, 2)
    
    
    
    differences = bottom_right_bbox - top_left_bboxes# (M, N, 2)
    differences = tf.maximum(differences, 0.0)
    
    inter = tf.maximum(differences[:, :, 0] * differences[:, :, 1], 0.0) # (M, N)
    union = area_bbox + area_anchor - inter # (M, N)

    eps = 1e-6
    return (inter + eps) / (union + eps) # (M, N)

In [None]:
def load_tfrecs(fold_idx, num_folds, num_epochs):
    if TPU is not None:
        glob_fn = lambda x: tf.io.gfile.glob(x)
    else:
        glob_fn = lambda x: glob(x)
    
    # Grab the correct dir for epochs and fold 
    TRAIN_TFRECS = [] 
    VAL_TFRECS = []
    for i in range(num_folds):
        if i == fold_idx:
            base_dir = DataModule.val_dirs[0]
            VAL_TFRECS.append(glob_fn(f"{base_dir}/val_fold_{i}*"))
    for epoch in range(num_epochs):
        for i in range(num_folds):
            if i == fold_idx:
                continue 
            else:
                #-------------TRAIN DIR---------------
                for idx in range(len(DataModule.train_dirs)):
                    base_dir = DataModule.train_dirs[idx] # The Base Dir 
                    if len(glob_fn(f"{base_dir}/{epoch}_fold_{i}*")) == 0:
                        continue
                    TRAIN_TFRECS.append(glob_fn(f"{base_dir}/{epoch}_fold_{i}*")[0])
    VAL_NUMBER = sum(1 for _ in tf.data.TFRecordDataset(VAL_TFRECS))
    TRAIN_NUMBER = VAL_NUMBER * 4
    ext_NUMBER = sum(1 for _ in tf.data.TFRecordDataset(TRAIN_TFRECS[-1]))
    TRAIN_NUMBER += ext_NUMBER
    return TRAIN_TFRECS, VAL_TFRECS, TRAIN_NUMBER, VAL_NUMBER

        
class DataModule:
    # -------------Load Dataset----------------
    
    NUM_EPOCHS = 50 # 30 EPOCHS augmented.
    train_dirs = [
        'd/andrewshao05/folds-0-10-gwd',
        'd/andrewshao05/folds-10-20-gwd',
        'd/andrewshao05/folds-20-30-gwd',
        'd/andrewshao05/folds-30-40-gwd',
        'd/andrewshao05/folds-40-50-gwd',
        # Pseudo Data(The External Didn't Help at all)
        'pseudo-1-10',
        'folds-10-20',
        'pseudo-20-30',
        'folds-30-40',
        'pseudo-40-50'
        
    ]
    
    val_dirs = [
        'val-folds-gwd'
    ]
    if TPU is not None:
        new_train_dirs = []
        new_val_dirs = []
        for td in train_dirs:
            new_train_dirs.append(KaggleDatasets().get_gcs_path(td))
            print(td)
        for vd in val_dirs:
            new_val_dirs.append(KaggleDatasets().get_gcs_path(vd))
        train_dirs = new_train_dirs
        val_dirs = new_val_dirs 
    else:
        new_train_dirs = [] 
        new_val_dirs = []
        for td in train_dirs:
            new_train_dirs.append(f"../input/{td}")
        for vd in val_dirs:
            new_val_dirs.append(f"../input/{vd}")
        train_dirs = new_train_dirs 
        val_dirs = new_val_dirs 
        
    SAVE_PATH = './'
    # -----------OTHER Data Params-------------------
    IMG_SHAPE = (1024, 1024, 3) # Massive Fricken Images - But EfficientDet + TPUs can handle it.
    # Generate Anchors, a Constant.
    ANCHORS = get_anchors(IMG_SHAPE[0]) # Negative Values are fine.
    # num Classes = 1. Wheat or No Wheat
    NUM_CLASSES = 1
    # Max Number of BBoxes == 516, but to be safe, 1024 cap.
    MAX_BBOXES = 1024
# Total Number
CUR_EPOCH = 0
FOLD_IDX = 0
NUM_FOLDS = 6#7
NUM_EPOCHS = DataModule.NUM_EPOCHS
TRAIN_TFRECS, VAL_TFRECS, TRAIN_NUMBER, VAL_NUMBER =  load_tfrecs(FOLD_IDX, NUM_FOLDS, NUM_EPOCHS)


# Augmentations:

In [None]:
def display_images(images, bboxes):
    values = tf.where((tf.greater_equal(bboxes[0][:, 4], 1.0)))
    idx = tf.squeeze(tf.where((tf.greater_equal(bboxes[0][:, 4], 1.0))))
    bounding = tf.gather(bboxes[0], idx)
    anchors = tf.gather(DataModule.ANCHORS, idx)
    images = images.numpy()
    bounding = decode_bounding_box(bounding, anchors).numpy()
    for idx in range(len(bounding)):
        bbox = bounding[idx, :] # (5)
        x1, y1, x2, y2 = bbox[:4]
        coord1 = (x1, y1)
        coord2 = (x2, y2)


        cv2.rectangle(images, (int(x1), int(y1)), (int(x2), int(y2)), 220, 3)
    plt.imshow(images)
    plt.show()
    

def decode_bounding_box(bboxes, anchors, image_size = 1024):
    # Decodes the Predictions of the Bboxes(N,5) - Only X1,Y1, X2, Y2
    width = anchors[:, 2] - anchors[:, 0]
    height = anchors[:, 3] - anchors[:, 1]
    # Obj Score should be kept the same for NMS.
    
    x1 = tf.expand_dims(bboxes[:, 0] + anchors[:, 0], axis = 1)
    y1 = tf.expand_dims(bboxes[:, 1] + anchors[:, 1], axis = 1)
    w =  tf.expand_dims(tf.math.exp(bboxes[:, 2]) * width, axis = 1)
    h = tf.expand_dims(tf.math.exp(bboxes[:, 3]) * height, axis = 1)
    obj =bboxes[:, 4:]
    
    
    x2 = x1 + w
    y2 = y1 + h
    
    new_bboxes = tf.concat([x1, y1, x2, y2, obj], axis = 1)
    return new_bboxes
    

# ENCODING FN

In [None]:
# BBox Alignment Code
def datasets_encode(x, y, overlap_threshold = 0.5, ignore_threshold = 0.4):
    # x (Batch Size, H, W, 3)
    # y (Batch Size, N, 5)  x1, y1, x2, y2, class

    # resize image
    WIDTH, HEIGHT, _ = DataModule.IMG_SHAPE
    if x.shape == DataModule.IMG_SHAPE:
        x_scale, y_scale = DataModule.IMG_SHAPE[0] / x.shape[0], DataModule.IMG_SHAPE[1] / x.shape[1]
        y = y * [x_scale, y_scale, x_scale, y_scale, 1]
        x = tf.image.resize(x, [DataModule.IMG_SHAPE[0], DataModule.IMG_SHAPE[1]])
    

    assignment_x = tf.zeros((DataModule.ANCHORS.shape[0]), tf.float32)
    assignment_y = tf.zeros((DataModule.ANCHORS.shape[0]), tf.float32)
    assignment_w = tf.zeros((DataModule.ANCHORS.shape[0]), tf.float32)
    assignment_h = tf.zeros((DataModule.ANCHORS.shape[0]), tf.float32)
    assignment_is_obj = tf.zeros((DataModule.ANCHORS.shape[0]), tf.float32)


    # filter correct boxes (area > 0)
    priors = DataModule.ANCHORS # (1, 196416, 4)
    width = tf.expand_dims(priors[:, 2] - priors[:, 0], axis = 1)
    height = tf.expand_dims(priors[:, 3] - priors[:, 1], axis = 1)
    
    anchor_wh = tf.concat([width, height], axis = 1) # (196416, 2)
    
    box = tf.squeeze(filter_bboxes(y), axis = 1)
   
    iou = compute_iou(priors, box)

    iou_max = tf.math.reduce_max(iou, axis=0) # (196416)
    iou_max_idxs = tf.math.argmax(iou, axis=0) # (196416)
    
    # ignore box
    ignore_mask = tf.math.logical_and(iou_max > ignore_threshold, iou_max < overlap_threshold) # (196416)
    assignment_is_obj = tf.where(ignore_mask, -1., assignment_is_obj)

    # object box
    assign_mask = iou_max > overlap_threshold # (196416)
    #print(tf.where(assign_mask))
    box_best = tf.gather(box, iou_max_idxs, axis=0) # (?, 5) + (196416) => (196416, 5)_
    assignment_is_obj = tf.where(assign_mask, 1., assignment_is_obj)
   


    box_center = compute_centers(box_best)

    box_wh = box_best[:,2:4] - box_best[:,:2]  #  (196416, 2)
    assigned_xy = (box_center - compute_centers(priors)) #  (196416, 2)
    assigned_wh = tf.math.log(box_wh / anchor_wh) #  (196416, 2)
    assignment_x = tf.where(assign_mask, assigned_xy[:, 0], assignment_x)
    assignment_y = tf.where(assign_mask, assigned_xy[:, 1], assignment_y)
    assignment_w = tf.where(assign_mask, assigned_wh[:, 0], assignment_w)
    assignment_h = tf.where(assign_mask, assigned_wh[:, 1], assignment_h)


    regression_list = tf.stack([
        assignment_x, assignment_y, 
        assignment_w, assignment_h, 
        assignment_is_obj], axis=-1) # (N, 5)  
    

    classification_list = assignment_is_obj  # Literally Binary CE.
    
    return x, (regression_list, classification_list)
def batched_dataset_encode(images, bboxes, overlap_threshold = 0.5, ignore_threshold = 0.4):
  # Images: tensor(B, H, W, 3)
  # Bboxes: (Tensor(B, N,5), Tensor(B, N, C))
  B, H, W, C = images.shape
  _, _, num_channels = bboxes.shape
  num_anchors = len(DataModule.ANCHORS)
  num_classes = DataModule.NUM_CLASSES if DataModule.NUM_CLASSES != 1 else 2 
  IMAGES = tf.ones((0, H, W, C), images.dtype)
  BBOXES = tf.ones((0, num_anchors, num_channels), images.dtype)
  CLASSIFICATION = tf.ones((0, num_anchors), images.dtype)
  for b in range(B):
    image = images[b] # (H, W, 3)
    bbox = bboxes[b] # (N, 5)
    
    # Drop Padded Values: Where bboxes[:, -1] == 0.0
    keep = tf.equal(bbox[:, -1], 1.0)
  
    bbox = tf.boolean_mask(bbox, keep, axis = 0)

    image, (bbox, Class) = datasets_encode(image, bbox, overlap_threshold = overlap_threshold, ignore_threshold = ignore_threshold)
    
    IMAGES = tf.concat([IMAGES, tf.expand_dims(image, axis = 0)], axis = 0)
    BBOXES = tf.concat([BBOXES, tf.expand_dims(bbox, axis = 0)], axis = 0)
    CLASSIFICATION = tf.concat([CLASSIFICATION, tf.expand_dims(Class, axis = 0)], axis = 0)
  return IMAGES, (BBOXES, CLASSIFICATION)


In [None]:
class Augments:
  IMAGE_SIZE = 1024
  THRESH = 0.25 # Doesn't Matter how large the bounding box is: Just can't be completely outside of the image.
  BORDER = 100
  
  MAX_BBOX = 200 # Assume Max # of Bboxes, to pad to.(You can tune this to the dataset)


# Helpers

In [None]:
def load_image(example):
    feature_dict = {
        'image': tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value=''),
        'bboxes': tf.io.VarLenFeature(dtype = tf.int64)
        #'classification': tf.io.VarLenFeature(dtype = tf.int64) 
    }
    features = tf.io.parse_single_example(example, features=feature_dict)
    
    image = features['image']
    bboxes = features['bboxes'] # Classification Exists, but it's just repetition 
    
    # Load image
    image = tf.io.decode_jpeg(image, channels = 3)
    # Cast to Desired Dtype
    image = tf.cast(image, tf.float32) / 255.0
    # Reshape the image 
    image = tf.reshape(image,  DataModule.IMG_SHAPE)
    # Load the Bounding Boxes
    bboxes = tf.cast(tf.sparse.to_dense(bboxes), tf.float32)
    bboxes = tf.reshape(bboxes, (-1, 5)) # (N, 4)
    
    return image, bboxes


# COMBINATION AUGMENTS

In [None]:
def thresh_by_zero(bboxes):
  img_shape = Augments.IMAGE_SIZE
  # Bboxes: Tensor(B, 5)
  x1 = bboxes[:, 0]
  y1 = bboxes[:, 1]
  x2 = bboxes[:, 2]
  y2 = bboxes[:, 3]
  obj = bboxes[:, 4:]

  # Threshold the Values
  x1 = tf.expand_dims(tf.clip_by_value(x1, 0.0, img_shape - 1), axis = -1)
  y1 = tf.expand_dims(tf.clip_by_value(y1, 0.0, img_shape - 1), axis = -1)
  x2 = tf.expand_dims(tf.clip_by_value(x2, 0.0, img_shape - 1), axis = -1)
  y2 = tf.expand_dims(tf.clip_by_value(y2, 0.0, img_shape - 1), axis = -1)
  # Compute Area and Threshold
  bboxes = tf.concat([x1, y1, x2, y2, bboxes[:, 4:]], axis = 1) # (B, 5)
  new_area =  compute_area(bboxes) # (N, )

  # Threshold
  keep = tf.where(tf.greater(new_area, 0.0))
  kept_bboxes = tf.gather(bboxes, keep, axis = 0) # (M, 5)
  return kept_bboxes # (N, 5)

In [None]:
def display_bbox_normal(image, bbox):
  N, _ = bbox.shape
  
  for box in bbox:
    if box[-1] == 0:
      continue
    cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), 220, 3)
  plt.imshow(image)
  plt.show()  
def convert_normal(bboxes, anchors, idx):
  # Select Valid Bbox
  indices = tf.where(tf.equal(bboxes[idx, :, -1], 1.0))
  valid_bboxes = tf.squeeze(tf.gather(bboxes[idx], indices))[:, :-1] # (N, 4)
  valid_anchors = tf.squeeze(tf.gather(anchors, indices)) # (N, 4)
  
  # Converts XYWH bounding to XYXY
  x1 = valid_bboxes[:, 0] # Represents CX (CX = X - CAX)
  y1 = valid_bboxes[:, 1] # Represents CY (CY = Y - CAY)
  w = valid_bboxes[:, 2] # Represents Width change (Log(W / AW))
  h = valid_bboxes[:, 3] # Represents Height change (Log(H / AH))
  
  anchor_centers = compute_centers(valid_anchors)
  cx = anchor_centers[:, 0]
  cy = anchor_centers[:, 1]

  anchor_x1 = valid_anchors[:, 0]
  anchor_y1 = valid_anchors[:, 1]
  anchor_x2 = valid_anchors[:, 2]
  anchor_y2 = valid_anchors[:, 3]

  anchor_w = anchor_x2 - anchor_x1
  anchor_h = anchor_y2 - anchor_y1

  # Convert to normal format
  new_x1 = x1 + cx
  new_y1 = y1 + cy
  new_w = tf.math.exp(w) * anchor_w
  new_h = tf.math.exp(h) * anchor_h
 
  new_x1 = new_x1 - new_w / 2
  new_y1 = new_y1 - new_h / 2

  new_x2 = new_x1 + new_w
  new_y2 = new_y1 + new_h


  new_x1 = tf.expand_dims(new_x1, 1)
  new_y1 = tf.expand_dims(new_y1, 1)
  new_x2 = tf.expand_dims(new_x2, 1)
  new_y2 = tf.expand_dims(new_y2, 1)
  
  bboxes = tf.concat([new_x1, new_y1, new_x2, new_y2], axis = 1)
  return bboxes

In [None]:
def get_dfs(FOLD_IDX):
    train, val, _, _ = load_tfrecs(FOLD_IDX, NUM_FOLDS, DataModule.NUM_EPOCHS)
    # Create TFRecordDatasets
    train_dataset = tf.data.TFRecordDataset(
        train,
        num_parallel_reads = AUTO
    )
    val_dataset = tf.data.TFRecordDataset(
        val,
        num_parallel_reads = AUTO
    )
    # Options
    options = tf.data.Options()
    options.experimental_deterministic = False
    
    train_dataset = train_dataset.with_options(options)
    val_dataset = val_dataset.with_options(options)
    # Map Locations
    train_dataset = train_dataset.map(lambda x: load_image(x), num_parallel_calls = AUTO, deterministic = False)
    val_dataset = val_dataset.map(lambda x: load_image(x), num_parallel_calls = AUTO, deterministic = False)
   
    train_dataset = train_dataset.shuffle(128) 
    
    # batch the Dataset
    # TODO: CREATE CUSTOM BATCHING, ALSO CHECK BOUNDING BOXES BEFORE THIS POINT
    
    train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder = True)
    val_dataset = val_dataset.batch(BATCH_SIZE, drop_remainder = True)   
   
    # ALREADY AUGMENTED :), Just encode to Anchor Boxes
    train_dataset = train_dataset.map(lambda x, y: batched_dataset_encode(x, y), num_parallel_calls = AUTO, deterministic = False)
    val_dataset = val_dataset.map(lambda x, y: batched_dataset_encode(x, y), num_parallel_calls = AUTO, deterministic = False)
    train_dataset = train_dataset.repeat()  
    train_dataset = train_dataset.prefetch(AUTO)
    val_dataset = val_dataset.prefetch(AUTO)
    def convert_dtypes(images, bboxes):
        bboxes, classification = bboxes
        images = tf.cast(images, dtype = TARGET_DTYPE)
        bboxes = tf.cast(bboxes, dtype = TARGET_DTYPE)
        classification = tf.cast(classification, dtype = TARGET_DTYPE)
        return images, (bboxes, classification)
    
    train_dataset = train_dataset.map(lambda x, y: convert_dtypes(x, y))
    val_dataset = val_dataset.map(lambda x, y: convert_dtypes(x, y))
    
    return train_dataset, val_dataset

# Standardize Bounding Boxes

In [None]:
def fix_bounding_boxes(bboxes, min_height = 0, max_height = 1024, min_width = 0, max_width = 1024):
  # Corrects the Bounding Boxes
  x1 = bboxes[:, 0]
  y1 = bboxes[:, 1]
  x2 = bboxes[:, 2]
  y2 = bboxes[:, 3]
  obj = bboxes[:, 4]

  x1 = tf.clip_by_value(x1, min_width, max_width)
  y1 = tf.clip_by_value(y1, min_height, max_height) 
  x2 = tf.clip_by_value(x2, min_width, max_width)
  y2 = tf.clip_by_value(y2, min_height, max_height)

  x1 = tf.expand_dims(x1, axis = 0)
  y1 = tf.expand_dims(y1, axis = 0)
  x2 = tf.expand_dims(x2, axis = 0)
  y2 = tf.expand_dims(y2, axis = 0)
  obj = tf.expand_dims(obj, axis = 0)
  
  # Create Bboxes
  bboxes = tf.concat([x1, y1, x2, y2, obj], axis = 0) # (N, 5)
  # Compute Area
  area = compute_area(bboxes) 
  # Remove Bboxes with 0 area
  to_be_removed = tf.where(area > 0)
  bboxes = tf.gather(bboxes, to_be_removed)
  return tf.squeeze(bboxes, axis = 1)


# NMS THRESHOLD

In [None]:
def replace_index_bbox(bboxes, idx, new_bbox):
  N, C = bboxes.shape
  # Bboxes: Tensor(N, 6)
  # Idx: Tensor(1)
  # New_BBox: Tensor(6)
  begin = bboxes[:idx]
  end = bboxes[idx + 1:]
  middle = tf.expand_dims(new_bbox, axis = 0)

  new_bbox = tf.concat([begin, middle, end], axis = 0)
  new_bbox = tf.reshape(new_bbox, (-1, C))
  return new_bbox
def replace_index_scores(scores, idx, new_score):
  N = scores.shape[0]

  begin = scores[:idx]
  end = scores[idx + 1:]
  middle = tf.expand_dims(new_score, axis = 0)

  new_score = tf.concat([begin, middle, end], axis = 0)
  new_score = tf.reshape(new_score, (-1, ))
  return new_score 

In [None]:
def soft_nms_float(dets, labels, Nt, sigma, thresh, method):
    """
    Based on: https://github.com/DocF/Soft-NMS/blob/master/soft_nms.py
    It's different from original soft-NMS because we have float coordinates on range [0; 1]
    :param dets:   boxes format [x1, y1, x2, y2]
    :param sc:     scores for boxes
    :param Nt:     required iou 
    :param sigma:  
    :param thresh: 
    :param method: 1 - linear soft-NMS, 2 - gaussian soft-NMS, 3 - standard NMS
    :return: index of boxes to keep
    """

    # indexes concatenate boxes with the last column
    N = dets.shape[0]
    indexes = tf.cast(tf.expand_dims(tf.range(N), 1), dtype = dets.dtype)
    dets = tf.concat([dets, indexes], axis=1) # (N, 6)

    # the order of boxes coordinate is [y1, x1, y2, x2]
    y1 = dets[:, 1] # (N, )
    x1 = dets[:, 0] # (N, )
    y2 = dets[:, 3] # (N, )
    x2 = dets[:, 2] # (N, )

    scores = dets[:, 4] # (N, )
    areas = (x2 - x1) * (y2 - y1) # (N, )

    for i in range(N - 1):
        # intermediate parameters for later parameters exchange
        tBD = tf.identity(dets[i, :])
        tscore = tf.identity(scores[i])
        tarea = tf.identity(areas[i])
        pos = i + 1
        if i != N - 1:
            maxscore = tf.reduce_max(scores[pos:], axis=0)
            maxpos = tf.argmax(scores[pos:], axis=0)
        else:
            maxscore = scores[-1]
            maxpos = 0
        if tscore < maxscore:
            dets = replace_index_bbox(dets, i, dets[maxpos + pos, :])
            dets = replace_index_bbox(dets, maxpos + pos, tBD)
            tBD = dets[i, :]

            scores = replace_index_scores(scores, i, scores[maxpos + pos])
            scores = replace_index_scores(scores, maxpos + pos, tscore)
            tscore = scores[i]

            areas = replace_index_scores(areas, i, areas[maxpos + pos])
            areas = replace_index_scores(areas, maxpos + pos, tarea)
            tarea = areas[i]

        # IoU calculate
        xx1 = tf.maximum(dets[i, 1], dets[pos:, 1])
        yy1 = tf.maximum(dets[i, 0], dets[pos:, 0])
        xx2 = tf.minimum(dets[i, 3], dets[pos:, 3])
        yy2 = tf.minimum(dets[i, 2], dets[pos:, 2])

        w = tf.maximum(0.0, xx2 - xx1)
        h = tf.maximum(0.0, yy2 - yy1)
        inter = w * h
        ovr = inter / (areas[i] + tf.gather(areas, pos) - inter)
        # Three methods: 1.linear 2.gaussian 3.original NMS
        if method == 1:  # linear
            weight = tf.ones_like(ovr)
            boolean_mask = ovr > Nt 
            N = weight.shape[0]
            for idx in range(N):
              if boolean_mask[idx]:
                weight = replace_index_scores(weight, idx, weight[idx] - ovr[idx])
            
        elif method == 2:  # gaussian
            weight = tf.exp(-(ovr * ovr) / sigma)
        else:  # original NMS
            weight = tf.ones_like(ovr)
            boolean_mask = ovr > Nt
            N = boolean_mask.shape[0]
            for idx in range(N):
              weight = replace_index_scores(weight, idx, tf.constant(0.0, dtype = weight.dtype))

        length_left = len(scores) - pos
        for idx in range(length_left):
          new_idx = idx + pos
          scores = replace_index_scores(scores, new_idx, weight[idx] * scores[new_idx])

    # select the boxes and keep the corresponding indexes
  
    inds = dets[:, 5][scores > thresh]
    # Replace Scores 
    x1 = tf.expand_dims(dets[:, 0], axis = 1)
    y1 = tf.expand_dims(dets[:, 1], axis = 1)
    x2 = tf.expand_dims(dets[:, 2], axis = 1)
    y2 = tf.expand_dims(dets[:, 3], axis = 1)
    scores = tf.expand_dims(scores, axis = 1)

    bboxes = tf.concat([x1, y1, x2, y2, scores], axis = 1)
    # select out good bounding boxes 
    bboxes = tf.gather(bboxes, tf.cast(inds, tf.int64))
    labels = tf.gather(labels, tf.cast(inds, tf.int64))
    
    labels = tf.squeeze(labels)
    return bboxes, labels


def nms_float(dets, labels, thresh):
    """
    # It's different from original nms because we have float coordinates on range [0; 1]
    :param dets: numpy array of boxes with shape: (N, 5). Order: x1, y1, x2, y2, score. All variables in range [0; 1]
    :param thresh: IoU value for boxes
    :return: index of boxes to keep
    """
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]
    
    areas = (x2 - x1) * (y2 - y1)
    order = tf.argsort(scores)[::-1]
    
    keep = tf.zeros((0, ), dtype = tf.int32)
    while order.shape[0] > 0:
        i = order[0]
      
        keep = tf.concat([keep, tf.reshape(tf.constant(i, dtype = keep.dtype), (1, ) )], axis = 0)
        i = tf.reshape(i, ())
        xx1 = tf.maximum(x1[i], tf.gather(x1, order[1:]))
        yy1 = tf.maximum(y1[i], tf.gather(y1, order[1:]))
        xx2 = tf.minimum(x2[i], tf.gather(x2, order[1:]))
        yy2 = tf.minimum(y2[i], tf.gather(y2, order[1:]))

        w = tf.maximum(0.0, xx2 - xx1)
        h = tf.maximum(0.0, yy2 - yy1)
        inter = w * h
        ovr = inter / (areas[i] + tf.gather(areas, order[1:]) - inter)
        inds = tf.where(ovr <= thresh)
        order = tf.gather(order, inds + 1)
    dets = tf.gather(dets, keep)
    labels = tf.gather(labels, keep)

    labels = tf.squeeze(labels)
    return dets, labels


def nms_method(boxes, labels, method=3, iou_thr=0.5, sigma=0.5, thresh=0.001, weights=None):
    """
    :param boxes: list of boxes predictions from each model, each box is 4 numbers. 
    It has 3 dimensions (models_number, model_preds, 4)
    Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1] 
    :param scores: list of scores for each model 
    :param labels: list of labels for each model
    :param method: 1 - linear soft-NMS, 2 - gaussian soft-NMS, 3 - standard NMS
    :param iou_thr: IoU value for boxes to be a match 
    :param sigma: Sigma value for SoftNMS
    :param thresh: threshold for boxes to keep (important for SoftNMS)
    :param weights: list of weights for each model. Default: None, which means weight == 1 for each model
    :return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2). 
    :return: scores: confidence scores
    :return: labels: boxes labels
    """
    # Run NMS independently for each label
    unique_labels = tf.unique(labels)
    final_boxes = tf.zeros((0, boxes.shape[1]), dtype = boxes.dtype)
    final_labels = tf.zeros((0, ), dtype = labels.dtype)
    for l in unique_labels:
        condition = (labels == tf.cast(l, dtype = labels.dtype))
        boxes_by_label = boxes[condition]
        labels_by_label = np.array([l] * len(boxes_by_label))

        if method != 3:
            bboxes, labels = soft_nms_float(tf.identity(boxes_by_label), labels_by_label, Nt=iou_thr, sigma=sigma, thresh=thresh, method=method)
        else:
            # Use faster function
            bboxes, labels = nms_float(boxes_by_label, labels_by_label, thresh=iou_thr)


        final_boxes = tf.concat([final_boxes, bboxes], axis = 0)
        if len(labels.shape) != len(final_labels.shape):
          labels = tf.expand_dims(labels, axis = -1)
        final_labels = tf.concat([final_labels, tf.cast(labels, dtype = final_labels.dtype)], axis = 0)
  
    return final_boxes, final_labels


def nms(boxes, labels, iou_thr=0.5, weights=None):
    """
    Short call for standard NMS 
    
    :param boxes: 
    :param scores: 
    :param labels: 
    :param iou_thr: 
    :param weights: 
    :return: 
    """
    return nms_method(boxes, labels, method=3, iou_thr=iou_thr, weights=weights)


def soft_nms(boxes, labels, method=2, iou_thr=0.5, sigma=0.5, thresh=0.001, weights=None):
    """
    Short call for Soft-NMS
     
    :param boxes: 
    :param scores: 
    :param labels: 
    :param method: 
    :param iou_thr: 
    :param sigma: 
    :param thresh: 
    :param weights: 
    :return: 
    """
    return nms_method(boxes, labels, method=method, iou_thr=iou_thr, sigma=sigma, thresh=thresh, weights=weights)

# WBF



In [None]:
def increment_index(indices, index):
  begin = indices[:index]
  end = indices[index + 1:]
  middle = tf.expand_dims(indices[index] + 1, axis = 0)
  indices = tf.concat([begin, middle, end], axis = 0)
  return indices
def replace_index_double_bbox(bboxes, first_index, second_index, new_value):
  # Double Index
  first_index = tf.squeeze(first_index)
  second_index = tf.squeeze(second_index)
  
  begin = bboxes[:first_index]
  end = bboxes[first_index + 1:]

  middle = bboxes[first_index]
  middle = tf.expand_dims(replace_index_bbox(middle, second_index, new_value), axis = 0)
  new_bboxes = tf.concat([begin, middle, end], axis = 0)
  return new_bboxes


In [None]:
def compute_weighted_box(bboxes, i):
  # Computes an AVG Weighted Box
  # Boxes: Tensor(NUM_BOXES, NUM_BOXES, 5)
  num_bboxes = len(bboxes)
  to_avg_bboxes = bboxes[i] # (NUM_BOXES, 5)
  # Average These Bounding Boxes According to Confidence Score
  X1 = tf.zeros((1, ), dtype = bboxes.dtype)
  Y1 = tf.zeros((1, ), dtype = bboxes.dtype)
  X2 = tf.zeros((1, ), dtype = bboxes.dtype)
  Y2 = tf.zeros((1, ), dtype = bboxes.dtype)
  CONF = tf.zeros((1, ), dtype = bboxes.dtype)
 
  for j in range(num_bboxes):
    bbox = to_avg_bboxes[j] # (5)
    
    conf = bbox[4]
    x1 = bbox[0] * conf
    y1 = bbox[1] * conf
    x2 = bbox[2] * conf
    y2 = bbox[3] * conf
    X1 = X1 + x1
    Y1 = Y1 + y1
    X2 = X2 + x2
    Y2 = Y2 + y2
    CONF = CONF + conf

  X1 = X1 / CONF
  Y1 = Y1 / CONF
  X2 = X2 / CONF
  Y2 = Y2 / CONF
  CONF = CONF / tf.cast(num_bboxes, dtype = CONF.dtype)
  
  LBL = tf.ones_like(CONF) * bboxes[0, 0, 5]
  
  new_bbox = tf.concat([X1, Y1, X2, Y2, CONF, LBL], axis = 0) # (5, )
  return new_bbox

In [None]:
def find_matching_box(boxes, new_box, match_iou):
    # Boxes: Tensor(Len(boxes), 5)
    # New_box: Tensor(5)
    boxes_idx = tf.where(tf.not_equal(boxes[:, 4], 0.0))
    boxes = tf.squeeze(tf.gather(boxes, boxes_idx), axis = 1)
    new_box = tf.expand_dims(new_box, axis = 0)
    iou_computation = compute_iou(boxes, new_box)
    max_iou = tf.reduce_max(iou_computation)
    if max_iou > match_iou:
        # Find the Maximum
        idx = tf.squeeze(tf.argmax(iou_computation, axis = -1))
        return tf.cast(idx, dtype = tf.int32)
    else:
      return tf.constant(-1, dtype = tf.int32)
@tf.function
def weighted_boxes_fusion(boxes, labels, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False):
    '''
    :param boxes_list: list of boxes predictions from each model, each box is 4 numbers.
    It has 3 dimensions (models_number, model_preds, 4)
    Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1]
    Boxes: Tensor(M, N, 5)
    Labels: Tensor(m, N)
    :param weights: list of weights for each model. Default: None, which means weight == 1 for each model
    :param iou_thr: IoU value for boxes to be a match
    :param skip_box_thr: exclude boxes with score lower than this variable
    :param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value, 'box_and_model_avg': box and model wise hybrid weighted average, 'absent_model_aware_avg': weighted average that takes into account the absent model.
    :param allows_overflow: false if we want confidence score not exceed 1.0
    :return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2).
    :return: scores: confidence scores
    :return: labels: boxes labels
    '''

    if weights is None:
        weights = tf.ones(len(boxes))
    SUM_OF_WEIGHTS = tf.cast(tf.reduce_sum(weights), dtype = boxes.dtype)
    # Append the Labels to the Bounding Box 
    N = boxes.shape[0] * boxes.shape[1]
    final_bounding_boxes = tf.zeros((N, 6), dtype = boxes.dtype)

    expanded_labels = tf.expand_dims(labels, axis = 2)
    boxes = tf.concat([boxes, expanded_labels], axis = 2)

    boxes = tf.reshape(boxes, (-1, 6))
    unique_labels = tf.unique(tf.reshape(labels, (-1, )))[0]
    labels = tf.reshape(labels, (-1, ))
    bounding_idx = tf.zeros((1, ), dtype = tf.int64)
    for i in range(len(unique_labels)):
      cur_lbl = unique_labels[i]
      # Select Bboxes with this lbl
      boolean_mask = tf.equal(labels, cur_lbl)
      cur_boxes = tf.boolean_mask(boxes, boolean_mask, axis = 0)

      new_boxes = tf.zeros((len(cur_boxes), len(cur_boxes), 6), dtype = cur_boxes.dtype)
      weighted_boxes = tf.zeros((len(cur_boxes), 6), dtype = cur_boxes.dtype)
      cur_idx = tf.zeros((len(cur_boxes), ), dtype = tf.int32)
      cur_filled_idx = tf.zeros((1, ), dtype = tf.int32)
    
      # Clusterize boxes
      for j in range(0, len(cur_boxes)):
          bounding_box = tf.identity(cur_boxes[j])
          index = find_matching_box(weighted_boxes, bounding_box, iou_thr)
          
          if index != -1:
              new_boxes = replace_index_double_bbox(new_boxes, index, tf.gather(cur_idx, index), bounding_box)
              # Compute New Weighted Boxes.
              weighted_box = compute_weighted_box(new_boxes, index)
              weighted_boxes = replace_index_bbox(weighted_boxes, index, weighted_box)

              cur_idx = increment_index(cur_idx, index)
          else:
              idx = tf.squeeze(cur_filled_idx)
              new_boxes = replace_index_double_bbox(new_boxes, cur_filled_idx, cur_idx[idx], bounding_box)
              weighted_boxes = replace_index_bbox(weighted_boxes, idx, bounding_box)
              cur_idx = increment_index(cur_idx, idx)
              cur_filled_idx = cur_filled_idx + 1
      # Find Number of Bounding Boxes
      num_weighted = cur_filled_idx  #idx is 0 based
      # Rescale confidence based on number of models and boxes
      for i in range(tf.squeeze(num_weighted)):
          clustered_boxes = new_boxes[i] # (NUM_BBOXES, 5)
          num_boxes = tf.cast(tf.not_equal(clustered_boxes[:, 4], 0.0), dtype = clustered_boxes.dtype)
          num_boxes = tf.reduce_sum(num_boxes)

          weighted_score = weighted_boxes[i, 4] * tf.minimum(SUM_OF_WEIGHTS, num_boxes) / SUM_OF_WEIGHTS
          x1 = tf.expand_dims(weighted_boxes[i, 0], axis = 0)
          y1 = tf.expand_dims(weighted_boxes[i, 1], axis = 0)
          x2 = tf.expand_dims(weighted_boxes[i, 2], axis = 0)
          y2 = tf.expand_dims(weighted_boxes[i, 3], axis = 0)
          lbl = tf.expand_dims(weighted_boxes[i, 5], axis = 0)
          weighted_score = tf.expand_dims(weighted_score, axis = 0)

          bbox = tf.concat([x1, y1, x2, y2, weighted_score, lbl], axis = 0)
          bbox = tf.expand_dims(bbox, axis = 0)

          begin = weighted_boxes[:i]
          end = weighted_boxes[i + 1:]
          weighted_boxes = tf.concat([begin, bbox, end], axis = 0)
      weighted_boxes = weighted_boxes[:tf.squeeze(num_weighted)]
      num_bounding_boxes = len(weighted_boxes)
      cur_idx = tf.squeeze(bounding_idx)

      begin = final_bounding_boxes[:cur_idx]
      end = final_bounding_boxes[tf.cast(cur_idx, dtype =num_bounding_boxes.dtype)  + num_bounding_boxes:]
      middle = weighted_boxes


      bounding_idx = bounding_idx + tf.cast(num_bounding_boxes, dtype = bounding_idx.dtype)
      if cur_idx == 0:
        final_bounding_boxes = tf.concat([middle, end], axis = 0)
      elif cur_idx + tf.cast(num_bounding_boxes, dtype = cur_idx.dtype) >= tf.cast(len(final_bounding_boxes), dtype = cur_idx.dtype):
        final_bounding_boxes = tf.concat([begin, middle], axis = 0)
      else:
        final_bounding_boxes = tf.concat([begin, middle, end], axis = 0)
      final_bounding_boxes = tf.reshape(final_bounding_boxes, (N, 6))
    x1 = tf.expand_dims(final_bounding_boxes[:, 0], axis = 1)
    y1 = tf.expand_dims(final_bounding_boxes[:, 1], axis = 1)
    x2 = tf.expand_dims(final_bounding_boxes[:, 2], axis = 1)
    y2 = tf.expand_dims(final_bounding_boxes[:, 3], axis = 1)
    obj = tf.expand_dims(final_bounding_boxes[:, 4], axis = 1)
    labels = final_bounding_boxes[:, 5]

    boxes = tf.concat([x1, y1, x2, y2, obj], axis = 1)
    return boxes, labels


# Metric Computation: "Accuracy"(Actually Precision)

In [None]:
class Accuracy(keras.metrics.Metric):
  # Average Accuracy Across Validation(close enough to ADA) 
  def __init__(self, name = 'accuracy', threshold = 0.5, **kwargs):
    super().__init__(name = name, **kwargs)
    self.threshold = threshold
    self.tp = tf.Variable(tf.constant(0.0, dtype = TARGET_DTYPE))
    self.fp = tf.Variable(tf.constant(0.0, dtype = TARGET_DTYPE))
    self.fn = tf.Variable(tf.constant(0.0, dtype = TARGET_DTYPE))
  def reset_states(self):
    self.tp.assign(0.0)
    self.fp.assign(0.0)
    self.fn.assign(0.0)
  @tf.function  
  def update_state(self, GT_bbox, GT_class, pred_bbox, pred_class):
    # GT_bbox: Tensor(B, N, 5)
    # GT_class: Tensor(B, 196196, 5)
    B = GT_bbox.shape[0]
    for b in range(B):
        ground_truth_bbox = GT_bbox[b] # (196196, 5)
        ground_truth_class = ground_truth_bbox[:, -1:] # (196196, 1)
        ground_truth_bbox = ground_truth_bbox[:, :-1] # (196196, 4)
        
        predicted_bounding_boxes = pred_bbox[b] # (196196, 5)
        predicted_classes = predicted_bounding_boxes[:, -1:] # (196196, 1)
        predicted_bounding_boxes = predicted_bounding_boxes[:, :-1] # (196196, 4)
        
        # Threshold the Bounding Boxes(to reduce the work of NMS)
        
    
    self.matching_method(GT_bbox, GT_class, pred_bbox, pred_class)
  def matching_method(self, GT_bbox, GT_class, pred_bbox, pred_class, threshold = None):
    # Matches Bounding Boxes to Ground-Truths
    if threshold is None:
      threshold = self.threshold
    # Bboxes: Tensor(M, 4) - Remove the OBJ score simply due to how these are already nms threshed.
    # GT: Tensor(N, 4) - X1, Y1, X2, Y2 quads, no labels needed since there is only 1 class.
    M, _ = GT_bbox.shape
    N, _ = pred_bbox.shape
    already_matched = tf.ones((N,), dtype = GT_bbox.dtype) * tf.constant(-1, dtype = GT_bbox.dtype)
    for i in range(M):
      ground_truth_class = GT_class[i] # (1,)
      ground_truth = tf.expand_dims(GT_bbox[i], axis = 0) # (1, 4)
      iou_computation = tf.squeeze(compute_iou(pred_bbox, ground_truth)) # (N, )
      # Mask out the Previously Selected Bounding Boxes.
      boolean_mask = tf.equal(already_matched, tf.constant(-1, dtype = pred_bbox.dtype)) # (N, )
      boolean_mask = tf.cast(boolean_mask, dtype = pred_bbox.dtype) # (N, )
      iou_computation = iou_computation * boolean_mask # (N, )
      # Check if Any Bboxes Currently Match(IOU > 0.5) - Find the First One.
      thresholded = iou_computation > threshold
      index = tf.argmax(tf.cast(thresholded, dtype = TARGET_DTYPE))
      if tf.reduce_sum(tf.cast(thresholded[index], tf.float32)) == 0.0:# and ground_truth_class == pred_class[index]:
        # Nothing Found, so False Negative.
        self.fn.assign_add(1.0)
      else:
        # Found One, remove this bounding box and increase tp
        self.tp.assign_add(1.0) 
        already_matched = replace_index_scores(already_matched, index, tf.constant(1.0, dtype = pred_bbox.dtype))
    # Find all False Positives Left
    num_false_positives = tf.equal(already_matched, tf.constant(-1, dtype = pred_bbox.dtype))
    num_false_positives = tf.cast(num_false_positives, dtype = TARGET_DTYPE)
    num_false_positives = tf.reduce_sum(num_false_positives)

    self.fp.assign_add(num_false_positives)      
  def result(self):
    tp = self.tp
    fp = self.fp
    fn = self.fn

    eps = 1e-10
    return (tp + eps) / (tp + fp + fn + eps)

# Model Definitions
- Code Copied from GitHub Repo and slightly modified for easier training(outside of built in fn's).
- Similar to Official AutoML EfficientDet.

# Metric:
- Thankfully, it's a really simple metric: Accuracy, where accuracy is computed using:
- ACC = TP / (TP + FN + FP) 
-

EfficientNet Backbone(B4)

In [None]:
def EfficientNetBN(n, input_tensor=None, input_shape=None, **kwargs):
    CONV_KERNEL_INITIALIZER = {
        'class_name': 'VarianceScaling',
        'config': {
            'scale': 2.0,
            'mode': 'fan_out',
            # EfficientNet actually uses an untruncated normal distribution for
            # initializing conv layers, but keras.initializers.VarianceScaling use
            # a truncated distribution.
            # We decided against a custom initializer for better serializability.
            'distribution': 'normal'
        }
    }

    def get_swish():
        def swish(x):
            return x * tf.math.sigmoid(x)
        return swish


    def get_dropout():
        class FixedDropout(layers.Dropout):
            def _get_noise_shape(self, inputs):
                if self.noise_shape is None:
                    return self.noise_shape
                symbolic_shape = tf.shape(inputs)
                noise_shape = [symbolic_shape[axis] if (shape is None) else shape for axis, shape in enumerate(self.noise_shape)]
                return tuple(noise_shape)
        return FixedDropout


    def round_filters(filters, width_coefficient, depth_divisor):
        filters *= width_coefficient
        new_filters = int(filters + depth_divisor / 2) // depth_divisor * depth_divisor
        new_filters = max(depth_divisor, new_filters)
        if new_filters < 0.9 * filters:
            new_filters += depth_divisor
        return int(new_filters)


    def round_repeats(repeats, depth_coefficient):
        return int(math.ceil(depth_coefficient * repeats))


    def mb_conv_block(inputs, block_args, activation, drop_rate=None, prefix='', freeze_bn=False):
        has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1)
        bn_axis = 3 

        Dropout = get_dropout()

        filters = block_args.input_filters * block_args.expand_ratio
        if block_args.expand_ratio != 1:
            x = layers.Conv2D(filters, 1, padding='same', use_bias=False, kernel_initializer=CONV_KERNEL_INITIALIZER, name=prefix + 'expand_conv')(inputs)
            x = layers.BatchNormalization(axis=bn_axis, name=prefix + 'expand_bn')(x)
            x = layers.Activation(activation, name=prefix + 'expand_activation')(x)
        else:
            x = inputs

        # Depthwise Convolution
        x = layers.DepthwiseConv2D(block_args.kernel_size, strides=block_args.strides, padding='same', use_bias=False, depthwise_initializer=CONV_KERNEL_INITIALIZER, name=prefix + 'dwconv')(x)
        x = layers.BatchNormalization(axis=bn_axis, name=prefix + 'bn')(x)
        x = layers.Activation(activation, name=prefix + 'activation')(x)

        # Squeeze and Excitation phase
        if has_se:
            num_reduced_filters = max(1, int(block_args.input_filters * block_args.se_ratio))
            se_tensor = layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x)

            target_shape = (1, 1, filters) if backend.image_data_format() == 'channels_last' else (filters, 1, 1)
            se_tensor = layers.Reshape(target_shape, name=prefix + 'se_reshape')(se_tensor)
            se_tensor = layers.Conv2D(num_reduced_filters, 1, activation=activation, padding='same', use_bias=True, kernel_initializer=CONV_KERNEL_INITIALIZER, name=prefix + 'se_reduce')(se_tensor)
            se_tensor = layers.Conv2D(filters, 1, activation='sigmoid', padding='same', use_bias=True, kernel_initializer=CONV_KERNEL_INITIALIZER, name=prefix + 'se_expand')(se_tensor)
            if backend.backend() == 'theano':
                pattern = ([True, True, True, False] if (backend.image_data_format() == 'channels_last') else [True, False, True, True])
                se_tensor = layers.Lambda(lambda x: backend.pattern_broadcast(x, pattern), name=prefix + 'se_broadcast')(se_tensor)
            x = layers.multiply([x, se_tensor], name=prefix + 'se_excite')

        # Output phase
        x = layers.Conv2D(block_args.output_filters, 1, padding='same', use_bias=False, kernel_initializer=CONV_KERNEL_INITIALIZER, name=prefix + 'project_conv')(x)
        x = layers.BatchNormalization(axis=bn_axis, name=prefix + 'project_bn')(x)
        if block_args.id_skip and all(s == 1 for s in block_args.strides) and block_args.input_filters == block_args.output_filters:
            if drop_rate and (drop_rate > 0):
                x = Dropout(drop_rate, noise_shape=(None, 1, 1, 1), name=prefix + 'drop')(x)
            x = layers.add([x, inputs], name=prefix + 'add')
        return x


    def EfficientNet(width_coefficient, depth_coefficient, drop_connect_rate=0.2, depth_divisor=8, input_tensor=None, input_shape=None, freeze_bn=False, **kwargs):
        BlockArgs = collections.namedtuple('BlockArgs', [
            'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
            'expand_ratio', 'id_skip', 'strides', 'se_ratio'
        ])
        blocks_args = [
            BlockArgs(kernel_size=3, num_repeat=1, input_filters=32, output_filters=16, expand_ratio=1, id_skip=True, strides=[1, 1], se_ratio=0.25),
            BlockArgs(kernel_size=3, num_repeat=2, input_filters=16, output_filters=24, expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
            BlockArgs(kernel_size=5, num_repeat=2, input_filters=24, output_filters=40, expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
            BlockArgs(kernel_size=3, num_repeat=3, input_filters=40, output_filters=80, expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
            BlockArgs(kernel_size=5, num_repeat=3, input_filters=80, output_filters=112, expand_ratio=6, id_skip=True, strides=[1, 1], se_ratio=0.25),
            BlockArgs(kernel_size=5, num_repeat=4, input_filters=112, output_filters=192, expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
            BlockArgs(kernel_size=3, num_repeat=1, input_filters=192, output_filters=320, expand_ratio=6, id_skip=True, strides=[1, 1], se_ratio=0.25)
        ]
        
        features = []

        img_input = layers.Input(shape=input_shape) if (input_tensor is None) else (input_tensor)

        bn_axis = 3 
        activation = get_swish(**kwargs)

        # Build stem
        x = img_input

        x = layers.Conv2D(round_filters(32, width_coefficient, depth_divisor), 3, strides=(2, 2), padding='same', use_bias=False, kernel_initializer=CONV_KERNEL_INITIALIZER, name='stem_conv')(x)
        x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
        x = layers.Activation(activation, name='stem_activation')(x)

        # Build blocks
        num_blocks_total = sum(block_args.num_repeat for block_args in blocks_args)
        block_num = 0
        for idx, block_args in enumerate(blocks_args):
            assert block_args.num_repeat > 0
            # Update block input and output filters based on depth multiplier.
            block_args = block_args._replace(
                input_filters=round_filters(block_args.input_filters, width_coefficient, depth_divisor),
                output_filters=round_filters(block_args.output_filters, width_coefficient, depth_divisor),
                num_repeat=round_repeats(block_args.num_repeat, depth_coefficient))

            # The first block needs to take care of stride and filter size increase.
            drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
            x = mb_conv_block(x, block_args, activation=activation, drop_rate=drop_rate, prefix='block{}a_'.format(idx + 1), freeze_bn=freeze_bn)
            block_num += 1
            if block_args.num_repeat > 1:
                # pylint: disable=protected-access
                block_args = block_args._replace(
                    input_filters=block_args.output_filters, strides=[1, 1])
                # pylint: enable=protected-access
                for bidx in range(block_args.num_repeat - 1):
                    drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
                    block_prefix = 'block{}{}_'.format(idx + 1, string.ascii_lowercase[bidx + 1])
                    x = mb_conv_block(x, block_args, activation=activation, drop_rate=drop_rate, prefix=block_prefix, freeze_bn=freeze_bn)
                    block_num += 1
            if idx < len(blocks_args) - 1 and blocks_args[idx + 1].strides[0] == 2:
                features.append(x)
            elif idx == len(blocks_args) - 1:
                features.append(x)
        return features

    
    parms = [
        { "width_coefficient" : 1.0, "depth_coefficient" : 1.0, "default_resolution" : 224},
        { "width_coefficient" : 1.0, "depth_coefficient" : 1.1, "default_resolution" : 240},
        { "width_coefficient" : 1.1, "depth_coefficient" : 1.2, "default_resolution" : 260},
        { "width_coefficient" : 1.2, "depth_coefficient" : 1.4, "default_resolution" : 300},
        { "width_coefficient" : 1.4, "depth_coefficient" : 1.8, "default_resolution" : 380},
        { "width_coefficient" : 1.6, "depth_coefficient" : 2.2, "default_resolution" : 456},
        { "width_coefficient" : 1.8, "depth_coefficient" : 2.6, "default_resolution" : 528},
        { "width_coefficient" : 2.0, "depth_coefficient" : 3.1, "default_resolution" : 600},
    ][n]
    return EfficientNet(parms['width_coefficient'], parms['depth_coefficient'], input_tensor=input_tensor, input_shape=input_shape, **kwargs)

#print(EfficientNetBN(7, input_shape=(600, 600, 3)))

In [None]:

MOMENTUM = 0.99
EPSILON = 1e-3

class wBiFPNAdd(layers.Layer):
    def __init__(self, epsilon=1e-4, **kwargs):
        super(wBiFPNAdd, self).__init__(**kwargs)
        self.epsilon = epsilon

    def build(self, input_shape):
        num_in = len(input_shape)
        self.w = self.add_weight(name=self.name, shape=(num_in,), initializer=initializers.constant(1 / num_in), trainable=True, dtype= tf.float32)

    def call(self, inputs, **kwargs):
        w = tf.cast(activations.relu(self.w), dtype = inputs[0].dtype)
        x = tf.reduce_sum([w[i] * inputs[i] for i in range(len(inputs))], axis=0)
        x = x / (tf.reduce_sum(w) + self.epsilon)
        return x

    def compute_output_shape(self, input_shape):
        return input_shape[0]

    def get_config(self):
        config = super(wBiFPNAdd, self).get_config()
        config.update({ 'epsilon': self.epsilon })
        return config
    
    
def SeparableConvBlock(num_channels, kernel_size, strides, name, freeze_bn=False):
    f1 = layers.SeparableConv2D(num_channels, kernel_size=kernel_size, strides=strides, padding='same',
                                use_bias=True, name=f'{name}/conv')
    f2 = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON, name=f'{name}/bn')
    return lambda *args, **kwargs: f2(f1(*args, **kwargs))


def build_wBiFPN(features, num_channels, id, freeze_bn=False):
    if id == 0:
        _, _, C3, C4, C5 = features
        
        # 第一次BIFPN需要 下采样 与 降通道 获得 p3_in p4_in p5_in p6_in p7_in
        #-----------------------------下采样 与 降通道----------------------------#
        P3_in = C3
        
        P3_in = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                              name=f'fpn_cells/cell_{id}/fnode3/resample_0_0_8/conv2d')(P3_in)
        P3_in = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                          name=f'fpn_cells/cell_{id}/fnode3/resample_0_0_8/bn')(P3_in)

        P4_in = C4
        P4_in_1 = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                                name=f'fpn_cells/cell_{id}/fnode2/resample_0_1_7/conv2d')(P4_in)
        P4_in_1 = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                            name=f'fpn_cells/cell_{id}/fnode2/resample_0_1_7/bn')(P4_in_1)
        P4_in_2 = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                                name=f'fpn_cells/cell_{id}/fnode4/resample_0_1_9/conv2d')(P4_in)
        P4_in_2 = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                            name=f'fpn_cells/cell_{id}/fnode4/resample_0_1_9/bn')(P4_in_2)

        P5_in = C5
        P5_in_1 = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                                name=f'fpn_cells/cell_{id}/fnode1/resample_0_2_6/conv2d')(P5_in)
        P5_in_1 = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                            name=f'fpn_cells/cell_{id}/fnode1/resample_0_2_6/bn')(P5_in_1)
        P5_in_2 = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                                name=f'fpn_cells/cell_{id}/fnode5/resample_0_2_10/conv2d')(P5_in)
        P5_in_2 = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                            name=f'fpn_cells/cell_{id}/fnode5/resample_0_2_10/bn')(P5_in_2)

        P6_in = layers.Conv2D(num_channels, kernel_size=1, padding='same', name='resample_p6/conv2d')(C5)
        P6_in = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON, name='resample_p6/bn')(P6_in)
        P6_in = layers.MaxPooling2D(pool_size=3, strides=2, padding='same', name='resample_p6/maxpool')(P6_in)

        P7_in = layers.MaxPooling2D(pool_size=3, strides=2, padding='same', name='resample_p7/maxpool')(P6_in)
        #-------------------------------------------------------------------------#
        #--------------------------构建BIFPN的上下采样循环-------------------------#
        P7_U = layers.UpSampling2D(dtype = tf.float16)(P7_in)
        P6_td = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode0/add')([P6_in, P7_U])
        P6_td = layers.Activation(lambda x: tf.nn.swish(x))(P6_td)
        P6_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode0/op_after_combine5')(P6_td)
        
        P6_U = layers.UpSampling2D(dtype = tf.float16)(P6_td)
        P5_td = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode1/add')([P5_in_1, P6_U])
        P5_td = layers.Activation(lambda x: tf.nn.swish(x))(P5_td)
        P5_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode1/op_after_combine6')(P5_td)

        P5_U = layers.UpSampling2D(dtype = tf.float16)(P5_td)
        P4_td = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode2/add')([P4_in_1, P5_U])
        P4_td = layers.Activation(lambda x: tf.nn.swish(x))(P4_td)
        P4_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode2/op_after_combine7')(P4_td)

        P4_U = layers.UpSampling2D(dtype = tf.float16)(P4_td)
        P3_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode3/add')([P3_in, P4_U])
        P3_out = layers.Activation(lambda x: tf.nn.swish(x))(P3_out)
        P3_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode3/op_after_combine8')(P3_out)

        P3_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P3_out)
        P4_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode4/add')([P4_in_2, P4_td, P3_D])
        P4_out = layers.Activation(lambda x: tf.nn.swish(x))(P4_out)
        P4_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode4/op_after_combine9')(P4_out)

        P4_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P4_out)
        P5_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode5/add')([P5_in_2, P5_td, P4_D])
        P5_out = layers.Activation(lambda x: tf.nn.swish(x))(P5_out)
        P5_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode5/op_after_combine10')(P5_out)

        P5_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P5_out)
        P6_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode6/add')([P6_in, P6_td, P5_D])
        P6_out = layers.Activation(lambda x: tf.nn.swish(x))(P6_out)
        P6_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode6/op_after_combine11')(P6_out)

        P6_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P6_out)
        P7_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode7/add')([P7_in, P6_D])
        P7_out = layers.Activation(lambda x: tf.nn.swish(x))(P7_out)
        P7_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode7/op_after_combine12')(P7_out)

    else:
        P3_in, P4_in, P5_in, P6_in, P7_in = features
        # Change the Dtypes to float16
        
        
        P7_U = layers.UpSampling2D(dtype = tf.float16)(P7_in)
        P6_td = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode0/add')([P6_in, P7_U])
        P6_td = layers.Activation(lambda x: tf.nn.swish(x))(P6_td)
        P6_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode0/op_after_combine5')(P6_td)

        P6_U = layers.UpSampling2D(dtype = tf.float16)(P6_td)
        P5_td = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode1/add')([P5_in, P6_U])
        P5_td = layers.Activation(lambda x: tf.nn.swish(x))(P5_td)
        P5_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode1/op_after_combine6')(P5_td)

        P5_U = layers.UpSampling2D(dtype = tf.float16)(P5_td)
        P4_td = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode2/add')([P4_in, P5_U])
        P4_td = layers.Activation(lambda x: tf.nn.swish(x))(P4_td)
        P4_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode2/op_after_combine7')(P4_td)

        P4_U = layers.UpSampling2D(dtype = tf.float16)(P4_td)
        
        P3_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode3/add')([P3_in, P4_U])
        P3_out = layers.Activation(lambda x: tf.nn.swish(x))(P3_out)
        P3_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode3/op_after_combine8')(P3_out)

        P3_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P3_out)
        P4_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode4/add')([P4_in, P4_td, P3_D])
        P4_out = layers.Activation(lambda x: tf.nn.swish(x))(P4_out)
        P4_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode4/op_after_combine9')(P4_out)

        P4_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P4_out)
        P5_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode5/add')([P5_in, P5_td, P4_D])
        P5_out = layers.Activation(lambda x: tf.nn.swish(x))(P5_out)
        P5_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode5/op_after_combine10')(P5_out)

        P5_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P5_out)
        P6_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode6/add')([P6_in, P6_td, P5_D])
        P6_out = layers.Activation(lambda x: tf.nn.swish(x))(P6_out)
        P6_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode6/op_after_combine11')(P6_out)

        P6_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P6_out)
        P7_out = wBiFPNAdd(name=f'fpn_cells/cell_{id}/fnode7/add')([P7_in, P6_D])
        P7_out = layers.Activation(lambda x: tf.nn.swish(x))(P7_out)
        P7_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode7/op_after_combine12')(P7_out)
    return [P3_out, P4_out, P5_out, P6_out, P7_out]
def build_BiFPN(features, num_channels, id, freeze_bn=False):
    if id == 0:
        # 第一次BIFPN需要 下采样 与 降通道 获得 p3_in p4_in p5_in p6_in p7_in
        #-----------------------------下采样 与 降通道----------------------------#
        _, _, C3, C4, C5 = features
        P3_in = C3
        P3_in = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                              name=f'fpn_cells/cell_{id}/fnode3/resample_0_0_8/conv2d')(P3_in)
        P3_in = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                          name=f'fpn_cells/cell_{id}/fnode3/resample_0_0_8/bn')(P3_in)

        P4_in = C4
        P4_in_1 = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                                name=f'fpn_cells/cell_{id}/fnode2/resample_0_1_7/conv2d')(P4_in)
        P4_in_1 = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                            name=f'fpn_cells/cell_{id}/fnode2/resample_0_1_7/bn')(P4_in_1)
        P4_in_2 = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                                name=f'fpn_cells/cell_{id}/fnode4/resample_0_1_9/conv2d')(P4_in)
        P4_in_2 = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                            name=f'fpn_cells/cell_{id}/fnode4/resample_0_1_9/bn')(P4_in_2)

        P5_in = C5
        P5_in_1 = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                                name=f'fpn_cells/cell_{id}/fnode1/resample_0_2_6/conv2d')(P5_in)
        P5_in_1 = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                            name=f'fpn_cells/cell_{id}/fnode1/resample_0_2_6/bn')(P5_in_1)
        P5_in_2 = layers.Conv2D(num_channels, kernel_size=1, padding='same',
                                name=f'fpn_cells/cell_{id}/fnode5/resample_0_2_10/conv2d')(P5_in)
        P5_in_2 = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON,
                                            name=f'fpn_cells/cell_{id}/fnode5/resample_0_2_10/bn')(P5_in_2)

        P6_in = layers.Conv2D(num_channels, kernel_size=1, padding='same', name='resample_p6/conv2d')(C5)
        P6_in = layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON, name='resample_p6/bn')(P6_in)
        P6_in = layers.MaxPooling2D(pool_size=3, strides=2, padding='same', name='resample_p6/maxpool')(P6_in)

        P7_in = layers.MaxPooling2D(pool_size=3, strides=2, padding='same', name='resample_p7/maxpool')(P6_in)
        #-------------------------------------------------------------------------#

        #--------------------------构建BIFPN的上下采样循环-------------------------#
        P7_U = layers.UpSampling2D(dtype = tf.float16)(P7_in)
        P6_td = layers.Add(name=f'fpn_cells/cell_{id}/fnode0/add')([P6_in, P7_U])
        P6_td = layers.Activation(lambda x: tf.nn.swish(x))(P6_td)
        P6_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode0/op_after_combine5')(P6_td)

        P6_U = layers.UpSampling2D(dtype = tf.float16)(P6_td)
        P5_td = layers.Add(name=f'fpn_cells/cell_{id}/fnode1/add')([P5_in_1, P6_U])
        P5_td = layers.Activation(lambda x: tf.nn.swish(x))(P5_td)
        P5_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode1/op_after_combine6')(P5_td)

        P5_U = layers.UpSampling2D(dtype = tf.float16)(P5_td)
        P4_td = layers.Add(name=f'fpn_cells/cell_{id}/fnode2/add')([P4_in_1, P5_U])
        P4_td = layers.Activation(lambda x: tf.nn.swish(x))(P4_td)
        P4_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode2/op_after_combine7')(P4_td)

        P4_U = layers.UpSampling2D(dtype = tf.float16)(P4_td)
        P3_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode3/add')([P3_in, P4_U])
        P3_out = layers.Activation(lambda x: tf.nn.swish(x))(P3_out)
        P3_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode3/op_after_combine8')(P3_out)

        P3_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P3_out)
        P4_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode4/add')([P4_in_2, P4_td, P3_D])
        P4_out = layers.Activation(lambda x: tf.nn.swish(x))(P4_out)
        P4_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode4/op_after_combine9')(P4_out)

        P4_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P4_out)
        P5_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode5/add')([P5_in_2, P5_td, P4_D])
        P5_out = layers.Activation(lambda x: tf.nn.swish(x))(P5_out)
        P5_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode5/op_after_combine10')(P5_out)

        P5_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P5_out)
        P6_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode6/add')([P6_in, P6_td, P5_D])
        P6_out = layers.Activation(lambda x: tf.nn.swish(x))(P6_out)
        P6_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode6/op_after_combine11')(P6_out)

        P6_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P6_out)
        P7_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode7/add')([P7_in, P6_D])
        P7_out = layers.Activation(lambda x: tf.nn.swish(x))(P7_out)
        P7_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode7/op_after_combine12')(P7_out)

    else:
        P3_in, P4_in, P5_in, P6_in, P7_in = features
        P7_U = layers.UpSampling2D(dtype = tf.float16)(P7_in)
        P6_td = layers.Add(name=f'fpn_cells/cell_{id}/fnode0/add')([P6_in, P7_U])
        P6_td = layers.Activation(lambda x: tf.nn.swish(x))(P6_td)
        P6_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode0/op_after_combine5')(P6_td)

        P6_U = layers.UpSampling2D(dtype = tf.float16)(P6_td)
        P5_td = layers.Add(name=f'fpn_cells/cell_{id}/fnode1/add')([P5_in, P6_U])
        P5_td = layers.Activation(lambda x: tf.nn.swish(x))(P5_td)
        P5_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode1/op_after_combine6')(P5_td)

        P5_U = layers.UpSampling2D(dtype = tf.float16)(P5_td)
        P4_td = layers.Add(name=f'fpn_cells/cell_{id}/fnode2/add')([P4_in, P5_U])
        P4_td = layers.Activation(lambda x: tf.nn.swish(x))(P4_td)
        P4_td = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                   name=f'fpn_cells/cell_{id}/fnode2/op_after_combine7')(P4_td)

        P4_U = layers.UpSampling2D(dtype = tf.float16)(P4_td)
        P3_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode3/add')([P3_in, P4_U])
        P3_out = layers.Activation(lambda x: tf.nn.swish(x))(P3_out)
        P3_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode3/op_after_combine8')(P3_out)

        P3_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P3_out)
        P4_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode4/add')([P4_in, P4_td, P3_D])
        P4_out = layers.Activation(lambda x: tf.nn.swish(x))(P4_out)
        P4_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode4/op_after_combine9')(P4_out)

        P4_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P4_out)
        P5_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode5/add')([P5_in, P5_td, P4_D])
        P5_out = layers.Activation(lambda x: tf.nn.swish(x))(P5_out)
        P5_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode5/op_after_combine10')(P5_out)

        P5_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P5_out)
        P6_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode6/add')([P6_in, P6_td, P5_D])
        P6_out = layers.Activation(lambda x: tf.nn.swish(x))(P6_out)
        P6_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode6/op_after_combine11')(P6_out)

        P6_D = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(P6_out)
        P7_out = layers.Add(name=f'fpn_cells/cell_{id}/fnode7/add')([P7_in, P6_D])
        P7_out = layers.Activation(lambda x: tf.nn.swish(x))(P7_out)
        P7_out = SeparableConvBlock(num_channels=num_channels, kernel_size=3, strides=1,
                                    name=f'fpn_cells/cell_{id}/fnode7/op_after_combine12')(P7_out)
    return [P3_out, P4_out, P5_out, P6_out, P7_out]


class PriorProbability(initializers.Initializer):
    """ Apply a prior probability to the weights.
    """
    def __init__(self, probability=0.01):
        self.probability = probability

    def get_config(self):
        return { 'probability': self.probability }

    def __call__(self, shape, dtype=None):
        # set bias to -log((1 - p)/p) for foreground
        result = np.ones(shape) * -math.log((1 - self.probability) / self.probability)

        return result

    
class BoxNet(layers.Layer):
    def __init__(self, width, depth, num_anchors=9, freeze_bn=False, name='box_net', **kwargs):
        super().__init__()
        
        self.width = width
        self.depth = depth
        self.num_anchors = num_anchors
        options = {
            'kernel_size': 3,
            'strides': 1,
            'padding': 'same',
            'bias_initializer': 'zeros',
            'depthwise_initializer': initializers.VarianceScaling(),
            'pointwise_initializer': initializers.VarianceScaling(),
        }

        self.convs = [layers.SeparableConv2D(filters=width, name=f'{name}/box-{i}', **options) for i in range(depth)]
        self.head = layers.SeparableConv2D(filters=num_anchors * 4, name=f'{name}/box-predict', **options)

        self.bns = [
            [layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON, name=f'{name}/box-{i}-bn-{j}') for j in
             range(3, 8)]
            for i in range(depth)]

        self.relu = layers.Lambda(lambda x: tf.nn.swish(x))
        self.reshape = layers.Reshape((-1, 4))

    def call(self, inputs):
        feature, level = inputs
        for i in range(self.depth):
            feature = self.convs[i](feature)
            feature = self.bns[i][level](feature)
            feature = self.relu(feature)
        outputs = self.head(feature)
        outputs = self.reshape(outputs)
        return outputs


class ClassNet(layers.Layer):
    def __init__(self, width, depth, num_classes=20, num_anchors=9, freeze_bn=False, name='class_net', **kwargs):
        super().__init__()
        
        self.width = width
        self.depth = depth
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        options = {
            'kernel_size': 3,
            'strides': 1,
            'padding': 'same',
            'depthwise_initializer': initializers.VarianceScaling(),
            'pointwise_initializer': initializers.VarianceScaling(),
        }

        self.convs = [layers.SeparableConv2D(filters=width, bias_initializer='zeros', name=f'{name}/class-{i}', **options) for i in range(depth)]
        self.head = layers.SeparableConv2D(filters=num_classes * num_anchors, bias_initializer=PriorProbability(probability=0.01), name=f'{name}/class-predict', **options)
        self.bns = [[layers.BatchNormalization(momentum=MOMENTUM, epsilon=EPSILON, name=f'{name}/class-{i}-bn-{j}') for j in range(3, 8)] for i in range(depth)]
        self.relu = layers.Lambda(lambda x: tf.nn.swish(x))
        self.reshape = layers.Reshape((-1, num_classes))
        self.activation = layers.Activation('sigmoid')

    def call(self, inputs):
        feature, level = inputs
        for i in range(self.depth):
            feature = self.convs[i](feature)
            feature = self.bns[i][level](feature)
            feature = self.relu(feature)
        outputs = self.head(feature)
        outputs = self.reshape(outputs)
        outputs = self.activation(outputs)
        return outputs

def Efficientdet(phi, num_classes=20, num_anchors=9, freeze_bn=True): # MY BATCH SIZE is 2, so no chance BN works.
    assert phi in range(8)
    fpn_num_filters = [64, 88, 112, 160, 224, 288, 384,384]
    fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8]
    box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5]
    image_sizes = [512, 640, 768, 896, 1024, 1024, 1024]
    
    image_input = layers.Input((image_sizes[phi], image_sizes[phi], 3), name='input', dtype = TARGET_DTYPE)
    features = EfficientNetBN(phi, input_tensor=image_input, freeze_bn=freeze_bn)
 
    fpn_features = features
    if phi < 6:
        for i in range(fpn_cell_repeats[phi]):
            fpn_features = build_wBiFPN(fpn_features, fpn_num_filters[phi], i, freeze_bn=freeze_bn)
    else:        
        for i in range(fpn_cell_repeats[phi]):
            fpn_features = build_BiFPN(fpn_features, fpn_num_filters[phi], i, freeze_bn=freeze_bn)

    box_net = BoxNet(fpn_num_filters[phi], box_class_repeats[phi], num_anchors=num_anchors, freeze_bn=freeze_bn, name='box_net')
    class_net = ClassNet(fpn_num_filters[phi], box_class_repeats[phi], num_classes=num_classes, num_anchors=num_anchors, freeze_bn=freeze_bn, name='class_net')
    
    classification = [class_net([feature, i]) for i, feature in enumerate(fpn_features)]
    classification = layers.Concatenate(axis=1, name='classification')(classification)
    regression = [box_net([feature, i]) for i, feature in enumerate(fpn_features)]
    regression = layers.Concatenate(axis=1, name='regression')(regression)

    model = models.Model(inputs=[image_input], outputs=[regression, classification], name='efficientdet')

    return model

# Focal Losses

In [None]:
# Classification Loss
def focal_loss(y_true, y_pred):
    # Assign Y_true, Y_bbox, Y_pred
    y_true = tf.cast(y_true, dtype = tf.float32)
    y_pred = tf.cast(y_pred, dtype = tf.float32)
    alpha=0.25 
    gamma=2.0
    # print(y_true.shape, y_pred.shape)
    # y_true [batch_size, num_anchor, num_classes+1]
    # y_pred [batch_size, num_anchor, num_classes]
    anchor_state   = y_true # (B, Num_Anchors) # -1: ignrore, 0: background, 1: object
    classification = y_pred # (B, Num_Anchors)


    # Focal Loss for postive sample (Object)
    pos_bool = tf.equal(anchor_state, tf.constant(1.0, dtype = anchor_state.dtype))
    labels_for_object         = tf.boolean_mask(anchor_state, pos_bool)
    classification_for_object = tf.boolean_mask(classification, pos_bool)
    

    alpha_factor_for_object = tf.ones_like(labels_for_object) * alpha
    alpha_factor_for_object = tf.where(tf.equal(labels_for_object, 1), alpha_factor_for_object, 1 - alpha_factor_for_object)
    focal_weight_for_object = tf.where(tf.equal(labels_for_object, 1), 1 - classification_for_object, classification_for_object)
    focal_weight_for_object = alpha_factor_for_object * focal_weight_for_object ** gamma

    cls_loss_for_object = focal_weight_for_object * backend.binary_crossentropy(labels_for_object, classification_for_object)
    cls_loss_for_object = tf.reduce_sum(cls_loss_for_object)
        
    # Focal Loss for negative sample (Background)
    neg_bool = tf.equal(anchor_state, tf.constant(0.0, dtype = anchor_state.dtype))
    labels_for_back         = tf.boolean_mask(anchor_state, neg_bool)
    classification_for_back = tf.boolean_mask(classification, neg_bool)

    alpha_factor_for_back = tf.ones_like(labels_for_back) * (1 - alpha)
    focal_weight_for_back = classification_for_back
    focal_weight_for_back = alpha_factor_for_back * focal_weight_for_back ** gamma

    cls_loss_for_back = focal_weight_for_back * backend.binary_crossentropy(labels_for_back, classification_for_back)
    cls_loss_for_back = tf.reduce_sum(cls_loss_for_back)

    # num of postive sample (Object) 
    normalizer = tf.where(tf.equal(anchor_state, 1))
    normalizer = tf.cast(tf.shape(normalizer)[0], tf.float32)
    normalizer = tf.maximum(normalizer, 1.0)


    # totoal loss
    loss = (cls_loss_for_object + cls_loss_for_back) / normalizer  # norm by num of postive samples
    return loss

# Regression Loss

def smooth_l1(y_true, y_pred, sigma =3.0):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    sigma_squared = sigma ** 2
    # print(y_true.shape, y_pred.shape)
    # y_true [batch_size, num_anchor, 5]
    # y_pred [batch_size, num_anchor, 4]
    
    regression        = y_pred
    regression_target = y_true[:, :, :-1]
    anchor_state      = y_true[:, :, -1]

    # Select postive samples
    indices           = tf.where(tf.equal(anchor_state, 1))
    regression        = tf.gather_nd(regression, indices)
    regression_target = tf.gather_nd(regression_target, indices)

    # compute smooth L1 loss
    # f(x) = 0.5 * (sigma * x)^2   if |x| < 1 / sigma / sigma
    # |x| - 0.5 / sigma / sigma    otherwise
    regression_diff = regression - regression_target
    regression_diff = tf.abs(regression_diff)
    regression_loss = tf.where(regression_diff <= (1.0 / sigma_squared), 0.5 * sigma_squared * tf.math.pow(regression_diff, 2), regression_diff - 0.5 / sigma_squared)

    # compute the normalizer: the number of positive anchors
    normalizer = tf.maximum(tf.shape(indices)[0], 1)
    normalizer = tf.cast(normalizer, tf.float32)
    return tf.reduce_sum(regression_loss) / normalizer / 4



# Build the Model

In [None]:
class ModelConfig:
  num_classes = 1
  phi = 4

In [None]:
def load_model():
  model = Efficientdet(4, num_classes=1)
  model.load_weights('../input/pretrainedmodelwheat/model_32_0.1533_0.0336_0.1197_0.00006.h5', by_name=True, skip_mismatch=False)
  return model

In [None]:
model = load_model()

# Create Lr Scheduler 

In [None]:
class ParamScheduler:
    def __init__(self, start, end, num_iter):
        self.start = start
        self.end = end
        self.num_iter = num_iter
        self.idx = -1
        
    def step(self):
        self.idx+=1
        return self.func(self.start, self.end, self.idx/self.num_iter)
    
    def reset(self):
        self.idx=-1
        
    def is_complete(self):
        return self.idx >= self.num_iter

class CosineScheduler(ParamScheduler):
    def func(self, start_val, end_val, pct):
        cos_out = np.cos(np.pi * pct) + 1
        return end_val + (start_val - end_val)/2 * cos_out
class ConstantScheduler(ParamScheduler):
    def __init__(self, init_lr, num_steps):
        self.init_lr = init_lr
        self.num_steps = num_steps
        self.steps = -1
    def step(self):
        self.steps += 1
        return self.init_lr
    def reset(self):
        self.steps = -1
    def is_complete(self):
        return self.steps >= self.num_steps
class OneCycleScheduler(keras.callbacks.Callback):
    
    def __init__(self, init_lr, max_lr, min_lr, warm_steps, peak_steps, total_steps):
        momentums=(0.95,0.85)
        start_div=25.
        pct_start=warm_steps
        pct_climax = peak_steps# Stay at the peak for 0.1 of training.
        verbose=True
        sched=CosineScheduler
        end_div=None
        self.pct_climax = pct_climax
        self.max_lr, self.momentums, self.start_div, self.pct_start, self.verbose, self.sched, self.end_div = max_lr, momentums, start_div, pct_start, verbose, sched, end_div
        if self.end_div is None:
            self.end_div = start_div * 1e4
        self.logs = {}
        self.min_lr = min_lr
        self.init_lr = init_lr
  
        self.start_lr = self.max_lr/self.start_div
        self.end_lr = self.max_lr/self.end_div 
        self.num_iter = int(total_steps * 0.8) # Pad the Steps a bit to make sure no overflow.
        self.num_iter_1 = int(self.pct_start*self.num_iter)
        self.num_iter_2 = int(self.pct_climax * self.num_iter)
        self.num_iter_3 = self.num_iter - self.num_iter_1 - self.num_iter_2
        
        self.lr_scheds = (self.sched(self.start_lr, self.max_lr, self.num_iter_1), ConstantScheduler(self.max_lr, self.num_iter_2), self.sched(self.max_lr, self.end_lr, self.num_iter_3))
        self.sched_idx = 0 
        
    def optimizer_params_step(self):
        try:
          next_lr = self.lr_scheds[self.sched_idx].step()
        except:
          next_lr = self.min_lr
        next_lr = tf.maximum(next_lr, self.min_lr)
        next_lr = tf.cast(next_lr, tf.float32)
        # update optimizer params
        return next_lr
        
    def step(self, idx):
        for i in range(TRAIN_NUMBER // BATCH_SIZE):
            lr = self.optimizer_params_step()
        try:
          if self.lr_scheds[self.sched_idx].is_complete():
              self.sched_idx += 1
        except:
          pass
        return lr

# Grad Acc Adam

In [None]:
class GradAccAdam():
    # Just a Wrapper to Accumulate Gradients and Send them to Adam
    def __init__(self, model, learning_rate, grad_acc_steps, prev_optim_path = None):
        self.learning_rate = learning_rate
        self.grad_acc_steps = grad_acc_steps
        
        self.weight_decay = TrainingConfig.weight_decay
        self.optimizer = tfa.optimizers.AdamW(learning_rate = self.learning_rate, weight_decay = self.weight_decay)
        
        self.PrevModelPath = prev_optim_path
        if self.PrevModelPath:
            self.opt_weights = np.load(f'{self.PrevModelPath}optimizer_last.npy', allow_pickle = True)
        
            trainable_weights = model.trainable_weights
            
            zero_grads = [tf.zeros_like(w) for w in trainable_weights]
            @tf.function
            def f():
                self.optimizer.apply_gradients(zip(zero_grads, trainable_weights))
            strategy.run(f)
            self.optimizer.set_weights(self.opt_weights)
            print("Loaded Weights")
        
        self.gradients = None
        self.cur_grad_acc = 0
    def apply_gradients(self, gradients, variables):
        if self.gradients is None:
            self.gradients = [g / tf.constant(float(self.grad_acc_steps)) for g in gradients]
            self.cur_grad_acc += 1
        else:
            for i in range(len(gradients)):
                self.gradients[i] += gradients[i] / tf.constant(float(self.grad_acc_steps))
            self.cur_grad_acc += 1
        if self.cur_grad_acc == self.grad_acc_steps:
            self.optimizer.apply_gradients(zip(self.gradients, variables))
            self.gradients = None
            self.cur_grad_acc = 0

# Training Config

In [None]:
class TrainingConfig:
  learning_rate = 1e-5
  max_lr = 1e-4
  min_lr = 1e-9
  


  WARM_STEPS = 0.1
  PEAK_STEPS = 0.1

  weight_decay = 0.

  STEPS_PER_EPOCH = TRAIN_NUMBER // BATCH_SIZE
  TOTAL_STEPS = STEPS_PER_EPOCH * NUM_EPOCHS

# Create Model under TPU Strategy

In [None]:
def save_states(model_path):
  save_path = f"{DataModule.SAVE_PATH}{model_path}"
  optimizer_path = f"{save_path}_optim.npy"
  with open(optimizer_path, 'w') as file:
    pass
  np.save(optimizer_path, optimizer.optimizer.get_weights())

  model_path = f"{save_path}_model.h5"
  with open(model_path, 'w') as file:
    pass 
  model.save_weights(model_path)

In [None]:
def prepare_model():
  with strategy.scope():
    print('-------------CREATING MODEL -------------------')
    model = load_model()
    print('-----------------CREATING LOSS_FN-------------')
    def loss_fn(y_true_bbox, y_true_cls, y_pred_bbox, y_pred_cls):
      y_true_bbox = tf.cast(y_true_bbox, tf.float32)
      y_true_cls = tf.cast(y_true_cls, tf.float32)
      y_pred_bbox = tf.cast(y_pred_bbox, tf.float32)
      y_pred_cls = tf.cast(y_pred_cls, tf.float32)
    
      loss_bbox = smooth_l1(y_true_cls, y_pred_cls)
      loss_cls = focal_loss(y_true_bbox, y_pred_bbox)
      # Sum Loss
      return loss_bbox + loss_cls
    print('-------------------CREATING OPTIMIZERS--------------')
    #optimizer = GradAccAdam(model, TrainingConfig.learning_rate, 1)
    print('-----------------CREATING SCHEDULER------------------')
    scheduler = OneCycleScheduler(TrainingConfig.learning_rate, TrainingConfig.max_lr, TrainingConfig.min_lr, TrainingConfig.WARM_STEPS, TrainingConfig.PEAK_STEPS, TrainingConfig.TOTAL_STEPS)
    print('-----------------CREATING METRICS-------------------')
    metrics = {
        'train_loss': keras.metrics.Mean(),
        'val_loss': keras.metrics.Mean(),
        #'val_acc': Accuracy()
    } # accuracy is too slow to reasonably Compute Anyways. I will use public LB as a evaluation metric.
  return model, scheduler, loss_fn#model, optimizer, scheduler, loss_fn, metrics



# Stat Logger.

In [None]:
class StatLogger():
  def __init__(self):
    self.best_loss = float('inf')
    self.best_accuracy = 0.0
    self.EPOCH = 0
  #def on_epoch_end(self, epoch, logs = {}):
    
  def update_val(self):
    # Updates at the end of a validation loop
    # grab the metrics 
    train_loss = metrics['train_loss'].result().numpy().item()
    val_loss = metrics['val_loss'].result().numpy().item() 
    val_acc = metrics['val_acc'].result().numpy().item()
    
    if val_loss <= self.best_loss:
      self.best_loss = val_loss
      save_states('loss')
    if val_acc >= self.best_accuracy:
      self.best_accuracy = val_acc
      save_states('acc')
    
    print(f"E: {self.EPOCH}, BL: {self.best_loss}, BA: {self.best_accuracy}, TL: {train_loss} VL: {val_loss}, VA: {val_acc}")
    self.EPOCH += 1


# Train the Model using TPU's.

In [None]:
# Use Model.Fit 
def train_model(fold_idx):
    # Prepare Callbacks
    lr_schedule = tf.keras.callbacks.LearningRateScheduler(scheduler.step, verbose=1)
    save_model = keras.callbacks.ModelCheckpoint(f'./best_model_{fold_idx}.h5', save_best_only = True, save_weights_only = True)
    cbs = [
        lr_schedule,
        save_model
    ]
    
    # Fit the Model
    model.compile(
        optimizer = tfa.optimizers.AdamW(
            weight_decay = TrainingConfig.weight_decay,
            learning_rate = TrainingConfig.learning_rate
        ),
        loss = {'regression': smooth_l1, 'classification': focal_loss}
    )
    train_dataset, val_dataset = get_dfs(fold_idx)
    model.fit(
        train_dataset,
        epochs = NUM_EPOCHS,
        callbacks = cbs,
        steps_per_epoch = TRAIN_NUMBER // BATCH_SIZE,
        validation_data = val_dataset,
        validation_steps = VAL_NUMBER // BATCH_SIZE
    )

In [None]:
FOLDS_TO_TRAIN = [0]

In [None]:
for FOLD_IDX in FOLDS_TO_TRAIN:
    model, scheduler, loss_fn = prepare_model()
    train_model(FOLD_IDX)