In [23]:
import itertools

import torch
import transformers
import datasets
import promptsource.templates
import numpy.typing as npt
import numpy as np
from tqdm import tqdm

from typing import Dict, Any, List, Union


torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f2c341c3670>

In [2]:
model_name = "EleutherAI/pythia-70m-deduped"
remove_sos_token = False
token_pos = -1

dataset_name = "winogrande"
dataset_config = "winogrande_xs"
dataset_split = "validation"
prompt_template = "fill in the blank"

device = "cuda:0"

def create_prompt(example: Dict[str, Any]) -> str:
    return template.apply(example)[0]


templates = promptsource.templates.DatasetTemplates(
    dataset_name, dataset_config
)
template = templates[prompt_template]
dataset = datasets.load_dataset(dataset_name, dataset_config)
assert isinstance(dataset, datasets.dataset_dict.DatasetDict)
dataset = dataset[dataset_split]
prompt_creator = create_prompt

Found cached dataset winogrande (/root/.cache/huggingface/datasets/winogrande/winogrande_xs/1.1.0/a826c3d3506aefe0e9e9390dcb53271070536586bab95849876b2c1743df56e2)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
with torch.device(device):
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        # torch_dtype=torch.float,
    )


In [18]:
with torch.device(device):
    all_representations = []
    prompts = list(map(prompt_creator, dataset))  # type:ignore
    for prompt in tqdm(prompts):
        toks = tokenizer(prompt, return_tensors="pt")
        input_ids = toks["input_ids"].to(torch.device(device))  # type:ignore
        out = model(
            input_ids=input_ids, output_hidden_states=True
        ).hidden_states  # Tuple with elements shape(1, n_tokens, dim)
        assert isinstance(out[0], torch.Tensor)
        # Some models have the representations of special start-of-sentence tokens.
        # We typically do not care about those.
        if remove_sos_token:
            out = tuple((representations[:, 1:, :] for representations in out))
        out = tuple((representations.to("cpu") for representations in out))

        all_representations.append(out)

100%|██████████| 1267/1267 [02:38<00:00,  7.98it/s]


In [64]:
out

