# CKA
Compute CKA, plot and save CKA matrix.


In [4]:
import numpy as np
import torch
import sys
import os

from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from util_cka import cka, gram_linear

sys.path.append('../')
from sample_batch_data import get_data_info, get_batch
from signal_propagation import get_activation
from set_config import generate_variant

## Functions for CKA computation

In [5]:
def compute_cka(activation_1, activation_2, reward_state_action, timestep=-1):
    """Compute CKA for either return-to-go, state, or action.

    Args:
        activation_1 (np.ndarray[batchsize, represemtation_dim, time_step_in_context]): Neural activation vector.
        activation_2 (np.ndarray[batchsize, represemtation_dim, time_step_in_context]): Neural activation vector.
        reward_state_action (str): "reward" (return-to-go), "state", or "action".
        timestep (int, optional): Activation at this timestep is used for CKA computation. Defaults to -1.

    Returns:
        np.ndarray: scalar CKA
    """
    # Input is sequence of [..., return-to-go, state, action]
    if reward_state_action == "reward":
        idx = timestep * 3
    elif reward_state_action == "state":
        idx = timestep * 2
    elif reward_state_action == "action":
        idx = timestep * 1
    else:
        print("Specify either 'reward', 'state', or 'action'.")

    if len(activation_1.shape) == 3:
        activation_1 = activation_1[:, :, idx]
    elif len(activation_1.shape) == 4:
        activation_1 = activation_1[:, :, idx, idx]
    if len(activation_2.shape) == 3:
        activation_2 = activation_2[:, :, idx]
    elif len(activation_2.shape) == 4:
        activation_2 = activation_2[:, :, idx, idx]

    cka_from_examples = cka(
        gram_linear(activation_1.numpy()),
        gram_linear(activation_2.numpy()),
        debiased=True,
    )

    return cka_from_examples


def plot_cka(
    path_to_save_figure,
    cka_matrix,
    reward_state_action,
    model1,
    model2,
    env_name,
    dataset_name,
    seed,
    epoch1,
    epoch2,
    block=False,
):
    """Plot CKA heatmap.

    Args:
        path_to_save_figure (str): Path to save figure of CKA heatmap.
        cka_matrix (np.ndarray[dim, dim]): CKA heatmap of activation of two models.
        reward_state_action (str): "reward" (return-to-go), "state", or "action".
        model1 (str): 'gpt2', 'igpt', or 'dt'.
        model2 (str): 'gpt2', 'igpt', or 'dt'.
        env_name (str): 'hopper', 'halfcheetah', or 'walker2d'.
        dataset_name (str): 'medium'.
        seed (int): 666.
        epoch1 (int): 0 or 40.
        epoch2 (int): 0 or 40.
        block (bool, optional): If the activation is that of Transformer block, set this True. Defaults to False.
    """

    sns.set_style("ticks")
    sns.set_context("paper", 1.5, {"lines.linewidth": 2})

    ax = sns.heatmap(cka_matrix, vmin=0, vmax=1, square=True)
    ax.invert_yaxis()
    if model1 == "dt":
        label1 = "random init"
    else:
        label1 = model1
    if model2 == "dt":
        label2 = "random init"
    else:
        label2 = model2
    if block:
        plt.xlabel(f"{label2.upper()} Block")
        plt.ylabel(f"{label1.upper()} Block")
    else:
        plt.xlabel(f"{label2.upper()} Layers")
        plt.ylabel(f"{label1.upper()} Layers")
    plt.tight_layout()
    if block:
        plt.savefig(
            f"{path_to_save_figure}/cka_block_{epoch1}_{epoch2}_{model1}{model2}_{env_name}_{dataset_name}_{seed}_{reward_state_action}.pdf"
        )
    else:
        plt.savefig(
            f"{path_to_save_figure}/cka_{epoch1}_{epoch2}_{model1}{model2}_{env_name}_{dataset_name}_{seed}_{reward_state_action}.pdf"
        )
    plt.show()

