In [None]:
# default_exp callbacks.core

# Callback Core

> Base callback class

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export

from mrl.imports import *
from mrl.core import *
from mrl.torch_imports import *
from mrl.torch_core import *

In [None]:
# export

class Callback():
    def __init__(self, name='base_callback', order=10):
        self.order=order
        self.name = name
        self.event_timelog = defaultdict(list)
    
    def __call__(self, event_name):
        
        start = time.time()
        event = getattr(self, event_name, None)
        if event is not None:
            output = event()
        else:
            output = None
            
        end = time.time() - start
        self.event_timelog[event_name].append(end)
        return output
    
    def __repr__(self):
        return self.name
    
    def plot_dict(self, data_dict, cols=4, smooth=True):
        num_metrics = len(data_dict.keys())
        
        rows = int(np.ceil(num_metrics/cols))
        fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
                
        metrics = list(data_dict.keys())
        
        for i, ax in enumerate(axes.flat):
            if i <len(metrics):
                ax.plot(np.stack(data_dict[metrics[i]]),)
                ax.set_title(metrics[i])
            else:
                ax.axis('off')
    
    def plot_time(self, cols=4, smooth=True):
        self.plot_dict(self.event_timelog, cols=cols, smooth=smooth)

In [None]:
Callback()

base_callback

In [None]:
# export

class Event():
    def __init__(self):
        self.setup = 'setup'
        self.before_train = 'before_train'
        self.build_buffer = 'build_buffer'
        self.after_build_buffer = 'after_build_buffer'
        self.before_batch = 'before_batch'
        self.sample_batch = 'sample_batch'
        self.after_sample = 'after_sample'
        self.get_model_outputs = 'get_model_outputs'
        self.compute_reward = 'compute_reward'
        self.after_compute_reward = 'after_compute_reward'
        self.compute_loss = 'compute_loss'
        self.zero_grad = 'zero_grad'
        self.before_step = 'before_step'
        self.step = 'step'
        self.after_batch = 'after_batch'
        self.after_train = 'after_train'
        
        
# class Event():
#     def __init__(self):
#         self.setup = 'setup'
#         self.before_train = 'before_train'
#         self.build_buffer = 'build_buffer'
#         self.filter_buffer = 'filter_buffer'
#         self.after_build_buffer = 'after_build_buffer'
#         self.score_buffer = 'compute_buffer_reward'
#         self.before_batch = 'before_batch'
#         self.sample_batch = 'sample_batch'
#         self.filter_batch = 'filter_batch'
#         self.after_sample = 'after_sample'
#         self.compute_reward = 'compute_batch_reward'
#         self.after_compute_reward = 'after_compute_reward'
#         self.reward_modification = 'reward_modification'
#         self.get_model_outputs = 'get_model_outputs'
#         self.compute_loss = 'compute_loss'
#         self.zero_grad = 'zero_grad'
#         self.before_step = 'before_step'
#         self.step = 'step'
#         self.after_batch = 'after_batch'
#         self.after_train = 'after_train'

In [None]:
# export

class SettrDict(dict):
    def __init__(self):
        super().__init__()
        
    def __setitem__(self, key, item):
        super().__setitem__(key, item)
        super().__setattr__(key, item)
    
    def __setattr__(self, key, item):
        super().__setitem__(key, item)
        super().__setattr__(key, item)
        
    def update_from_dict(self, update_dict):
        for k,v in update_dict.items():
            self[k] = v
        
class BatchState(SettrDict):
    def __init__(self):
        super().__init__()
        
        self.samples = []
        self.sources = []
        self.rewards = to_device(torch.tensor(0.))
        self.loss = to_device(torch.tensor(0., requires_grad=True))
        self.latent_data = {}

In [None]:
# hide
from nbdev.export import notebook2script; notebook2script()