# Interpretability Study

Credits
- [PyTorch](https://github.com/pytorch/pytorch)
- [pytorch-gradcam](https://github.com/vickyliin/gradcam_plus_plus-pytorch)
- [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3)
- [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo)

### Import utils

In [None]:
from scripts.interpretability_utils import *

### Create and wrap the game env

In [None]:
game = 'PongNoFrameskip-v4'
n_envs = 1
seed = 0
vec_env = make_atari_env(game, n_envs=n_envs, seed=seed)
vec_env = VecFrameStack(vec_env, n_stack=4)

In [None]:
wrap_with_vectranspose = is_image_space(vec_env.observation_space) and not is_image_space_channels_first(vec_env.observation_space)

In [None]:
# wrap the vec_env with VecTransposeImage if wrap_with_vectranspose is True
if wrap_with_vectranspose:
    vec_env = VecTransposeImage(vec_env)
    print("VecTransposeImage is applied.")
else:
    warnings.warn("VecTransposeImage is not applied.")

In [None]:
# get action meanings
action_meanings = deepcopy(vec_env.unwrapped.envs[0].unwrapped.get_action_meanings())
print(f"action meanings: {action_meanings}")

### Re-generate 10 random observations (optional)

You can generate random observations by running the cell below recursively.

In [None]:
# my_obs_list = get_random_obs(1)

# for idx, obs in enumerate(my_obs_list):
#     frames = obs.squeeze()
#     fig = plt.figure(idx, figsize=(10,4))
#     for pos, frame in enumerate(frames):
#         ax = fig.add_subplot(1, 4, pos+1)
#         plt.imshow(frame, cmap='gray', vmin=0, vmax=255)
#         plt.axis('off')
# plt.show()

You can save the observations as a `.npy` file. 

**Note: this may overwrite the existing observations. Do use a different filename!**

In [None]:
# with open('gradcam/pong_new_1.npy', 'wb') as f:
#     np.save(f, my_obs_list[0])

### Load all saved observation files

In [None]:
for idx in range(10):
    with open(f'gradcam/pong_{idx+1}.npy', 'rb') as f:
        obs = np.load(f)
        fig = plt.figure(idx, figsize=(10,4))
        for pos, frame in enumerate(obs.squeeze()):
            ax = fig.add_subplot(1, 4, pos+1)
            plt.imshow(frame, cmap='gray', vmin=0, vmax=255)
            plt.axis('off')
        # plt.savefig(f"gradcam/pong_{idx+1}.pdf", dpi=300)
        plt.show()

### Define parameters

In [None]:
DEVICE = "cpu"
CHECKPOINT_EXT = "model_checkpoint_3000000_steps"
GAME_VERSION = "NoFrameskip-v4"
game_no_version = game.replace(GAME_VERSION, '')
sat_to_seed_mapping = {
    'CWCA': 42,
    'NA': 42,
    'SWA': 42,
    'CWRA': 1234,
    'CWRCA': 0,
}
sat_list = ['CWCA', 'NA', 'SWA', 'CWRA', 'CWRCA'] # put winner agent at the leftmost

### Generate saliency maps (target layer c1)

target_layer = policy.features_extractor.c1

deterministic = True

In [None]:
heatmaps_all = [] # gradcam
actions_all = [] # actions selected by the agent (depends on deterministic parameter)
values_all = [] # value_net output
logits_all = [] # action_net outputs
scores_all = [] # logits selected based on actions
gradients_all = [] # gradients of the target layer
activations_all = [] # feed-forward output of the target layer
attended_feature_maps_all = [] # self_attn_layer outputs

h = 84
w = 84
deterministic = True

# each row holds all heatmaps for an obs 
# each column is an sat model
for obs_idx in range(10):
    # load observation .npy file
    with open(f'gradcam/pong_{obs_idx+1}.npy', 'rb') as f:
        obs = np.load(f)
    # convert observation to torch tensor
    obs_th = th.FloatTensor(obs)
    heatmaps_per_obs = []
    actions_per_obs = []
    values_per_obs = []
    logits_per_obs = []
    scores_per_obs = []
    gradients_per_obs = []
    activations_per_obs = []
    attended_feature_maps_per_obs = []
    for sat in sat_list:
        # instantiate the PPO model
        policy_kwargs = eval(f"dict(features_extractor_class=SelfAttnCNNPPO, features_extractor_kwargs=dict(self_attn='{sat}'), net_arch=[])")
        model = PPO(policy='CnnPolicy', env=vec_env, seed=seed, verbose=True, device=DEVICE, policy_kwargs=policy_kwargs) # seed is the env's seed set at the beginning
        # update the model by loading the selected checkpoint zip file
        # get seed used in the model checkpoint file
        seed_checkpoint = sat_to_seed_mapping.get(sat)
        model_updated = model.load(f"gradcam/{game_no_version}_{sat}_{seed_checkpoint}_{CHECKPOINT_EXT}", device=DEVICE, print_system_info=False)
        # get the policy and the target layer
        policy = model_updated.policy
        target_layer = policy.features_extractor.c1
        # instantiate a GradCAM object
        gradcam = GradCAM(policy, target_layer, sat=sat)
        # call the forward() in GradCAM
        saliency_map, action, value, logit, score, gradients, activations, attended_feature_maps = gradcam(obs_th, deterministic=deterministic)
        # convert saliency map to heatmap
        heatmap = convert_to_heatmap(saliency_map)
        heatmaps_per_obs.extend([heatmap])
        actions_per_obs.extend([action])
        values_per_obs.extend([value])
        logits_per_obs.extend([logit])
        scores_per_obs.extend([score])
        gradients_per_obs.extend([gradients])
        activations_per_obs.extend([activations])
        # preprocess attended_feature_maps before converting to headmap style
        if attended_feature_maps is not None:
            attended_feature_maps_sum = attended_feature_maps.sum(1, keepdim=True) # shape=(1, 1, 20, 20)
            attended_feature_maps_upsample = F.interpolate(attended_feature_maps_sum, size=(h,w), mode='bilinear', align_corners=False) # upsample to input size (1, 1, 84, 84)
            attended_feature_maps_upsample_min, attended_feature_maps_upsample_max = attended_feature_maps_upsample.min(), attended_feature_maps_upsample.max() # get min and max
            attended_feature_maps_norm = (attended_feature_maps_upsample - attended_feature_maps_upsample_min).div(attended_feature_maps_upsample_max - attended_feature_maps_upsample_min).data
            # convert to heatmap style (color)
            attended_feature_maps_color = convert_to_heatmap(attended_feature_maps_norm)
            attended_feature_maps_per_obs.extend([attended_feature_maps_color])
    heatmaps_all.extend(heatmaps_per_obs)
    actions_all.extend(actions_per_obs)
    values_all.extend(values_per_obs)
    logits_all.extend(logits_per_obs)
    scores_all.extend(scores_per_obs)
    gradients_all.extend(gradients_per_obs)
    activations_all.extend(activations_per_obs)
    attended_feature_maps_all.extend(attended_feature_maps_per_obs)

#### Heatmaps

In [None]:
grid_image_heatmap = make_grid(heatmaps_all, nrow=5)
grid_image_heatmap_PIL = transforms.ToPILImage()(grid_image_heatmap)

In [None]:
grid_image_heatmap_PIL

In [None]:
grid_image_heatmap_PIL.save("gradcam/heatmap_c1_deterministic_true.pdf")

#### Attended feature maps

In [None]:
grid_image_attended_feature_map = make_grid(attended_feature_maps_all, nrow=4)
grid_image_attended_feature_map_PIL = transforms.ToPILImage()(grid_image_attended_feature_map)

In [None]:
grid_image_attended_feature_map_PIL

In [None]:
grid_image_attended_feature_map_PIL.save("gradcam/attended_feature_map_c1_deterministic_true.pdf")

In [None]:
actions_all

In [None]:
actions_CWCA = actions_all[0::5]
actions_CWCA

In [None]:
actions_CWRCA = actions_all[4::5]
actions_CWRCA

In [None]:
values_all

In [None]:
values_CWCA = values_all[0::5]
values_CWCA

In [None]:
values_CWRCA = values_all[4::5]
values_CWRCA

In [None]:
logits_all

In [None]:
logits_CWCA = logits_all[0::5]
logits_CWCA

In [None]:
logits_CWRCA = logits_all[4::5]
logits_CWRCA

### Generate saliency maps (target layer c2)

target_layer = policy.features_extractor.c2

deterministic = True

In [None]:
heatmaps_all = [] # gradcam
actions_all = [] # actions selected by the agent (depends on deterministic parameter)
values_all = [] # value_net output
logits_all = [] # action_net outputs
scores_all = [] # logits selected based on actions
gradients_all = [] # gradients of the target layer
activations_all = [] # feed-forward output of the target layer
attended_feature_maps_all = [] # self_attn_layer outputs

h = 84
w = 84
deterministic = True

# each row holds all heatmaps for an obs 
# each column is an sat model
for obs_idx in range(10):
    # load observation .npy file
    with open(f'gradcam/pong_{obs_idx+1}.npy', 'rb') as f:
        obs = np.load(f)
    # convert observation to torch tensor
    obs_th = th.FloatTensor(obs)
    heatmaps_per_obs = []
    actions_per_obs = []
    values_per_obs = []
    logits_per_obs = []
    scores_per_obs = []
    gradients_per_obs = []
    activations_per_obs = []
    attended_feature_maps_per_obs = []
    for sat in sat_list:
        # instantiate the PPO model
        policy_kwargs = eval(f"dict(features_extractor_class=SelfAttnCNNPPO, features_extractor_kwargs=dict(self_attn='{sat}'), net_arch=[])")
        model = PPO(policy='CnnPolicy', env=vec_env, seed=seed, verbose=True, device=DEVICE, policy_kwargs=policy_kwargs) # seed is the env's seed set at the beginning
        # update the model by loading the selected checkpoint zip file
        # get seed used in the model checkpoint file
        seed_checkpoint = sat_to_seed_mapping.get(sat)
        model_updated = model.load(f"gradcam/{game_no_version}_{sat}_{seed_checkpoint}_{CHECKPOINT_EXT}", device=DEVICE, print_system_info=False)
        # get the policy and the target layer
        policy = model_updated.policy
        target_layer = policy.features_extractor.c2
        # instantiate a GradCAM object
        gradcam = GradCAM(policy, target_layer, sat=sat)
        # call the forward() in GradCAM
        saliency_map, action, value, logit, score, gradients, activations, attended_feature_maps = gradcam(obs_th, deterministic=deterministic)
        # convert saliency map to heatmap
        heatmap = convert_to_heatmap(saliency_map)
        heatmaps_per_obs.extend([heatmap])
        actions_per_obs.extend([action])
        values_per_obs.extend([value])
        logits_per_obs.extend([logit])
        scores_per_obs.extend([score])
        gradients_per_obs.extend([gradients])
        activations_per_obs.extend([activations])
        # preprocess attended_feature_maps before converting to headmap style
        if attended_feature_maps is not None:
            attended_feature_maps_sum = attended_feature_maps.sum(1, keepdim=True) # shape=(1, 1, 20, 20)
            attended_feature_maps_upsample = F.interpolate(attended_feature_maps_sum, size=(h,w), mode='bilinear', align_corners=False) # upsample to input size (1, 1, 84, 84)
            attended_feature_maps_upsample_min, attended_feature_maps_upsample_max = attended_feature_maps_upsample.min(), attended_feature_maps_upsample.max() # get min and max
            attended_feature_maps_norm = (attended_feature_maps_upsample - attended_feature_maps_upsample_min).div(attended_feature_maps_upsample_max - attended_feature_maps_upsample_min).data
            # convert to heatmap style (color)
            attended_feature_maps_color = convert_to_heatmap(attended_feature_maps_norm)
            attended_feature_maps_per_obs.extend([attended_feature_maps_color])
    heatmaps_all.extend(heatmaps_per_obs)
    actions_all.extend(actions_per_obs)
    values_all.extend(values_per_obs)
    logits_all.extend(logits_per_obs)
    scores_all.extend(scores_per_obs)
    gradients_all.extend(gradients_per_obs)
    activations_all.extend(activations_per_obs)
    attended_feature_maps_all.extend(attended_feature_maps_per_obs)

#### Heatmaps

In [None]:
grid_image_heatmap = make_grid(heatmaps_all, nrow=5)
grid_image_heatmap_PIL = transforms.ToPILImage()(grid_image_heatmap)

In [None]:
grid_image_heatmap_PIL

In [None]:
grid_image_heatmap_PIL.save("gradcam/heatmap_c2_deterministic_true.pdf")

In [None]:
actions_all

In [None]:
actions_CWCA = actions_all[0::5]
actions_CWCA

In [None]:
actions_CWRCA = actions_all[4::5]
actions_CWRCA

In [None]:
values_all

In [None]:
values_CWCA = values_all[0::5]
values_CWCA

In [None]:
values_CWRCA = values_all[4::5]
values_CWRCA

In [None]:
logits_all

In [None]:
logits_CWCA = logits_all[0::5]
logits_CWCA

In [None]:
logits_CWRCA = logits_all[4::5]
logits_CWRCA

In [None]:
actions_all

In [None]:
actions_CWCA = actions_all[0::5]
actions_CWCA

In [None]:
actions_CWRCA = actions_all[4::5]
actions_CWRCA

In [None]:
values_all

In [None]:
values_CWCA = values_all[0::5]
values_CWCA

In [None]:
values_CWRCA = values_all[4::5]
values_CWRCA

In [None]:
logits_all

In [None]:
logits_CWCA = logits_all[0::5]
logits_CWCA

In [None]:
logits_CWRCA = logits_all[4::5]
logits_CWRCA