## Function for running CKA computation

In [6]:
def run_cka(
    path_to_dataset,
    path_to_model_checkpoint,
    path_to_save_cka,
    path_to_save_figure,
    seed=666,
    model1='gpt2',
    model2='gpt2',
    epoch1=40,
    epoch2=40,
    env_name_list=['hopper', 'halfcheetah', 'walker2d'],
    block=False,
    no_context=False,
    device="cpu"
    ):
    """Compute CKA and save it as array and fig.

    Args:
        path_to_model_checkpoint (str): Path to load model checkpoint.
        path_to_save_cka (str): Path to save CKA matrix as np.array.
        path_to_save_figure (str): Path to save figure of CKA heatmap.
        seed (int, optional): Random seed. Defaults to 666.
        model1 (str, optional): 'gpt2', 'igpt', or 'dt'. Defaults to 'gpt2'.
        model2 (str, optional): 'gpt2', 'igpt', or 'dt'. Defaults to 'gpt2'.
        epoch1 (int, optional): 0 or 40. Defaults to 40.
        epoch2 (int, optional): 0 or 40. Defaults to 40.
        env_name_list (list, optional): environment name list. Defaults to ['hopper', 'halfcheetah', 'walker2d'].
        block (bool, optional): If True, compute CKA for transformer block. Defaults to False.
        no_context (bool, optional): If True, compute CKA of K=1. Defaults to False.
        device (str): cuda or cpu
    """

    os.makedirs(path_to_save_cka, exist_ok=True)
    os.makedirs(path_to_save_figure, exist_ok=True)

    dataset_name = 'medium'
    device = torch.device(device)
    reward_state_action_list = ['action', 'state', 'reward']

    for env_name in env_name_list:
        
        torch.manual_seed(seed)

        variant = generate_variant(epoch1, path_to_model_checkpoint, model1, env_name, seed, dataset_name)

        if no_context:
            variant['load_checkpoint'] = False if epoch1==0 else f'{path_to_model_checkpoint}/{model1}_medium_{env_name}_{seed}_K1/model_{epoch1}.pt'

        state_dim, act_dim, max_ep_len, scale = get_data_info(variant)
        states, actions, rewards, dones, rtg, timesteps, attention_mask = get_batch(variant, state_dim, act_dim, max_ep_len, scale, device, path_to_dataset)

        activation_list = []

        # Get activations of model1 and model2, respectvely.
        for _ in range(2):
    
            # For the first iteration, ues the `variant` defined above (model1).
            # For the second iteration, ues the `variant` defined below (model2).
            activation = get_activation(variant, state_dim, act_dim, max_ep_len, states, actions, rewards, rtg, timesteps, attention_mask, device)
            activation_list.append(activation)

            variant = generate_variant(epoch2, path_to_model_checkpoint, model2, env_name, seed, dataset_name)

            if no_context:
                variant['load_checkpoint'] = False if epoch2==0 else f'{path_to_model_checkpoint}/{model2}_medium_{env_name}_{seed}_K1/model_{epoch2}.pt'
        
        if block:
            for reward_state_action in reward_state_action_list:
                cka_matrix = []
                for key_1, act_1 in tqdm(activation_list[0].items()):
                    # Compute CKA only for output of blocks (e.g. DecisionTransformer.transformer.h[0].mlp.dropout)
                    if ('dropout' in key_1) and ('mlp' in key_1):
                        cka_list = []
                        for key_2, act_2 in activation_list[1].items():
                            if ('dropout' in key_2) and ('mlp' in key_2):
                                cka = compute_cka(act_1, act_2, reward_state_action)
                                cka_list.append(cka)
                        cka_matrix.append(cka_list)
                cka_matrix = np.array(cka_matrix)

                np.save(f'{path_to_save_cka}/cka_block_{epoch1}_{epoch2}_{model1}{model2}_{env_name}_{dataset_name}_{seed}_{reward_state_action}.npy', cka_matrix)
                plot_cka(path_to_save_figure, cka_matrix, reward_state_action, model1, model2, env_name, dataset_name, seed, epoch1, epoch2, block)
        else:
            for reward_state_action in reward_state_action_list:
                cka_matrix = []
                for key_1, act_1 in tqdm(activation_list[0].items()):
                    cka_list = []
                    for key_2, act_2 in activation_list[1].items():
                        cka = compute_cka(act_1, act_2, reward_state_action, timestep=-1)
                        cka_list.append(cka)
                    cka_matrix.append(cka_list)
                cka_matrix = np.array(cka_matrix)

                np.save(f'{path_to_save_cka}/cka_{epoch1}_{epoch2}_{model1}{model2}_{env_name}_{dataset_name}_{seed}_{reward_state_action}.npy', cka_matrix)
                plot_cka(path_to_save_figure, cka_matrix, reward_state_action, model1, model2, env_name, dataset_name, seed, epoch1, epoch2, block)

