## Loss functions
Going to put any custom loss functions I need to make over the course of building models for this research here. If it gets to be too many, I'll keep them in separate notebooks.

In [None]:
import sys
import tensorflow as tf
from tensorflow.keras.losses import CategoricalCrossentropy

In [None]:
def refresh(obj):  # clear the state of the loss function / metric if it holds onto it for some damn reason
    try:
        obj.reset_state()
        print('reset state')
    except:
        pass
    try:
        obj.reset_states()
        print('reset state')
    except:
        pass

### CategoricalCrossEntropy for each point in a series

In [None]:
def catcrossentropy_per_pt(y_true, y_pred):
    """
    A metric for a series of datapoints, each of which needs classification.
    
    Parameters
    ----------
    y_true: tensorflow tensor of shape (batch size, series length, num_categories)
        The true values to compare with. For datapoint in the series,
        the category information should be one-hot encoded
    y_pred: tensorflow tensor of shape (batch size, series length, num_categories)
        The predicted values. For datapoint in the series,
        the category information should be expressed in probabilities (fractions of 1)    
    
    Returns
    -------
    loss: tensorflow tensor of shape (1,)
        The categorical cross entropy for each datapoint in each series, summed

    """
    
    loss_fn = CategoricalCrossentropy()
    
    losses = tf.zeros(shape=(1,))

    for i in range(int(tf.shape(y_true)[1])):  # loop over every datapoint in the series
        loss = loss_fn(y_true[:,i,:], y_pred[:,i,:])
        losses += loss
        
    losses /= int(tf.shape(y_true)[1])
    
    return losses

### generator function to handle masked sequential data

In [None]:
def gen_loss_per_pt(loss_fn=CategoricalCrossentropy(), mask_layer=None):
    """
    A generator function for series' of datapoints that returns a loss function
    which takes into account a mask of the inputs
    
    Parameters
    ----------
    loss_fn : loss-type class object
        loss function to use per each data point
    mask_layer : layer.Masking object
        masking layer used to throw out padding datapoints
    
    Returns
    -------
    loss_per_pt : function
        The generated loss function taking into account the values of loss_fn and mask_layer
    """
    
    def loss_per_pt(y_true, y_pred):
        """
        A metric for a series of datapoints, each of which needs its own separate evaluation.

        Parameters
        ----------
        y_true: tensorflow tensor of shape (batch size, series length, num_categories)
            The true values to compare with. For datapoint in the series,
            the category information should be one-hot encoded
            SHOULD BE MASKED AS PER MASK_LAYER'S EXPECTATIONS
        y_pred: tensorflow tensor of shape (batch size, series length, num_categories)
            The predicted values. For datapoint in the series,
            the category information should be expressed in probabilities (fractions of 1)    

        Returns
        -------
        loss: tensorflow tensor of shape (1,)
            The loss for each datapoint in each series, summed

        """
        losses = tf.zeros(shape=(1,))
        n_points = int(tf.shape(y_true)[1])
        if mask_layer is not None:
            mask = mask_layer.compute_mask(y_true)
            for i in range(n_points):  # loop over every datapoint in the series
                y_t = tf.boolean_mask(y_true[:,i,:], mask[:,i])
                y_p = tf.boolean_mask(y_pred[:,i,:], mask[:,i]) 
                refresh(loss_fn)
                loss = loss_fn(y_t, y_p)
                if len(loss) > 1: # fn returned loss for each point
                    loss = sum(loss)
                losses += loss        
        else:
            for i in range(n_points):  # loop over every datapoint in the series
                refresh(loss_fn)
                loss = loss_fn(y_true[:,i,:], y_pred[:,i,:])
                if len(loss) > 1:
                    loss = sum(loss)
                losses += loss
        # normalize by number of datapoints
        losses /= n_points
        
        return losses
    
    return loss_per_pt