In [None]:
# https://github.com/tensorflow/addons/blob/v0.14.0/tensorflow_addons/losses/focal_loss.py#L26-L81

import tensorflow as tf
import tensorflow.keras.backend as K

class FocalLoss:
    def __init__(self, alpha = 0.25, gamma = 2.0, from_logits = False):
        self.alpha = alpha
        self.gamma = gamma
        self.from_logits = from_logits
    
    def __call__(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float64)
        y_pred = tf.cast(y_pred, tf.float64)
        
        epsilon = K.epsilon()
        
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
        
        if self.from_logits:
            y_pred = tf.sigmoid(y_pred)
        
        # No log
        p_t = y_true * (y_pred) + (1 - y_true) * (1 - y_pred)
        alpha_factor = y_true * self.alpha + (1 - y_true) * (1 - self.alpha)
        modulating_factor = tf.pow((1.0 - p_t), self.gamma)
        
        # a * (1 - p_t) ** gamma * (-log(pt))
        loss = alpha_factor * modulating_factor * tf.keras.losses.binary_crossentropy(y_true, y_pred, 
                                                                                      from_logits = self.from_logits)
        
        return tf.reduce_sum(loss)
    
    
import numpy as np

y_true = np.array([[1.],
                   [1.],
                   [0.]])
y_pred = np.array([[0.97],
                   [0.61],
                   [0.4]])

fl = FocalLoss()
fl(y_true, y_pred)