In [12]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [20]:
import pickle
env_name = 'hopper'
dataset='expert-v2'
with open(f'/home/alexrgilbert/repos/pre-training-different-modality-offline-rl/can-wikipedia-help-offline-rl/data/{env_name}-{dataset}.pkl','rb') as f:
    trajectories = pickle.load(f)

In [21]:
import numpy as np
states, traj_lens, returns = [], [], []
for path in trajectories:
    # if mode == "delayed":  # delayed: all rewards moved to end of trajectory
    #     path["rewards"][-1] = path["rewards"].sum()
    #     path["rewards"][:-1] = 0.0
    states.append(path["observations"])
    traj_lens.append(len(path["observations"]))
    returns.append(path["rewards"].sum())
traj_lens, returns = np.array(traj_lens), np.array(returns)

# used for input normalization
states = np.concatenate(states, axis=0)
state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

num_timesteps = sum(traj_lens)

print("=" * 50)
print(f"Starting new experiment: {env_name} {dataset}")
print(f"{len(traj_lens)} trajectories, {num_timesteps} timesteps found")
print(f"Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}")
print(f"Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}")
print("=" * 50)

Starting new experiment: hopper expert-v2
1027 trajectories, 999494 timesteps found
Average return: 3511.36, std: 328.59
Max return: 3759.08, min: 1645.28


In [34]:
e.time_limit

AttributeError: 'OfflineHopperEnv' object has no attribute 'time_limit'

In [31]:
max_len

240

In [30]:
import gym
import d4rl
env = gym.make('CartPole-v1')

In [27]:
env.action_space

Box(-1.0, 1.0, (2,), float32)

In [4]:
import sys
import torch
from transformers import TransfoXLConfig, TransfoXLModel
from tqdm._tqdm_notebook import tqdm
from pprint import pformat
from collections import defaultdict

config = TransfoXLConfig()
random = TransfoXLModel(config)
pretrained = TransfoXLModel.from_pretrained("transfo-xl-wt103")


