# Multi-Epoch Training Using Gradient Tape

## Progress Metrics

In [None]:
class Mean(metrics.Metric):
    def __init__(self, name="mean", **kwargs):
        super(Mean, self).__init__(name=name, **kwargs)
        self.total = self.add_weight(name='total_{}'.format(name), initializer="zeros")
        self.count = self.add_weight(name='count_{}'.format(name), initializer="zeros")

    def update_state(self, result):
        self.total.assign_add(result)
        self.count.assign_add(1)

    def result(self):
        return self.total/self.count

    def reset_state(self):
        # The state of the metric will be reset at the start of each epoch.
        self.total.assign(0.0)
        self.count.assign(0)

class History(object):
    def __init__(self):
        self.metrics = dict(
            learning_rate = Mean(name='learning_rate'),

            loss = Mean(name='loss'),
            val_loss = Mean(name='val_loss'),

            yx_loss = Mean(name='yx_loss'),
            val_yx_loss = Mean(name='val_yx_loss'),

            hw_loss = Mean(name='hw_loss'),
            val_hw_loss = Mean(name='val_hw_loss'),

            iou = Mean(name='iou'),
            val_iou = Mean(name='val_iou'),

            positive_iou = Mean(name='positive_iou'),
            val_positive_iou = Mean(name='val_positive_iou'),

            negative_iou = Mean(name='negative_iou'),
            val_negative_iou = Mean(name='val_negative_iou'),
        )
        self.history = {name: [] for name, metric in self.metrics.items()}
    
    @property
    def metric_names(self):
        return list(self.metrics.keys())

    @property
    def training_metrics_names(self):
        return list(filter(lambda name: not name.startswith('val_'), self.metrics.keys()))
    
    @property
    def training_metrics(self):
        return [(name, self.metrics[name].result()) for name in self.training_metrics_names]

    @property
    def metric_values(self):
        return [(name, metric.result()) for name, metric in self.metrics.items()]
    
    def train_step(self, yx_loss, hw_loss, iou, positive_iou, negative_iou):
        self.metrics['loss'].update_state(yx_loss + hw_loss)
        self.metrics['yx_loss'].update_state(yx_loss)
        self.metrics['hw_loss'].update_state(hw_loss)

        self.metrics['iou'].update_state(iou)
        self.metrics['positive_iou'].update_state(positive_iou)
        self.metrics['negative_iou'].update_state(negative_iou)

        return self.training_metrics
    
    def val_step(self, yx_loss, hw_loss, iou, positive_iou, negative_iou):
        self.metrics['val_loss'].update_state(yx_loss + hw_loss)
        self.metrics['val_yx_loss'].update_state(yx_loss)
        self.metrics['val_hw_loss'].update_state(hw_loss)

        self.metrics['val_iou'].update_state(iou)
        self.metrics['val_positive_iou'].update_state(positive_iou)
        self.metrics['val_negative_iou'].update_state(negative_iou)

        return self.metric_values
    
    def learning_rate(self, lr_value):
        self.metrics['learning_rate'].update_state(lr_value)
    
    def epoch(self):
        # Record the current epoch values before reset.
        values = self.metric_values

        for name in self.metrics.keys():
            self.record_and_reset(name)
        
        return values
    
    def record_and_reset(self, name):
        self.history[name].append(self.metrics[name].result().numpy())
        self.metrics[name].reset_state()

## Training Loop

In [None]:
@tf.function
def train_step(model, x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        yx_loss, hw_loss = model.loss(y, logits)
        loss = yx_loss + hw_loss
        iou, positive_iou, negative_iou = compute_iou_metric(y, logits)

    # Compute gradients and backpropagate.
    grads = tape.gradient(loss, model.trainable_weights)
    model.optimizer.apply_gradients(zip(grads, model.trainable_weights))

    return yx_loss, hw_loss, iou, positive_iou, negative_iou

@tf.function
def val_step(model, x, y):
    logits = model(x, training=False)
    yx_loss, hw_loss = model.loss(y, logits)
    iou, positive_iou, negative_iou = compute_iou_metric(y, logits)
    
    return yx_loss, hw_loss, iou, positive_iou, negative_iou

def train(model, tds, vds, epochs=100):
    # Record progress
    ckpt = tf.train.Checkpoint(optimizer=model.optimizer, model=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt, './sequence_of_bboxes', max_to_keep=3)
    history = History()
    
    # tds = tds.prefetch(buffer_size=tf.data.AUTOTUNE)
    # vds = vds.prefetch(buffer_size=tf.data.AUTOTUNE) if vds else None
    tds = tds.prefetch(buffer_size=tf.data.AUTOTUNE).take(2)
    vds = vds.prefetch(buffer_size=tf.data.AUTOTUNE).take(2) if vds else None
    # tds = tds.prefetch(buffer_size=tf.data.AUTOTUNE).take(1)
    # vds = vds.prefetch(buffer_size=tf.data.AUTOTUNE).take(1) if vds else None

    history.learning_rate(LEARNING_RATE)
    history.record_and_reset('learning_rate')

    for epoch in range(epochs):
        # Track training progress
        print("\nEpoch {}/{}".format(epoch + 1, epochs))
        p_bar = utils.Progbar(STEPS_PER_EPOCH, stateful_metrics=history.metric_names)
        steps = 0

        for step, (x, y) in enumerate(iter(tds)):
            yx_loss, hw_loss, iou, positive_iou, negative_iou = train_step(model, x, y)

            p_bar.update(step + 1, values=history.train_step(yx_loss, hw_loss, iou, positive_iou, negative_iou))
            steps += 1
        
        # Record learning rates
        history.learning_rate(model.optimizer.lr((epoch + 1)*STEPS_PER_EPOCH))
        
        for x, y in iter(vds):
            yx_loss, hw_loss, iou, positive_iou, negative_iou = val_step(model, x, y)

            history.val_step(yx_loss, hw_loss, iou, positive_iou, negative_iou)
        
        # Display metrics at the end of each epoch.
        p_bar.update(steps, values=history.epoch())

        # Save Checkpoint
        print('\nSaved Checkpoint: {}'.format(ckpt_manager.save()))


    return history

EPOCHS = 2
# EPOCHS = 50
BATCH_SIZE = 2

tds = train_prep_ds.batch(BATCH_SIZE)
# vds = val_prep_ds.batch(256).cache()
vds = val_prep_ds.batch(2).cache()

hist = train(model, tds, vds, epochs=EPOCHS)
hist.history