(tensor([[[-2.7103e-03, -6.1737e-02, -6.5063e-02,  ..., -6.8963e-05,
            3.5461e-02,  3.5919e-02],
          [ 5.1918e-03, -1.5465e-02, -4.8180e-03,  ..., -1.8982e-02,
            3.5309e-02, -2.7817e-02],
          [ 3.0766e-03,  1.0941e-02,  1.9121e-04,  ...,  2.8191e-03,
            6.4453e-02, -1.6129e-02],
          ...,
          [ 7.9155e-04, -6.3705e-04, -1.1139e-03,  ..., -3.6144e-03,
            8.0505e-02, -1.0635e-02],
          [-1.6052e-02, -1.7929e-02, -4.4281e-02,  ..., -2.3865e-02,
            4.3915e-02, -1.0042e-03],
          [-1.2383e-02, -1.4671e-02,  1.4946e-02,  ..., -2.0737e-02,
           -5.3101e-02, -1.8280e-02]]], dtype=torch.float16),
 tensor([[[ 1.0469,  0.3872, -0.1362,  ...,  0.4907,  0.0220, -0.9336],
          [ 0.5278, -0.2140,  0.2418,  ..., -0.0556, -0.0890, -0.5308],
          [-0.1445, -0.1400, -0.1644,  ..., -0.0718,  0.0792, -0.3521],
          ...,
          [-0.2489,  0.1942,  0.0184,  ..., -0.1066,  0.0500,  0.0839],
          [ 0.04

In [6]:
all_representations[0][3][0]


tensor([[ 0.9805, -0.0605,  0.5303,  ...,  2.1523,  0.6240, -0.3872],
        [ 0.1099, -0.4536,  0.2367,  ...,  1.0400,  1.4150, -0.5562],
        [-0.0522, -0.3398,  0.1909,  ...,  0.7363,  0.8804, -0.7617],
        ...,
        [-0.0394,  0.1534, -0.4124,  ...,  0.6953, -0.6968,  0.2260],
        [ 0.2607, -0.6787, -0.0059,  ...,  0.2876, -0.5625, -0.0554],
        [ 0.6807, -0.7148,  0.9131,  ..., -0.6611,  0.4644,  0.1693]],
       dtype=torch.float16)

In [7]:
all_representations[9][3][0]


tensor([[ 0.9805, -0.0605,  0.5303,  ...,  2.1523,  0.6240, -0.3872],
        [ 0.1099, -0.4536,  0.2367,  ...,  1.0400,  1.4150, -0.5562],
        [-0.0522, -0.3398,  0.1909,  ...,  0.7363,  0.8804, -0.7617],
        ...,
        [-0.2395,  0.3196, -0.4863,  ...,  0.6670, -0.3704, -0.0927],
        [ 0.0042, -0.5381, -0.1251,  ...,  0.3706, -0.2317, -0.2263],
        [ 0.4961, -0.6665,  0.7998,  ..., -0.6533,  0.7349, -0.0464]],
       dtype=torch.float16)

In [8]:
r1 = all_representations[0][3]
r2 = all_representations[3][3]

print(r1.size(), r2.size())
((r1[0, -1, :] - r2[0, -1, :]) == 0).sum()
r1[0, -1, :] - r2[0, -1, :]

torch.Size([1, 41, 512]) torch.Size([1, 47, 512])


tensor([ 1.3770e-01, -5.1270e-02,  1.7139e-01, -4.3213e-02, -5.8105e-02,
         2.3975e-01,  1.0669e-01,  1.7334e-02, -7.8613e-02,  1.2988e-01,
        -8.5083e-02,  4.4006e-02,  1.9531e-01,  9.9487e-02, -1.6846e-01,
         4.0527e-02, -6.5918e-02, -2.0020e-02,  6.2256e-03, -2.5635e-02,
        -5.8472e-02,  3.7012e-01,  4.1504e-03, -7.8125e-03,  9.2285e-02,
        -4.9194e-02, -4.9438e-02,  1.2793e-01,  1.4807e-01, -1.3745e-01,
        -4.6875e-02, -1.4868e-01, -8.7280e-02, -3.2959e-03,  1.7261e-01,
        -2.1729e-02, -5.5664e-02,  2.8076e-02, -1.0620e-02,  1.5015e-02,
         1.0986e-01, -8.3984e-02,  3.2379e-02, -5.5298e-02, -1.3428e-02,
         2.4902e-02,  9.7168e-02, -5.4297e-01, -2.2754e-01, -1.0229e-01,
         6.8604e-02,  9.5947e-02, -1.8408e-01,  8.7646e-02,  3.3691e-02,
         1.2183e-01,  1.0938e-01,  6.8359e-02,  4.5410e-02,  4.1748e-02,
         1.8909e-01, -5.9082e-02, -2.3804e-02,  5.6152e-02, -1.1865e-01,
         2.4841e-02,  7.1533e-02, -6.5430e-02, -1.9

In [19]:
reps_per_input = all_representations
r = torch.cat(
    [torch.stack(reps, dim=0)[:, :, token_pos, :] for reps in reps_per_input],
    dim=1,
)

In [20]:
r1 = r[3, 0]
r2 = r[3, 3]

print(r1.size(), r2.size())
((r1 - r2) == 0).sum()
r1-r2

torch.Size([512]) torch.Size([512])


tensor([ 1.3770e-01, -5.1270e-02,  1.7139e-01, -4.3213e-02, -5.8105e-02,
         2.3975e-01,  1.0669e-01,  1.7334e-02, -7.8613e-02,  1.2988e-01,
        -8.5083e-02,  4.4006e-02,  1.9531e-01,  9.9487e-02, -1.6846e-01,
         4.0527e-02, -6.5918e-02, -2.0020e-02,  6.2256e-03, -2.5635e-02,
        -5.8472e-02,  3.7012e-01,  4.1504e-03, -7.8125e-03,  9.2285e-02,
        -4.9194e-02, -4.9438e-02,  1.2793e-01,  1.4807e-01, -1.3745e-01,
        -4.6875e-02, -1.4868e-01, -8.7280e-02, -3.2959e-03,  1.7261e-01,
        -2.1729e-02, -5.5664e-02,  2.8076e-02, -1.0620e-02,  1.5015e-02,
         1.0986e-01, -8.3984e-02,  3.2379e-02, -5.5298e-02, -1.3428e-02,
         2.4902e-02,  9.7168e-02, -5.4297e-01, -2.2754e-01, -1.0229e-01,
         6.8604e-02,  9.5947e-02, -1.8408e-01,  8.7646e-02,  3.3691e-02,
         1.2183e-01,  1.0938e-01,  6.8359e-02,  4.5410e-02,  4.1748e-02,
         1.8909e-01, -5.9082e-02, -2.3804e-02,  5.6152e-02, -1.1865e-01,
         2.4841e-02,  7.1533e-02, -6.5430e-02, -1.9

In [11]:
r1 = r[3]
(r1 - r1.mean(dim=0, keepdim=True)).mean(dim=0)

tensor([-7.3254e-05, -2.4414e-04,  4.8828e-04, -2.4438e-05, -6.1035e-05,
         0.0000e+00, -2.4438e-05,  1.0985e-04, -9.7632e-05,  9.7632e-05,
         2.4438e-05, -1.5259e-05, -3.9053e-04,  4.2737e-05,  2.4414e-04,
        -3.4189e-04,  0.0000e+00,  2.6846e-04,  1.5259e-05, -1.2219e-05,
        -1.8299e-05,  0.0000e+00,  3.6597e-05, -6.8378e-04, -1.4651e-04,
         0.0000e+00,  0.0000e+00,  4.8816e-05,  1.4651e-04,  1.3423e-04,
        -4.8816e-05,  7.3254e-05, -2.4438e-05,  2.4438e-05, -1.2207e-04,
         2.4438e-05,  0.0000e+00,  1.2207e-04,  0.0000e+00,  1.2207e-04,
        -2.4438e-05,  1.4651e-04, -1.5497e-06, -6.1035e-05, -1.2219e-05,
        -4.8816e-05, -7.3254e-05, -1.7095e-04, -3.6597e-05,  9.7632e-05,
         1.2207e-04,  4.2737e-05,  1.8299e-05, -7.3254e-05,  0.0000e+00,
        -3.6597e-05,  2.9302e-04, -3.4189e-04,  4.8816e-05,  1.4651e-04,
        -9.7632e-05,  1.4651e-04,  6.1035e-05, -4.8816e-05,  0.0000e+00,
         2.4438e-05,  1.2207e-04,  2.4438e-05, -6.1

In [12]:
n_layers1 = r.size(0)
n_layers2 = r.size(0)
for rep1_layer_idx, rep2_layer_idx in tqdm(
    itertools.product(range(n_layers1), range(n_layers2)), total=n_layers1 * n_layers2
):
    print(rep1_layer_idx, rep2_layer_idx)

100%|██████████| 49/49 [00:00<00:00, 94797.46it/s]

0 0
0 1
0 2
0 3
0 4
0 5
0 6
1 0
1 1
1 2
1 3
1 4
1 5
1 6
2 0
2 1
2 2
2 3
2 4
2 5
2 6
3 0
3 1
3 2
3 3
3 4
3 5
3 6
4 0
4 1
4 2
4 3
4 4
4 5
4 6
5 0
5 1
5 2
5 3
5 4
5 5
5 6
6 0
6 1
6 2
6 3
6 4
6 5
6 6





`ValueError: array must not contain infs or NaNs` when trying to compare layer zero to six in main code.

In [37]:
def check_nans_and_infs(r):
    if isinstance(r, torch.Tensor):
        print("infs", r.isinf().sum())
        print("nans", r.isnan().sum())
    elif isinstance(r, np.ndarray):
        print("infs", np.isinf(r).sum())
        print("nans", np.isnan(r).sum())
    else:
        raise ValueError(f"{type(r)}")
check_nans_and_infs(r)

infs tensor(0)
nans tensor(0)


Anfangs keine infs or nans --> kommt vom preprocessing

In [69]:
def to_numpy_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[npt.NDArray]:
    def convert(x: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray:
        return x if isinstance(x, np.ndarray) else x.numpy()

    return list(map(convert, args))

def center_columns(R: npt.NDArray) -> npt.NDArray:
    return R - R.mean(axis=0)[None, :]


def normalize_matrix_norm(R: npt.NDArray) -> npt.NDArray:
    return R / np.linalg.norm(R, ord="fro")





for layer in range(r.size(0)):
    # conversion seems fine
    # print("numpy conversion")
    # check_nans_and_infs(r_np)
    r_np = to_numpy_if_needed(r)[0]
    print(layer)
    print("centering")
    r_np = r_np[layer]
    print("zeros", (r_np == 0).sum())
    r_np_nonzeroed = r_np
    r_np = center_columns(r_np)
    print("zeros", (r_np == 0).sum())
    print("zeros share", (r_np == 0).sum()/r_np.size)
    check_nans_and_infs(r_np)

# normalize fails, because all vals are zero
# print("normalize")
# print(r_np.shape)
# print((r_np == 0).sum())
# r_np = normalize_matrix_norm(r_np)
# check_nans_and_infs(r_np)


0
centering
zeros 0
zeros 648704
zeros share 1.0
infs 0
nans 0
1
centering
zeros 13
zeros 1678
zeros share 0.0025866959352801896
infs 0
nans 0
2
centering
zeros 89
zeros 1355
zeros share 0.002088780090765588
infs 0
nans 0
3
centering
zeros 68
zeros 820
zeros share 0.0012640588003157065
infs 0
nans 0
4
centering
zeros 183
zeros 712
zeros share 0.0010975730071033937
infs 0
nans 0
5
centering
zeros 137
zeros 359
zeros share 0.0005534111089187056
infs 0
nans 0
6
centering
zeros 0
zeros 4369
zeros share 0.006734966949486977
infs 0
nans 0


layer zero normalization probably fails, because it is the embedding output and all prompts end the same way. 
Hence, the final token is always identical, and centering thus turns the columns into zeros.

_Maybe not true: in the main script it fails with layer 6..._

In [57]:
r_np_nonzeroed

array([[-0.01238 , -0.01467 ,  0.014946, ..., -0.02074 , -0.0531  ,
        -0.01828 ],
       [-0.01238 , -0.01467 ,  0.014946, ..., -0.02074 , -0.0531  ,
        -0.01828 ],
       [-0.01238 , -0.01467 ,  0.014946, ..., -0.02074 , -0.0531  ,
        -0.01828 ],
       ...,
       [-0.01238 , -0.01467 ,  0.014946, ..., -0.02074 , -0.0531  ,
        -0.01828 ],
       [-0.01238 , -0.01467 ,  0.014946, ..., -0.02074 , -0.0531  ,
        -0.01828 ],
       [-0.01238 , -0.01467 ,  0.014946, ..., -0.02074 , -0.0531  ,
        -0.01828 ]], dtype=float16)