def get_activation(model):

    model.eval()

    activation = {}

    def extract_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()

        return hook

    def apply_hook(module,layer_id,name):
        module.register_forward_hook(extract_activation(f"{layer_id}.{name}"))

    for i,layer in enumerate(model.layers):
        apply_hook(layer.dec_attn.layer_norm,i,'dec_attn.layer_norm')
        apply_hook(layer.dec_attn.o_net,i,'dec_attn.o_net')
        apply_hook(layer.dec_attn.qkv_net,i,'dec_attn.qkv_net')
        apply_hook(layer.dec_attn.r_net,i,'dec_attn.r_net')
        apply_hook(layer.pos_ff.CoreNet[0],i,'pos_ff.CoreNet.LinearExpand')
        apply_hook(layer.pos_ff.CoreNet[1],i,'pos_ff.CoreNet.ReLU')
        apply_hook(layer.pos_ff.CoreNet[2],i,'pos_ff.CoreNet.DropoutExpand')
        apply_hook(layer.pos_ff.CoreNet[3],i,'pos_ff.CoreNet.LinearBottleneck')
        apply_hook(layer.pos_ff.CoreNet[4],i,'pos_ff.CoreNet.DropoutBottleneck')

        _, _, _, _ = model.forward(
        states,
        actions,
        rewards,
        rtg[:, :-1],
        timesteps,
        attention_mask=attention_mask,
    )


    for mn,m in model.named_modules():
        for pn,p in m.named_parameters(recurse=False):
            
            parameters[mn].append(f'{pn} ({p.shape})')
            assert pretrained
    print(pformat(parameters))

    for block_id in range(len(model.transformer.h)):
        model.transformer.h[block_id].ln_1.register_forward_hook(
            extract_activation(f"{block_id}.ln_1")
        )
    model.transformer.h[block_id].attn.c_attn.register_forward_hook(
        extract_activation(f"{block_id}.attn.c_attn")
    )
    model.transformer.h[block_id].attn.c_proj.register_forward_hook(
        extract_activation(f"{block_id}.attn.c_proj")
    )
    model.transformer.h[block_id].attn.attn_dropout.register_forward_hook(
        extract_activation(f"{block_id}.attn.attn_dropout")
    )
    model.transformer.h[block_id].attn.resid_dropout.register_forward_hook(
        extract_activation(f"{block_id}.attn.resid_dropout")
    )
    model.transformer.h[block_id].ln_2.register_forward_hook(
        extract_activation(f"{block_id}.ln_2")
    )
    model.transformer.h[block_id].mlp.c_fc.register_forward_hook(
        extract_activation(f"{block_id}.mlp.c_fc")
    )
    model.transformer.h[block_id].mlp.c_proj.register_forward_hook(
        extract_activation(f"{block_id}.mlp.c_proj")
    )
    try:
        model.transformer.h[block_id].mlp.act.register_forward_hook(
            extract_activation(f"{block_id}.mlp.act")
        )
    except:
        pass
    model.transformer.h[block_id].mlp.dropout.register_forward_hook(
        extract_activation(f"{block_id}.mlp.dropout")
    )

    _, _, _, _ = model.forward(
        states,
        actions,
        rewards,
        rtg[:, :-1],
        timesteps,
        attention_mask=attention_mask,
    )

    activation_sorted = {}
    block_name_list = [
        "ln_1",
        "attn.c_attn",
        "attn.c_proj",
        "attn.resid_dropout",
        "ln_2",
        "mlp.c_fc",
        "mlp.c_proj",
        "mlp.act",
        "mlp.dropout",
    ]
    for block_id in range(len(model.transformer.h)):
        for block_name in block_name_list:
            try:
                activation_sorted[f"{block_id}.{block_name}"] = activation[
                    f"{block_id}.{block_name}"
                ]
            except:
                pass

# parameters = defaultdict(list)
# for mn,m in pretrained.named_modules():
#     for pn,p in m.named_parameters(recurse=False):
#         parameters[mn].append(f'{pn} ({p.shape})')
#         assert pretrained
# print(pformat(parameters))




