In [None]:
# default_exp callbacks.core

# Callback Core

> Base callback class

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export

from mrl.imports import *
from mrl.core import *
from mrl.torch_imports import *
from mrl.torch_core import *

## Callbacks

The training cycle in MRL is built around the Callback system. Rather than trying to explicitly define every training cycle variant, Callbacks define a series of events (see `Events`) that occur during training and allow users to esaily hook into those events. The result is an extremely flexible framework that can adapt to most generative design challenges.

Callbacks use the `__call__` function to organize events. The call function will be passed an event name, like `compute_reward`. If the Callback function has an attribute that matches the event name, the attribute is called.

Callbacks have access to the training environment (see `Environment`) and can access the training environment, the model/agent, the training buffer, training log, other callbacks and all other aspects of the training state

In [None]:
# export

class Callback():
    def __init__(self, name='base_callback', order=10):
        self.order=order
        self.name = name
        self.event_timelog = defaultdict(list)
    
    def __call__(self, event_name):
        
        start = time.time()
        event = getattr(self, event_name, None)
        if event is not None:
            output = event()
        else:
            output = None
            
        end = time.time() - start
        self.event_timelog[event_name].append(end)
        return output
    
    def __repr__(self):
        return self.name
    
    def _filter_batch(self, valids):
        valids = np.array(valids)
        env = self.environment
        batch_state = env.batch_state
        
        samples = batch_state.samples
        sources = np.array(batch_state.sources)
        
        if valids.mean()<1.:
            filtered_samples = [samples[i] for i in range(len(samples)) if valids[i]]
            filtered_sources = [sources[i] for i in range(len(sources)) if valids[i]]
            filtered_latent_data = {}

            for source,latents in batch_state.latent_data.items():
                valid_subset = valids[sources==source]
                latent_filtered = latents[valid_subset]
                filtered_latent_data[source] = latent_filtered

            batch_state.samples = filtered_samples
            batch_state.sources = filtered_sources
            batch_state.latent_data = filtered_latent_data
    
    def plot_dict(self, data_dict, cols=4, smooth=True):
        num_metrics = len(data_dict.keys())
        
        rows = int(np.ceil(num_metrics/cols))
        fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
                
        metrics = list(data_dict.keys())
        
        for i, ax in enumerate(axes.flat):
            if i <len(metrics):
                ax.plot(np.stack(data_dict[metrics[i]]),)
                ax.set_title(metrics[i])
            else:
                ax.axis('off')
    
    def plot_time(self, cols=4, smooth=True):
        self.plot_dict(self.event_timelog, cols=cols, smooth=smooth)
        
    def save(self, filename):
        torch.save(self, filename)

In [None]:
# export

class Event():
    '''
    Event
    
    Base class for events
    '''
    def __init__(self):
        self.event_name = None

In [None]:
# export

class Setup(Event):
    '''
    Setup
    
    Setup is called after an `Environment` is created. The setup 
    step is used to do things like set attributes or add logging terms
    '''
    def __init__(self):
        self.event_name = 'setup'

In [None]:
# export

class BeforeTrain(Event):
    '''
    BeforeTrain
    
    This event is called by `Environment.fit` before the first batch is run
    '''
    def __init__(self):
        self.event_name = 'before_train'

In [None]:
# export

class BuildBuffer(Event):
    '''
    BuildBuffer
    
    The build buffer event is used to add samples to the Buffer
    '''
    def __init__(self):
        self.event_name = 'build_buffer'

In [None]:
# export

class FilterBuffer(Event):
    '''
    FilterBuffer
    
    The filter buffer event is used to screen items added to the 
    buffer during `build_buffer` and remove ones that do not 
    match the filter criteria
    '''
    def __init__(self):
        self.event_name = 'filter_buffer'

In [None]:
# export

class AfterBuildBuffer(Event):
    '''
    AfterBuildBuffer
    
    This event is called after the buffer has been filtered and 
    before the next batch starts. This event can be used to 
    evaluate metrics and statistics related to the buffer creation
    '''
    def __init__(self):
        self.event_name = 'after_build_buffer'

