## Categorical Focal Loss
In this assignment we will implement a categorical focal loss function with "L1" and "L2" regularization for multi-class classification problems.\
Focal Loss have several applications in problems which have inbalance datasets such as Object Detection:
you can learn more about this loss function here:
https://medium.com/swlh/focal-loss-what-why-and-how-df6735f26616

In [None]:
import tensorflow as tf

## Focal Loss Formula:
$$
FL(y_{true}, y_{pred}) = - \alpha * y_{true} * (1 - y_{pred})^ γ * log(y_{pred}) \\
l1(y_{true}, y_{pred}) = ∑|y_{pred}| \\
l2(y_{true}, y_{pred}) = \sum (y_{pred})^2 \\
total-loss = FL + l1_w * l1 + l2_2 * l2
$$

In [None]:
class CategoricalFocalLoss(tf.keras.losses.Loss):
    def __init__(self, alpha=0.25, gamma=2, l1=0.01, l2=0.01, **kwargs):
        super(CategoricalFocalLoss, self).__init__(**kwargs)
        #Code Here

In [None]:

def build_model(dense_units, input_shape=(224, 224) + (3,)):
  model = tf.keras.models.Sequential([
      tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=input_shape),
      tf.keras.layers.MaxPooling2D(2, 2),
      tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
      tf.keras.layers.MaxPooling2D(2, 2),
      tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
      tf.keras.layers.MaxPooling2D(2, 2),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(dense_units, activation='relu'),
      tf.keras.layers.Dense(2, activation='softmax')
  ])
  return model

In [None]:
import tensorflow_datasets as tfds
dataset = tfds.load('cats_vs_dogs', split=tfds.Split.TRAIN, data_dir='data/')

# Initialize VGG with the number of classes 
model = build_model(dense_units=256)

# Compile with losses and metrics
model.compile(optimizer='adam', loss=CategoricalFocalLoss(), metrics=['accuracy'])

# Define preprocessing function
def preprocess(features):
    # Resize and normalize
    image = tf.image.resize(features['image'], (224, 224))
    return tf.cast(image, tf.float32) / 255., tf.cast(tf.one_hot(features['label'], depth=2), tf.float32)

# Apply transformations to dataset
dataset = dataset.map(preprocess).batch(32)

# Train the custom VGG model
model.fit(dataset, epochs=10)