def get_activation(
    variant,
    state_dim,
    act_dim,
    max_ep_len,
    states,
    actions,
    rewards,
    rtg,
    timesteps,
    attention_mask,
    device,
):
    """Get activation of a model.

    Args:
        variant (dict): arguments.
        state_dim (int): dimension of state.
        act_dim (int): dimension of action.
        max_ep_len (int): context length K.
        states (torch.Tensor): a batch of states.
        actions (torch.Tensor): a batch of actions.
        rewards (torch.Tensor): a batch of rewards.
        rtg (torch.Tensor): a batch of return-to-go.
        timesteps (torch.Tensor): a batch of timesteps.
        attention_mask (torch.Tensor): Mask for causal Transformer.
        device (torch.device): torch.device("cuda" if torch.cuda.is_available() else "cpu").

    Returns:
        dict: {layer_name: activation, ...}
    """
    torch.manual_seed(0)
    model = DecisionTransformer(
        args=variant,
        state_dim=state_dim,
        act_dim=act_dim,
        max_length=variant["K"],
        max_ep_len=max_ep_len,
        hidden_size=variant["embed_dim"],
        n_layer=variant["n_layer"],
        n_head=variant["n_head"],
        n_inner=4 * variant["embed_dim"],
        activation_function=variant["activation_function"],
        n_positions=1024,
        resid_pdrop=variant["dropout"],
        attn_pdrop=0.1,
    ).to(device)
    if variant["load_checkpoint"]:
        state_dict = torch.load(variant["load_checkpoint"])
        model.load_state_dict(state_dict)
        print(f"Loaded from {variant['load_checkpoint']}")

    model.eval()

    activation = {}

    def extract_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()

        return hook

    for block_id in range(len(model.transformer.h)):
        model.transformer.h[block_id].ln_1.register_forward_hook(
            extract_activation(f"{block_id}.ln_1")
        )
        model.transformer.h[block_id].attn.c_attn.register_forward_hook(
            extract_activation(f"{block_id}.attn.c_attn")
        )
        model.transformer.h[block_id].attn.c_proj.register_forward_hook(
            extract_activation(f"{block_id}.attn.c_proj")
        )
        model.transformer.h[block_id].attn.attn_dropout.register_forward_hook(
            extract_activation(f"{block_id}.attn.attn_dropout")
        )
        model.transformer.h[block_id].attn.resid_dropout.register_forward_hook(
            extract_activation(f"{block_id}.attn.resid_dropout")
        )
        model.transformer.h[block_id].ln_2.register_forward_hook(
            extract_activation(f"{block_id}.ln_2")
        )
        model.transformer.h[block_id].mlp.c_fc.register_forward_hook(
            extract_activation(f"{block_id}.mlp.c_fc")
        )
        model.transformer.h[block_id].mlp.c_proj.register_forward_hook(
            extract_activation(f"{block_id}.mlp.c_proj")
        )
        try:
            model.transformer.h[block_id].mlp.act.register_forward_hook(
                extract_activation(f"{block_id}.mlp.act")
            )
        except:
            pass
        model.transformer.h[block_id].mlp.dropout.register_forward_hook(
            extract_activation(f"{block_id}.mlp.dropout")
        )

    _, _, _, _ = model.forward(
        states,
        actions,
        rewards,
        rtg[:, :-1],
        timesteps,
        attention_mask=attention_mask,
    )

    activation_sorted = {}
    block_name_list = [
        "ln_1",
        "attn.c_attn",
        "attn.c_proj",
        "attn.resid_dropout",
        "ln_2",
        "mlp.c_fc",
        "mlp.c_proj",
        "mlp.act",
        "mlp.dropout",
    ]
    for block_id in range(len(model.transformer.h)):
        for block_name in block_name_list:
            try:
                activation_sorted[f"{block_id}.{block_name}"] = activation[
                    f"{block_id}.{block_name}"
                ]
            except:
                pass

    return activation_sorted

Some weights of the model checkpoint at transfo-xl-wt103 were not used when initializing TransfoXLModel: ['crit.out_projs.0', 'crit.out_layers.0.weight', 'crit.out_layers.2.weight', 'crit.out_projs.2', 'crit.out_layers.3.bias', 'crit.out_layers.0.bias', 'crit.out_projs.1', 'crit.cluster_weight', 'crit.out_layers.3.weight', 'crit.out_projs.3', 'crit.out_layers.2.bias', 'crit.out_layers.1.weight', 'crit.cluster_bias', 'crit.out_layers.1.bias']
- This IS expected if you are initializing TransfoXLModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TransfoXLModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


KeyboardInterrupt: 

