# Visualizations

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

from utils import PRIORS
%matplotlib inline

### Methods

In [None]:
def plot_alpha_mask(masks, inputs, fig_size, num_samples=5, slot_plot=False, save=False):
    """Plot and save masks of random samples from eval batch"""
    idxs = np.random.randint(32, size=num_samples)
    K, L = masks.shape[1], masks.shape[2]
    fig = plt.figure(figsize=fig_size)
    colors = [(0.9, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), (1.0, 0.5, 0.0), 
              (0.5, 0.0, 1.0), (1.0, 0.0, 1.0), (1.0, 1.0, 0.0)]
    
    for i in range(num_samples):
        plt.subplot(num_samples, 1, i+1)
        if not slot_plot:
            plt.xticks(np.arange(L), inputs[idxs[i], :].astype(int))
        for s_k in range(K):
            # plot alpha_masks
            plt.plot(range(L), masks[idxs[i], s_k, :, 0], color=colors[s_k], linestyle='solid')
                
    plt.tight_layout(pad=0.5, w_pad=0.5)
    plt.show()
    if save:
        plt.savefig("mask-plt-" + map(str, idxs) + ".png", dpi=300)
                
        
def plot_attention_head(inputs, outputs, attention):
    ax = plt.gca()
    ax.matshow(attention[0])

    ax.set_xticklabels(inputs[0])
    ax.set_yticklabels(outputs[0])

    ax.set_xticks(range(len(inputs[0])))
    ax.set_yticks(range(len(outputs[0])))
    