In [None]:
# export

class BeforeBatch(Event):
    '''
    BeforeBatch
    
    This event is called before the next batch is sampled
    '''
    def __init__(self):
        self.event_name = 'before_batch'

In [None]:
# export

class SampleBatch(Event):
    '''
    SampleBatch
    
    This event produces a series of samples that are added 
    to the next batch
    '''
    def __init__(self):
        self.event_name = 'sample_batch'

In [None]:
# export

class BeforeFilterBatch(Event):
    '''
    BeforeFilterBatch
    
    This event is called before the current batch is filtered
    '''
    def __init__(self):
        self.event_name = 'before_filter_batch'

In [None]:
# export

class FilterBatch(Event):
    '''
    FilterBatch
    
    This event is used to screen items in the current batch 
    and remove items that do not match the filter criteria
    '''
    def __init__(self):
        self.event_name = 'filter_batch'

In [None]:
# export

class AfterSample(Event):
    '''
    AfterSample
    
    This event is called after a batch is sampled and filtered. 
    This event can be used to log stats about the last batch
    '''
    def __init__(self):
        self.event_name = 'after_sample'

In [None]:
# export

class BeforeComputeReward(Event):
    '''
    BeforeComputeReward
    
    This event is called prior to computing rewards 
    on the current batch. This event can be used to generate 
    any inputs required for computing rewards
    '''
    def __init__(self):
        self.event_name = 'before_compute_reward'

In [None]:
# export

class ComputeReward(Event):
    '''
    ComputeReward
    
    This event is used to compute rewards for 
    the current batch
    
    All rewards should be added to `self.environmemnt.batch_state.rewards`
    '''
    def __init__(self):
        self.event_name = 'compute_reward'

In [None]:
# export

class AfterComputeReward(Event):
    '''
    AfterComputeReward
    
    This event is called after all rewards 
    have been computed. This event can be used 
    to log stats and metrics related to the 
    rewards for the current batch
    '''
    def __init__(self):
        self.event_name = 'after_compute_reward'

In [None]:
# export

class RewardModification(Event):
    '''
    RewardModification
    
    This event is used to modify rewards before they 
    are used to compute the model's loss. Reward modifications 
    encompass changes to rewards in the context of the current 
    training cycle. These are things like "give a score bonus 
    to new samples that havent't been seen before" or "penalize 
    the score of samples that have occurred in the last 5 batches".
    
    These types of modifications are kept separate from the core 
    reward for logging purposes. Samples are logged with their 
    respective rewards. These logged scores are referenced later 
    when samples are drawn from the log. This means we need the 
    logged score to be independent from "batch context" type scores
    
    All reward modifications should be 
    applied to `self.environmemnt.batch_state.rewards`
    '''
    def __init__(self):
        self.event_name = 'reward_modification'

In [None]:
# export

class GetModelOutputs(Event):
    '''
    GetModelOutputs
    
    This event is used to generate any model-derived outputs 
    relevant to loss computation
    '''
    def __init__(self):
        self.event_name = 'get_model_outputs'

In [None]:
# export

class AfterGetModelOutputs(Event):
    '''
    AfterGetModelOutputs
    
    This event is called after `get_model_outputs`. 
    This event can be used for any processing 
    required prior to loss computation
    '''
    def __init__(self):
        self.event_name = 'after_get_model_outputs'

In [None]:
# export

class ComputeLoss(Event):
    '''
    ComputeLoss
    
    This event is used to compute loss values
    
    All loss values should be added to
    `self.environment.batch_state.loss`
    '''
    def __init__(self):
        self.event_name = 'compute_loss'

In [None]:
# export

class ZeroGrad(Event):
    '''
    ZeroGrad
    
    This event is used to zero gradients 
    in any optimizers relevant to the fit cycle
    
    `loss.backward()` is called after zero grad
    '''
    def __init__(self):
        self.event_name = 'zero_grad'

In [None]:
# export

