In [None]:
# default_exp callbacks.log

# Logging

> Callbacks for logging data

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export

from mrl.imports import *
from mrl.core import *
from mrl.callbacks.core import *
from mrl.torch_imports import *

In [None]:
# export

def log_to_df(log, keys=None):
    batch = 0
    output_dict = defaultdict(list)
    
    if keys is None:
        keys = list(log.keys())
    
    items = log[keys[0]]
    for item in items:
        output_dict['batch'] += [batch]*len(item)
        batch += 1
        
    for key in keys:
        output_dict[key] = flatten_list_of_lists(log[key])

    return pd.DataFrame(output_dict)

In [None]:
# export

class Log(Callback):
    def __init__(self):
        super().__init__(name='log', order=100)
        
        self.pbar = None
        self.iterations = 0
        self.metrics = {}
        
        self.batch_log = {}
        self.timelog = defaultdict(list)
        
        self.report = 1
        self.unique_samples = {}
#         self.unique_samples = set()
        
        self.add_metric('rewards')
        self.add_log('samples')
        self.add_log('sources')
        self.add_log('rewards')
        
        self.log_df = None
        
    def setup(self):
        self.df = pd.DataFrame(self.batch_log)
        
    def before_train(self):
        cols = ['iterations'] + list(self.metrics.keys())
        if self.pbar is None:
            print('\t'.join(cols))
        else:
            self.pbar.write(cols, table=True)
            
    def add_metric(self, name):
        if not name in self.metrics.keys():
            self.metrics[name] = []
        
    def add_log(self, name):
        if not name in self.batch_log.keys():
            self.batch_log[name] = []
            
    def update_metric(self, name, value):
        self.metrics[name].append(value)
        
    def update_log(self):
        env = self.environment
        batch_state = env.batch_state
        samples = batch_state.samples
        update_dict = {}

        for key in self.batch_log.keys():
            items = batch_state[key]
            if isinstance(items, torch.Tensor):
                items = items.detach().cpu().numpy()
            self.batch_log[key].append(items)
            update_dict[key] = items
            
        new_df = pd.DataFrame(update_dict)
        repeats = new_df.samples.isin(self.df.samples)
        new_df = new_df[~repeats]
            
        self.df = self.df.append(new_df)
        
#         if self.iterations%5==0 and self.iterations>0:
#             self.df.drop_duplicates(subset='samples', inplace=True)
            
    def before_compute_reward(self):
        env = self.environment
        batch_state = env.batch_state
        samples = batch_state.samples
        batch_state.prescored = []
        
        for i, sample in enumerate(samples):
            if sample in self.unique_samples:
                batch_state.prescored.append(True)
                batch_state.rewards[i] = torch.tensor(self.unique_samples[sample])
            else:
                batch_state.prescored.append(False)
        
            
    def after_compute_reward(self):
        env = self.environment
        batch_state = env.batch_state
        samples = batch_state.samples
        rewards = batch_state.rewards.detach().cpu().numpy()
        for i in range(len(samples)):
            if not samples[i] in self.unique_samples:
                self.unique_samples[samples[i]] = rewards[i]
            
    def report_batch(self):
        outputs = [f'{self.iterations}']
        if self.iterations%self.report==0:
            
            for k,v in self.metrics.items():
                val = v[-1]

                if type(val)==int:
                    val = f'{val}'
                else:
                    val = f'{val:.3f}'

                outputs.append(val)

            if self.pbar is None:
                print('\t'.join(outputs))
            else:
                self.pbar.write(outputs, table=True)
            
        self.iterations += 1
        
    def after_batch(self):
        self.update_log()
        self.report_batch()
        
    def get_df(self):
        return log_to_df(self.batch_log)
    
    def plot_metrics(self, cols=4, smooth=True):
        self.plot_dict(self.metrics, cols=cols, smooth=smooth)
            
    def plot_timelog(self, cols=4, smooth=True):
        self.plot_dict(self.timelog, cols=cols, smooth=smooth)


In [None]:
# export

class StatsCallback(Callback):
    # grabs from batch_state based on name
    def __init__(self, batch_attribute, grabname=None, name='stats', order=20):
        super().__init__(name=name, order=order)
        self.grabname = grabname
        self.batch_attribute = batch_attribute

    def get_values(self):
        batch_state = self.environment.batch_state
        sources = np.array(batch_state.sources)
        values = batch_state[self.batch_attribute]
        
        if self.grabname is not None:
            source_mask = sources==self.grabname
            values = values[source_mask]
            
        if isinstance(values, torch.Tensor):
            values = values.detach().cpu().numpy()
            
        return values
    
class MaxCallback(StatsCallback):
    def __init__(self, batch_attribute, grabname, order=20):
        
        if grabname is None:
            name = f'{batch_attribute}_max'
        else:
            name = f'{batch_attribute}_{grabname}_max'
        
        super().__init__(batch_attribute, grabname, name=name)
        
        
    def setup(self):
        log = self.environment.log
        log.add_metric(self.name)
        
    def after_compute_reward(self):
        
        values = self.get_values()
        self.environment.log.update_metric(self.name, values.max())
        
class PercentileCallback(StatsCallback):
    def __init__(self, batch_attribute, grabname, percentile, order=20):
        
        if grabname is None:
            name = f'{batch_attribute}_p{percentile}'
        else:
            name = f'{batch_attribute}_{grabname}_p{percentile}'
        
        super().__init__(batch_attribute, grabname, name=name)
        self.percentile = percentile
        
    def setup(self):
        log = self.environment.log
        log.add_metric(self.name)
        
    def after_compute_reward(self):
        
        values = self.get_values()
        self.environment.log.update_metric(self.name, np.percentile(values, self.percentile))
        

In [None]:
# hide
from nbdev.export import notebook2script; notebook2script()