# Suspect of too low accuracy in Lesion model
Vocabulary: 
- boardcasting: when adding tensor with different dimension, for example input_p (dim = [batch_size, 250]) and bias_p (dim = [1, 250]), bias_p will automatically "boardcast" to (batch_size, 250) during the elementwise addition

When input information have not reach a layer, it will assume the first axis (batch_size) = 1 
however, the batch_size won't matter... since tf.matmul batch_axis is independent with each other...  


In [None]:
import tensorflow as tf
import pandas as pd
import altair as alt
import os, random
import meta, modeling, data_wrangling, evaluate
from importlib import reload
reload(modeling)

In [None]:
class SemanticDiagnosis:
    """A diagnoistic bundle to trouble shot activation and input in semantic layer"""
    def __init__(self, code_name:str):
        self.cfg = meta.ModelConfig.from_json(
            os.path.join("models", code_name, "model_config.json")
        )
        
    def eval(self, testset_name:str, task:str, epoch:int):
        self.testset_package = data_wrangling.load_testset(os.path.join('dataset', 'testsets', f"{testset_name}.pkl.gz"))
        
        # Force the 
        self.cfg.output_ticks = 13
        self.model = modeling.MyModel(self.cfg)
        self.model.load_weights(self.cfg.saved_weights_fstring.format(epoch=epoch))
        self.model.set_active_task("triangle")
        self.y_pred = self.model([self.testset_package['ort']] * self.cfg.n_timesteps)

        # Get data for evaluate object
        self.testset = evaluate.TestSet(self.cfg)
        self.df = self.testset.eval(testset_name, task)

    def set_target_word(self, word:str):
        self.word_df = self.make_semantic_output_diagnostic_df(word)

    @property
    def all_outputs(self) -> list:
        return list(self.y_pred.keys())

    @property
    def all_weights(self) -> list:
        return [w.name for w in self.model.weights]

    @property
    def all_words(self) -> list:
        return list(self.df.word.unique())

    def make_semantic_output_diagnostic_df(self, target_word:str) -> pd.DataFrame:
        """Output all Semantic related input and activation in a word"""
        
        target_word_idx = self.testset_package['item'].index(target_word)

        # Time invariant elements
        df_dict = {}
        df_dict['target_act'] = self.testset_package['sem'][target_word_idx, :].numpy()
        df_dict['bias'] = [w.numpy() for w in self.model.weights if w.name.endswith('bias_s:0')][0]
        df_time_invar = pd.DataFrame.from_dict(df_dict)
        df_time_invar['unit'] = df_time_invar.index
        df_time_invar['word'] = target_word

        # Time varying elements
        df_time_varying = pd.DataFrame()
        NAME_MAP = {'input_hps_hs':'PS', 'input_css_cs':'CS', 'input_sem_ss':'SS', 'input_hos_hs':'OS', 'sem':'act_sem'}
        
        for i, input_name in enumerate(('input_hps_hs', 'input_css_cs', 'input_sem_ss', 'input_hos_hs', 'sem')):
            this_input_df = pd.DataFrame()
            for t in range(13):
                df_dict = {}
                df_dict[NAME_MAP[input_name]] = self.y_pred[input_name][t,target_word_idx,:].numpy()
                this_step_df = pd.DataFrame.from_dict(df_dict)
                this_step_df['timetick'] = t
                this_step_df['unit'] = this_step_df.index
                this_input_df = pd.concat([this_input_df, this_step_df])

            if i == 0:
                df_time_varying = this_input_df
            else:
                df_time_varying = pd.merge(df_time_varying, this_input_df, on=['timetick', 'unit'])

        # Merge and export
        df = df_time_varying.merge(df_time_invar, on='unit', how='left')
        return df[['word', 'unit', 'timetick', 'target_act', 'act_sem', 'bias', 'OS', 'PS', 'CS', 'SS']]

    
    def plot_diagnosis(self, target_act: int) -> alt.Chart:
        """Plot all in one diagnosis"""

        df = self.word_df.loc[self.word_df.target_act == target_act]
        df = df.melt(id_vars=['word', 'unit', 'timetick', 'target_act'], value_vars=['act_sem', 'CS', 'OS', 'PS', 'SS', 'bias'])
        sel_units = list(df.unit.unique())

        # Random sample if there are too many nodes
        if len(sel_units) > 20:
            sample_units = random.sample(sel_units, 20)
            df = df.loc[df.unit.isin(sample_units)]

        sel_var = alt.selection_multi(fields=["variable"], bind="legend")
        sel_unit = alt.selection_multi(fields=["unit"], bind="legend")

        plot_input_pathway = (
            alt.Chart(df.loc[df.variable != "act_sem"])
            .mark_line(point=True)
            .encode(
                y="mean(value):Q",
                x="timetick",
                color="variable",
                opacity=alt.condition(sel_var, alt.value(1), alt.value(0.1)),
            )
            .add_selection(sel_var)
        )

        plot_input_unit = (
            plot_input_pathway.encode(
                color="unit:N",
                opacity=alt.condition(sel_unit, alt.value(1), alt.value(0.1)),
            )
            .transform_filter(sel_var)
            .add_selection(sel_unit)
        )

        plot_act_unit = (
            alt.Chart(df.loc[df.variable == "act_sem"])
            .mark_line(point=True)
            .encode(
                y=alt.Y("value:Q", scale=alt.Scale(domain=(0, 1))),
                x="timetick",
                color=alt.Color("unit:N", legend=None),
                opacity=alt.condition(sel_unit, alt.value(1), alt.value(0.1)),
                tooltip=["unit"],
            )
            .add_selection(sel_unit)
            .properties(title="Activation time course in each unit (with target node 1)")
        )

        return (
            plot_input_pathway.properties(
                title="Input time course in each pathway (with target node 1)"
            )
            | plot_input_unit.properties(
                title="Input time course in each unit (with target node 1)"
            )
        ).resolve_scale(color="independent", y="shared") | plot_act_unit.resolve_scale(y="independent")

In [None]:
sem = SemanticDiagnosis(code_name='Refrac_5M_fix')
sem.eval(testset_name='train_r100', task='triangle', epoch=400)
sem.set_target_word('staffs')

In [None]:
sem.plot_diagnosis(target_act=1)

In [None]:
sem.plot_diagnosis(target_act=0)

In [None]:
def mean_unit_input(tensor, mask=None):
    """Mask with output target, reduce sum on items (axis1), mean on units (axis2)"""
    if mask is not None:
        tensor = tf.multiply(tensor, mask)
    return tf.reduce_mean(tf.reduce_mean(tensor, axis=-1), axis=-1)



In [None]:

results_dict = {k: mean_unit_input(globals()[k], train100['sem']) for k in ('os', 'ps', 'cs', 'ss')}
results_dict['op'] = mean_unit_input(op, train100['pho'])
dol_df = pd.DataFrame.from_dict(results_dict)
dol_df.plot(title="one target")

In [None]:
results_dict = {k: mean_unit_input(globals()[k], 1-train100['sem']) for k in ('os', 'ps', 'cs', 'ss')}
results_dict['op'] = mean_unit_input(op, 1-train100['pho'])
dol_df = pd.DataFrame.from_dict(results_dict)
dol_df.plot(title="zero target")

In [None]:
bias_s = [w for w in model.weights if w.name == "my_model/bias_s:0"]
bias_p = [w for w in model.weights if w.name == "my_model/bias_p:0"]
print(f"mean bias S: {tf.reduce_mean(bias_s)}; P: {tf.reduce_mean(bias_p)}")

- Kind of neutral P bias
- Positive bias in S??? is counter int

In [None]:
tf.