class BeforeStep(Event):
    '''
    BeforeStep
    
    This event is used for any processed 
    needed after `loss.backward()` but 
    before `opt.step()`, ie gradient clipping
    '''
    def __init__(self):
        self.event_name = 'before_step'

In [None]:
# export

class Step(Event):
    '''
    Step
    
    This event is used to step all optimizers
    '''
    def __init__(self):
        self.event_name = 'step'

In [None]:
# export

class AfterBatch(Event):
    '''
    AfterBatch
    
    This event is called after `step`. This 
    event can be used to compute batch stats 
    and clean up values before the next batch
    '''
    def __init__(self):
        self.event_name = 'after_batch'

In [None]:
# export

class AfterTrain(Event):
    '''
    AfterTrain
    
    This event is called after all 
    batch steps have been completed 
    '''
    def __init__(self):
        self.event_name = 'after_train'

In [None]:
# export        
        
class Events():
    def __init__(self):
        self.setup = Setup()
        self.before_train = BeforeTrain()
        self.build_buffer = BuildBuffer()
        self.filter_buffer = FilterBuffer()
        self.after_build_buffer = AfterBuildBuffer()
        self.before_batch = BeforeBatch()
        self.sample_batch = SampleBatch()
        self.before_filter_batch = BeforeFilterBatch()
        self.filter_batch = FilterBatch()
        self.after_sample = AfterSample()
        self.before_compute_reward = BeforeComputeReward()
        self.compute_reward = ComputeReward()
        self.after_compute_reward = AfterComputeReward()
        self.reward_modification = RewardModification()
        self.get_model_outputs = GetModelOutputs()
        self.after_get_model_outputs = AfterGetModelOutputs()
        self.compute_loss = ComputeLoss()
        self.zero_grad = ZeroGrad()
        self.before_step = BeforeStep()
        self.step = Step()
        self.after_batch = AfterBatch()
        self.after_train = AfterTrain()
        
        self.event_names = [
            'setup',
            'before_train',
            'build_buffer',
            'filter_buffer',
            'after_build_buffer',
            'before_batch',
            'sample_batch',
            'before_filter_batch',
            'filter_batch',
            'after_sample',
            'before_compute_reward',
            'compute_reward',
            'after_compute_reward',
            'reward_modification',
            'get_model_outputs',
            'after_get_model_outputs',
            'compute_loss',
            'zero_grad',
            'before_step',
            'step',
            'after_batch',
            'after_train'
        ]
        
    def __call__(self, event_name):
        
        event = getattr(self, event_name, None)
        if event is not None:
            print(event.__doc__)

In [None]:
# export

class SettrDict(dict):
    def __init__(self):
        super().__init__()
        
    def __setitem__(self, key, item):
        super().__setitem__(key, item)
        super().__setattr__(key, item)
    
    def __setattr__(self, key, item):
        super().__setitem__(key, item)
        super().__setattr__(key, item)
        
    def update_from_dict(self, update_dict):
        for k,v in update_dict.items():
            self[k] = v

## Batch State

The `BatchState` class is used by an `Environment` to track values generated or computed during a batch. Every batch, the old `BatchState` is deleted and a new `BatchState` is created.

Attributes in `BatchState` can be set or accessed with a key like a dictionary or as an attribute. `BatchState` can hold any arbitrary value during a batch. However, it was designed for the use case where every attribute is either a single value or a list/container with length equal to the current batch size.

### Rewards

`BatchState` holds the `rewards` value for a batch. All reward functions should ultimately add their reward value to `BatchState.rewards`. See `Reward` for more information.

### Loss

`BatchState` holds the `loss` value for a batch. This is the value that will be backpropagated during the optimizer update. All loss functions should ultimately add their value to `BatchState.loss`. See `Loss` for more information.

In [None]:
# export
        
class BatchState(SettrDict):
    def __init__(self):
        super().__init__()
        
        self.samples = []
        self.sources = []
        self.rewards = to_device(torch.tensor(0.))
        self.loss = to_device(torch.tensor(0., requires_grad=True))
        self.latent_data = {}

In [None]:
# hide
from nbdev.export import notebook2script; notebook2script()