In [None]:
%load_ext blackcellmagic

# Declarator for counting words

In [None]:
def word_count(func, *args, **kwargs):
    def wrapper():
        counter = {}
        words = func(*args, **kwargs)
        for word in words:
            if word in counter:
                counter[word] += 1
            else:
                counter[word] = 1
        return func(*args, **kwargs), counter

    return wrapper

# Per parameter effective learning rate in ADAM

## The weight update step in ADAM:

$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{v_t} + \epsilon)$

where

$lr_t = \mathrm{learning\_rate} * \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$

$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$

$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$

The question is: will Adam affects the effective learning rate in P and S differently. In the other words, we want to compare $m_t / (\sqrt{v_t} + \epsilon)$ in PHO and SEM are similar or not.

Let's start from OP and OS only runs. (since it is simpler, one weight only associated to one optimizer instead of n optimizers...)

- PHO (during OP) weight and biases are w_hop_oh, bias_hop, w_hop_hp, w_pc, bias_cpp, w_cp, bias_p
- SEM (during OS) weight and biases are W_hos_oh, bias_hos, w_hos_hs, w_sc, bias_css, w_cs, bias_s

Steps:
1. Get all $m_t / (\sqrt{v_t} + \epsilon)$ in each weights
2. Average? Consider sparsity...
3. Compare  

In [None]:
import tensorflow as tf
import meta, modeling
import seaborn as sns
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from helper import stitch_fig
import pandas as pd
import altair as alt

### Optimizer content in OP

In [None]:
class ELR:
    """Examine the effective learning rate scaling factor in ADAM."""

    def __init__(self, batch_name, code_name, tf_root='/home/jupyter/triangle_model'):
        self.cfg = meta.Config.from_json(os.path.join('models', batch_name, code_name, 'model_config.json'))
        self.cfg.tf_root = tf_root
        self.model = modeling.MyModel(self.cfg)
        self.model.build()
        self.optimizers = {task: tf.keras.optimizers.Adam(learning_rate=self.cfg.learning_rate) for task in self.cfg.task_names}
        self.ckpt = tf.train.Checkpoint(model=self.model, optimizers=self.optimizers)

    def restore(self, epoch=int):
        """Restore model and optimizers from checkpoint."""
        self.ckpt.restore(os.path.join(self.cfg.weight_folder, f'epoch-{epoch}')).expect_partial()

    @staticmethod
    def cal_elr(m, v, e):
        """Calculate the effective learning rate scaling factor."""
        return m / (tf.sqrt(v) + e)

    def get_elr(self, task:str, weight_name:str):
        """Calculate the effective learning rate scaling factor from an optimizer"""
        m = [x for x in self.optimizers[task].weights if x.name == f"{weight_name}/m:0"]
        v = [x for x in self.optimizers[task].weights if x.name == f"{weight_name}/v:0"]
        e = self.optimizers[task].epsilon
        return self.cal_elr(m, v, e).numpy().squeeze()

    @staticmethod
    def _cal_variance(x):
        return x.flatten().var()

    def variance(self, task:str):
        """Calculate the variance of effective learning rate scaling factor in each weight"""
        return {k: self._cal_variance(self.get_elr(task, k)) for k in modeling.WEIGHTS_AND_BIASES[task]}

    def mean(self, task:str):
        """Calculate the mean of effective learning rate scaling factor in each weight"""
        return {k: self.get_elr(task, k).mean() for k in modeling.WEIGHTS_AND_BIASES[task]}

    def make_df(self, task:str, summary_function:str='variance'):
        df = pd.DataFrame()
        _summary_fun = self.variance if summary_function == 'variance' else self.mean
        for i in tqdm(self.cfg.saved_epochs):
            self.restore(epoch=i)
            this_epoch_data = pd.DataFrame(_summary_fun(task), index=[i])
            df = df.append(this_epoch_data)
        return df.reset_index().melt(id_vars='index', var_name='weight', value_name=summary_function)       

    def plot_elr(self, task:str, weight_name:str, ax=None):
        """Plot the effective learning rate scaling factor from an optimizer"""
        elr = self.get_elr(task, weight_name).flatten()
        sns.kdeplot(elr, label=f'{weight_name}', ax=ax)
        plt.legend()
        return ax

    def plot_all(self, task:str, ax=None):
        """Plot all effective learning rate scaling factors in each weight at current loaded epoch"""
        [self.plot_elr(task, x, ax) for x in modeling.WEIGHTS_AND_BIASES[task]]

    def plot_all_over_epochs(self, task:str, xlim=None):
        """Plot all effective learning rate scaling factors in each weight over epochs"""       

        output_folder = os.path.join(self.cfg.plot_folder, task)
        os.makedirs(output_folder, exist_ok=True)

        for i in tqdm(self.cfg.saved_epochs):
            self.restore(epoch=i)
            plt.clf() # Clear figure
            self.plot_all(task) # All weights KDE density at given epoch 
            if xlim is not None:
                plt.xlim(xlim)

            plt.title(f'Epoch {i}')
            plt.savefig(os.path.join(output_folder, f'epoch_{i}.png'))

In [None]:
ort_pho = ELR(batch_name='task_effect', code_name='task_effect_r0027')
df = ort_pho.make_df('ort_pho', summary_function='mean')


In [None]:
op_chart = alt.Chart(df).mark_line().encode(x='index', y='mean', color='weight')
op_chart

In [None]:
op_chart.transform_loess('index', 'mean', groupby=['weight']).mark_line()

In [None]:
ort_sem = ELR(batch_name='task_effect', code_name='task_effect_r0028')
os_df = ort_sem.make_df('ort_sem', summary_function='mean')

In [None]:
os_chart = alt.Chart(os_df).mark_line().encode(x='index', y='mean', color='weight')
os_chart

In [None]:
os_chart.transform_loess('index', 'mean', groupby=['weight']).mark_line()

In [None]:
[x.name for x in ort_pho.optimizers['ort_pho'].weights]

### OP model

In [None]:
m27 = ELR('task_effect', 'task_effect_r0027')
# m27.plot_all_over_epochs('ort_pho', xlim=(-0.1, 0.1))
m27_pngs = [os.path.join(m27.cfg.plot_folder, 'ort_pho', f"epoch_{x}.png") for x in m27.cfg.saved_epochs]
m27_stitch = stitch_fig(m27_pngs, rows=10, columns=5)
m27_stitch.save(m27.cfg.plot_folder + '/OP_ADAM_LR.png')

### OS model

In [None]:
m28 = ELR('task_effect', 'task_effect_r0028')
# m28.plot_all_over_epochs('ort_sem', xlim=(-0.1, 0.1))
m28_pngs = [os.path.join(m28.cfg.plot_folder, 'ort_sem', f"epoch_{x}.png") for x in m28.cfg.saved_epochs]
m28_stitch = stitch_fig(m28_pngs, rows=10, columns=5)
m28_stitch.save(m28.cfg.plot_folder + '/OS_ADAM_LR.png')