In [1]:
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import numpy as np
import seaborn as sns
import pandas as pd
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [2]:
ents = [0.0, 0.01, 0.05, 0.1, 0.4, 0.8]
clips = [0.005, 0.02, 0.08, 0.2, 0.5, 0.8, 1e6]

In [3]:
def plot_sparsity(env_name, ents, clips, std_scale=0.1, y_maxr=None, y_maxp=None, tick=1, k=500):
    with open(f"./shared/{env_name}_Wr.npy", 'rb') as f:
        Wr = np.load(f)
    with open(f"./shared/{env_name}_Wp.npy", 'rb') as f:
        Wp = np.load(f)
        
    assert len(Wr.shape) == len(Wp.shape) == 5  # (n_seeds, n_ent, n_clip, d_out, d_in)
    assert Wr.shape[1:3] == Wp.shape[1:3] == (len(ents), len(clips))
    print("Wr shape:", Wr.shape)
    print("Wp shape:", Wp.shape)

    # Abs
    Wr = np.abs(Wr)
    Wp = np.abs(Wp)
    
    Wr = np.reshape(Wr, (*Wr.shape[:-2], -1))
    Wp = np.reshape(Wp, (*Wp.shape[:-2], -1))
    assert Wr.shape[1:3] == Wp.shape[1:3] == (len(ents), len(clips))

    Wr = np.sort(Wr, axis=-1)[..., ::1]
    Wp = np.sort(Wp, axis=-1)[..., ::1]

    print("Max Wr:", np.max(np.mean(Wr, axis=0)))
    print("Max Wp:", np.max(np.mean(Wp, axis=0)))
            
    # Normalize
    Wr /= np.max(Wr, axis=-1, keepdims=True)
    Wp /= np.max(Wp, axis=-1, keepdims=True)

    
    Wr = Wr[..., :k]
    wp = Wp[..., :k]
    print("\nAbs values of representation weight matrix")
    plt.figure(figsize=(24, 15))
    indices = range(Wr.shape[-1])
    for j in range(Wr.shape[1]):

        data_j = pd.concat([pd.DataFrame(Wr[:, j, k, :], index=[f"eps={clips[k] if clips[k]<1000 else None}"]*Wr.shape[0], columns=indices)
                            for k in range(Wr.shape[2])], axis=0)
        assert Wr.shape[1] <= 6
        plt.subplot(2, 3, j+1)
        ax = sns.lineplot(data_j.T, errorbar=("sd", std_scale))
        plt.xlabel("Index")
        plt.ylabel("Weight value")
        plt.ylim(top=y_maxr)
#         ax.xaxis.set_major_locator(plticker.MultipleLocator(base=tick))        
        plt.title(f"Ent. bonus={ents[j]}")
    plt.show()

#     print("\n\nAbs values of logits weight matrix")
#     plt.figure(figsize=(24, 15))    
#     indices = range(Wp.shape[-1])
#     for j in range(Wp.shape[1]):

#         data_j = pd.concat([pd.DataFrame(Wp[:, j, k, :], index=[f"eps={clips[k] if clips[k]<1000 else None}"]*Wp.shape[0], columns=indices)
#                             for k in range(Wp.shape[2])], axis=0)
#         assert Wp.shape[1] <= 6
#         plt.subplot(2, 3, j+1)
#         ax = sns.lineplot(data_j.T, errorbar=("sd", std_scale))
#         plt.xlabel("Index")
#         plt.ylabel("Weight value")
#         plt.ylim(top=y_maxp)
#         ax.xaxis.set_major_locator(plticker.MultipleLocator(base=tick))        
#         plt.title(f"Ent. bonus={ents[j]}")
#     plt.show()
        
    return Wr, Wp

## 1. Acrobot

In [None]:
Wr, Wp = plot_sparsity("Acrobot-v1", ents, clips)

Wr shape: (8, 6, 7, 64, 64)
Wp shape: (8, 6, 7, 64, 3)
Max Wr: 2.3311603
Max Wp: 0.7949325

Abs values of representation weight matrix


## 2. Asterix

In [None]:
Wr, Wp = plot_sparsity("Asterix-MinAtar", ents, clips)

## 3. Breakout

In [None]:
Wr, Wp = plot_sparsity("Breakout-MinAtar", ents, clips)

## 4. CartPole

In [None]:
Wr, Wp = plot_sparsity("CartPole-v1", ents, clips)

## 5. Freeway

In [None]:
Wr, Wp = plot_sparsity("Freeway-MinAtar", ents, clips)

## 6. MountainCar

In [None]:
Wr, Wp = plot_sparsity("MountainCar-v0", ents, clips)

## 7. SpaceInvaders

In [None]:
Wr, Wp = plot_sparsity("SpaceInvaders-MinAtar", ents, clips)