In [5]:
from transformers import TransfoXLConfig, TransfoXLModel

# # Initializing a Transformer XL configuration
configuration = TransfoXLConfig()

# # Initializing a model (with random weights) from the configuration
model = TransfoXLModel(configuration)

# # Accessing the model configuration
# configuration = model.config
model
parameters = defaultdict(list)
for mn,m in pretrained.named_modules():
    for pn,p in m.named_parameters(recurse=False):
        parameters[mn].append(f'{pn} ({p.shape})')
        assert pretrained
print(pformat(parameters))

defaultdict(<class 'list'>,
            {'layers.0.dec_attn': ['r_r_bias (torch.Size([16, 64]))',
                                   'r_w_bias (torch.Size([16, 64]))'],
             'layers.0.dec_attn.layer_norm': ['weight (torch.Size([1024]))',
                                              'bias (torch.Size([1024]))'],
             'layers.0.dec_attn.o_net': ['weight (torch.Size([1024, 1024]))'],
             'layers.0.dec_attn.qkv_net': ['weight (torch.Size([3072, 1024]))'],
             'layers.0.dec_attn.r_net': ['weight (torch.Size([1024, 1024]))'],
             'layers.0.pos_ff.CoreNet.0': ['weight (torch.Size([4096, 1024]))',
                                           'bias (torch.Size([4096]))'],
             'layers.0.pos_ff.CoreNet.3': ['weight (torch.Size([1024, 4096]))',
                                           'bias (torch.Size([1024]))'],
             'layers.0.pos_ff.layer_norm': ['weight (torch.Size([1024]))',
                                            'bias (torch.S

In [22]:
model.layers[0]

RelPartialLearnableDecoderLayer(
  (dec_attn): RelPartialLearnableMultiHeadAttn(
    (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)
    (drop): Dropout(p=0.1, inplace=False)
    (dropatt): Dropout(p=0.0, inplace=False)
    (o_net): Linear(in_features=1024, out_features=1024, bias=False)
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (r_net): Linear(in_features=1024, out_features=1024, bias=False)
  )
  (pos_ff): PositionwiseFF(
    (CoreNet): Sequential(
      (0): Linear(in_features=1024, out_features=4096, bias=True)
      (1): ReLU(inplace=True)
      (2): Dropout(p=0.1, inplace=False)
      (3): Linear(in_features=4096, out_features=1024, bias=True)
      (4): Dropout(p=0.1, inplace=False)
    )
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
)

In [15]:
for mn,m in model.layers[0].named_modules():
    print(mn)


dec_attn
dec_attn.qkv_net
dec_attn.drop
dec_attn.dropatt
dec_attn.o_net
dec_attn.layer_norm
dec_attn.r_net
pos_ff
pos_ff.CoreNet
pos_ff.CoreNet.0
pos_ff.CoreNet.1
pos_ff.CoreNet.2
pos_ff.CoreNet.3
pos_ff.CoreNet.4
pos_ff.layer_norm


## Run

In [7]:
path_to_dataset = 'path_to_dataset'
path_to_model_checkpoint =  'path_to_model_checkpoint'
path_to_save_cka = 'path_to_save_cka'
path_to_save_figure = 'path_to_save_figure'
cka_matrix = run_cka(
    path_to_dataset,
    path_to_model_checkpoint,
    path_to_save_cka,
    path_to_save_figure,
    seed=666,
    model1='gpt2',
    model2='gpt2',
    epoch1=40,
    epoch2=40,
    env_name_list=['hopper', 'halfcheetah', 'walker2d'],
    block=False,
    no_context=False,
    device="cpu"
    )

  spec.namespace = self._ns


Exception: 
Missing path to your environment variable. 
Current values LD_LIBRARY_PATH=/home/alexrgilbert/.conda/envs/cs330/lib/python3.9/site-packages/cv2/../../lib64:
Please add following line to .bashrc:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/alexrgilbert/.mujoco/mujoco210/bin