def plot_self_attn_weights(inputs, outputs, attention_heads, save=False):
    fig = plt.figure(figsize=(25, 25))
    idx = np.random.randint(inputs.shape[0], size=1)
    num_heads = attention_heads.shape[1]
    for h in range(num_heads):
        ax = fig.add_subplot((num_heads // 4)+1, 4, h+1)
        plot_attention_head(inputs[idx], outputs[idx], attention_heads[idx, h])
        ax.set_xlabel(f'Head {h+1}')

    plt.tight_layout(pad=0.5, w_pad=0.2, h_pad=0.2)
    plt.show()
    if save:
        plt.savefig("attn-weights-plt-" + str(idx) + ".png", dpi=300)


def plot_bw_fw_access(masks, slot_attn_weights, inputs, fig_size, num_samples=1, 
                      pad_token=5, dataset_id="craft", dataset_fname="makeallf", 
                      save=False):
    """Plot and save qualitative viz of BW-FW access -- alpha-masks v/s slot-attn coeffs."""
    colors = [(0.9, 0.0, 0.0), (0.0, 0.5, 0.0), (0.0, 0.0, 1.0), (1.0, 0.65, 0.0), 
              (0.5, 0.0, 1.0), (1.0, 0.0, 1.0), (1.0, 1.0, 0.0)]
    CRAFT_VOCAB = [' \u2193 ', ' \u2191 ', ' \u2190 ', ' \u2192 ', ' u ', ' D ']
    MGRID_VOCAB = [' \u2190 ', ' \u2192 ', ' \u2191 ', ' PICK ', ' DROP ', ' TOGL ', ' DONE ']
    bsz = inputs.shape[0]
    idx = np.random.randint(bsz, size=num_samples)[0]
    K = masks.shape[1]
    # y-axis min/max ranges for mask & SA coeffs
    mask_min, mask_max = 0.0, 0.4
    sa_min, sa_max = 0.6, 1.0
    # convert inputs to int
    inputs = inputs.astype(int)
    if dataset_id == "craft":
        action_tokens = np.asarray([CRAFT_VOCAB[a] for a in inputs[idx, :]])
    elif dataset_id == "minigrid":
        action_tokens = np.asarray([MGRID_VOCAB[a] for a in inputs[idx, :]])
    non_padded_L = np.amin(np.where(inputs[idx, :] == pad_token)[0])
    fig = plt.figure(figsize=fig_size)
    # plot K separate panels of slot_attn v/s alpha_masks
    for s in range(K):
        plt.subplot(K, 1, s+1)
        plt.xticks(np.arange(non_padded_L), action_tokens)
        plt.yticks([mask_min, mask_max, sa_min, sa_max], ['off', 'on', 'off', 'on'])
        plt.ylim([0.0, 1.1])
        # additional processing for bw-fw-plot
        alpha_max, alpha_min = non_padded_L, 0
        # binarize values
        mask = (masks[idx, s, :, 0] > 0.8).astype(float)
        # indices where alpha-mask is "on"
        mask_on_idxs = np.where(mask)[0]
        slot_attn = (slot_attn_weights[idx, :, s] > 0.8).astype(float)
        if not mask_on_idxs.size == 0:
            # "filter" out only past and future timesteps of slot_attn_coeffs
            alpha_min = np.amin(mask_on_idxs)
            alpha_max = np.clip(np.amax(mask_on_idxs), 0, non_padded_L)
        # plot slot_attn_coeffs
        filtered_sa_coeffs = np.zeros((non_padded_L,))
        # "past" indices 
        filtered_sa_coeffs[0:alpha_min] = slot_attn[0:alpha_min]
        # "future" indices 
        filtered_sa_coeffs[alpha_max:non_padded_L] = slot_attn[alpha_max:non_padded_L]
        # re-norm mask [0, 0.45] & SA [0.55, 1.0] coeffs in their respective ranges
        norm_mask = mask * mask_max
        norm_filtered_sa_coeffs = (filtered_sa_coeffs * (sa_max - sa_min)) + sa_min
        # plot lines -- alpha_mask & slot_attn
        plt.plot(range(non_padded_L), norm_mask[0:non_padded_L], color=colors[s], 
                 linestyle='solid', linewidth=2.0)
        plt.axhline(y=0.5, color='k', linestyle='-', linewidth=1.0)
        plt.plot(range(non_padded_L), norm_filtered_sa_coeffs, color=colors[s], 
                 linestyle='dotted', linewidth=2.0)
        
    plt.tight_layout()
    if save:
        plt.savefig("bw-fw-access-" + dataset_fname + "-" + str(idx) + ".png", dpi=300)

#### Example: Viz decoder alpha masks

In [None]:
# load data logs
dataset = "craft-skills/"
f_dir = "hcrmc/logs/"
eval_step = 9
f_name = "eval_logs_" + str(eval_step) + ".npz"
data = np.load(dataset + f_dir + f_name, allow_pickle=True)
# plot decoder alpha masks
plot_alpha_mask(data['masks'], data['actions'], (10, 10))


#### Example: Viz self-attention of Transformer Encoder

In [None]:
# load data logs
dataset = "" 
f_dir = ""
eval_step = -1
f_name = "eval_logs_" + str(eval_step) + ".npz"
data = np.load(dataset + f_dir + f_name, allow_pickle=True)
plot_self_attn_weights(data['actions'].astype(int), data['actions'].astype(int), data['enc_attn_weights'])


#### Example: Viz slot-attention weights 

In [None]:
# load data logs
dataset = ""
f_dir = ""
eval_step = -1
f_name = "eval_logs_" + str(eval_step) + ".npz"
data = np.load(dataset + f_dir + f_name, allow_pickle=True)
idxs = [3, 5, 7, 11, 13]
# plot slot-attn coeffs
plot_alpha_mask(np.transpose(data['slot_attn_weights'], axes=[0, 2, 1]), data['actions'], idxs, 
          (10, 8), slot_plot=True)

#### Example: viz BW-FW Access -- slot_attn v/s alpha_mask

In [None]:
# load data from logs
root_dir = ""
logs_dir = ""
eval_step = -1
f_name = "eval_logs_" + str(eval_step) + ".npz"
data = np.load(root_dir + logs_dir + f_name, allow_pickle=True)
# plot decoder alpha masks
plot_bw_fw_access(data['masks'], data['slot_attn_weights'], data['actions'], (9, 6), 
                  pad_token=6, dataset_id="minigrid", dataset_fname="ulpkp", save=True)


#### Histogram plots of empirical prior dist.

In [None]:
prior = PRIORS["keycorridor-s4r3"]
plt.xlim([1, len(prior)])
plt.xticks(np.arange(1, len(prior)+1, 1))
plt.xlabel('Number of sub-routines')
plt.ylabel('Halting probability')
plt.title('Empirical prior distribution')
plt.bar(list(range(1, len(prior)+1, 1)), prior, width=0.2)
plt.grid(True)
plt.savefig("emp-prior-kcs4r3.png", dpi=300)