## GPT-2
W oparciu o fragmenty kodu załączone do pracy:
```
@misc{transformers_in_embedding_space,
  doi = {10.48550/ARXIV.2209.02535},
  url = {https://arxiv.org/abs/2209.02535},
  author = {Dar, Guy and Geva, Mor and Gupta, Ankit and Berant, Jonathan},
  title = {Analyzing Transformers in Embedding Space},
  publisher = {arXiv},
  year = {2022},
  copyright = {Creative Commons Attribution 4.0 International}
}
```

In [85]:
import sys


sys.path.append(
    "/net/software/v1/software/Python-bundle-PyPI/2023.10-GCCcore-13.2.0/lib/python3.11/site-packages"
)

In [86]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter


ALNUM_CHARSET = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")


def convert_to_tokens(
    indices, tokenizer, extended=False, extra_values_pos=None, strip=True
):
    if extended:
        res = [
            (
                tokenizer.convert_ids_to_tokens([idx])[0]
                if idx < len(tokenizer)
                else (
                    f"[pos{idx-len(tokenizer)}]"
                    if idx < extra_values_pos
                    else f"[val{idx-extra_values_pos}]"
                )
            )
            for idx in indices
        ]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == "Ġ" else "#" + x, res))
    return res


def top_tokens(
    v,
    k=100,
    tokenizer=None,
    only_alnum=False,
    only_ascii=True,
    with_values=False,
    exclude_brackets=False,
    extended=True,
    extra_values=None,
    only_from_list=None,
):
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices.extend(
            [
                key
                for val, key in tokenizer.vocab.items()
                if not val.strip("Ġ▁").isascii()
            ]
        )
    if only_alnum:
        ignored_indices.extend(
            [
                key
                for val, key in tokenizer.vocab.items()
                if not (set(val.strip("Ġ▁[] ")) <= ALNUM_CHARSET)
            ]
        )
    if only_from_list:
        ignored_indices.extend(
            [
                key
                for val, key in tokenizer.vocab.items()
                if val.strip("Ġ▁ ").lower() not in only_from_list
            ]
        )
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {
                key
                for val, key in tokenizer.vocab.items()
                if not (val.isascii() and val.isalnum())
            }
        )
        ignored_indices = list(ignored_indices)

    ignored_indices = list(set(ignored_indices))
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(
        indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos
    )
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res

In [87]:
def approx_topk(mat, min_k=500, max_k=250_000, th0=10, max_iters=10, verbose=False):
    _get_actual_k = lambda th, th_max: torch.nonzero((mat > th) & (mat < th_max)).shape[
        0
    ]
    th_max = np.inf
    left, right = 0, th0
    while True:
        actual_k = _get_actual_k(right, th_max)
        if verbose:
            print(f"one more iteration. {actual_k}")
        if actual_k <= max_k:
            break
        left, right = right, right * 2
    if min_k <= actual_k <= max_k:
        th = right
    else:
        for _ in range(max_iters):
            mid = (left + right) / 2
            actual_k = _get_actual_k(mid, th_max)
            if verbose:
                print(f"one more iteration. {actual_k}")
            if min_k <= actual_k <= max_k:
                break
            if actual_k > max_k:
                left = mid
            else:
                right = mid
        th = mid
    return torch.nonzero((mat > th) & (mat < th_max)).tolist()


def get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=False,
    only_alnum=False,
    exclude_same=False,
    exclude_fuzzy=False,
    tokens_list=None,
    tokenizer=None,
    reverse_list=False,
):
    remaining_pos = all_high_pos
    if only_ascii:
        remaining_pos = [
            *filter(
                lambda x: (
                    tokenizer.decode(x[0]).strip("Ġ▁").isascii()
                    and tokenizer.decode(x[1]).strip("Ġ▁").isascii()
                ),
                remaining_pos,
            )
        ]
    if only_alnum:
        remaining_pos = [
            *filter(
                lambda x: (
                    tokenizer.decode(x[0]).strip("Ġ▁ ").isalnum()
                    and tokenizer.decode(x[1]).strip("Ġ▁ ").isalnum()
                ),
                remaining_pos,
            )
        ]
    if exclude_same:
        remaining_pos = [
            *filter(
                lambda x: tokenizer.decode(x[0]).lower().strip()
                != tokenizer.decode(x[1]).lower().strip(),
                remaining_pos,
            )
        ]
    if exclude_fuzzy:
        remaining_pos = [
            *filter(
                lambda x: not _fuzzy_eq(
                    tokenizer.decode(x[0]).lower().strip(),
                    tokenizer.decode(x[1]).lower().strip(),
                ),
                remaining_pos,
            )
        ]
    if tokens_list:
        remaining_pos = [
            *filter(
                lambda x: (
                    (tokenizer.decode(x[0]).strip("Ġ▁").lower().strip() in tokens_list)
                    and (
                        tokenizer.decode(x[1]).strip("Ġ▁").lower().strip()
                        in tokens_list
                    )
                ),
                remaining_pos,
            )
        ]

    pos_val = tmp[[*zip(*remaining_pos)]]
    good_cells = [
        *map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)
    ]
    good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
    remaining_pos_best = np.array(remaining_pos)[
        torch.argsort(pos_val if reverse_list else -pos_val)[:50]
    ]
    good_cells_best = [
        *map(
            lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])),
            remaining_pos_best,
        )
    ]
    # good_cells[:100]
    # list(zip(good_tokens[0], good_tokens[1]))
    return good_cells_best

In [88]:
import os
from pathlib import Path


group_storage = Path(os.getenv("PLG_GROUPS_STORAGE")) / "plggaigraphicsk46"


model = AutoModelForCausalLM.from_pretrained(
    "sdadas/polish-gpt2-medium", cache_dir=group_storage / ".cache"
)
tokenizer = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-medium")
emb = model.get_output_embeddings().weight.data.T.detach()

num_layers = model.config.n_layer
num_heads = model.config.n_head
hidden_dim = model.config.n_embd
head_size = hidden_dim // num_heads

K = torch.cat(
    [
        model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
        for j in range(num_layers)
    ]
).detach()
V = torch.cat(
    [
        model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
        for j in range(num_layers)
    ]
).detach()

W_Q, W_K, W_V = (
    torch.cat(
        [
            model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
            for j in range(num_layers)
        ]
    )
    .detach()
    .chunk(3, dim=-1)
)
W_O = torch.cat(
    [
        model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
        for j in range(num_layers)
    ]
).detach()

K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(
    0, 2, 1, 3
)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(
    0, 2, 1, 3
)
W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(
    0, 2, 1, 3
)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)

emb_inv = emb.T

### Interpretacja $W_{VO}$

In [89]:
def generate_WVO_table_for_layer(layer, W_V_heads, W_O_heads, emb_inv, emb, tokenizer):
    table_data = []
    headers = ["Layer-Head"]

    for head in range(0, 12, 2):
        W_V_tmp, W_O_tmp = W_V_heads[layer, head, :], W_O_heads[layer, head]
        tmp = emb_inv @ (W_V_tmp @ W_O_tmp) @ emb

        all_high_pos = approx_topk(tmp, th0=1, verbose=False)

        entries = get_top_entries(
            tmp,
            all_high_pos,
            only_ascii=True,
            only_alnum=True,
            exclude_same=True,
            tokens_list=None,
            tokenizer=tokenizer,
        )

        headers.append(f"Layer {layer}, Head {head}")

        if len(table_data) < len(entries):
            table_data.extend(
                [[""] * len(headers) for _ in range(len(entries) - len(table_data))]
            )

        for i, (first, second) in enumerate(entries):
            if i >= len(table_data):
                table_data.append([""] * len(headers))
            table_data[i].append(f"{first}, {second}")

    table = tabulate(table_data, headers=headers, tablefmt="grid")
    return table

In [90]:
table = generate_WVO_table_for_layer(0, W_V_heads, W_O_heads, emb.T, emb, tokenizer)

print(table)

+----+--------------+----------------------------+-------------------+---------------------+-------------------+--------------------+--------------------+
|    | Layer-Head   | Layer 0, Head 0            | Layer 0, Head 2   | Layer 0, Head 4     | Layer 0, Head 6   | Layer 0, Head 8    | Layer 0, Head 10   |
|    |              | alam, alem                 | seksualnym, xie   | sie,  jeszcze       | Se, zani          | dziwszy, ium       | agora, terze       |
+----+--------------+----------------------------+-------------------+---------------------+-------------------+--------------------+--------------------+
|    |              | alam, lem                  | fonu, wisk        | sie,  bardzo        | Se, zy            | tum, atoli         | hre, tou           |
+----+--------------+----------------------------+-------------------+---------------------+-------------------+--------------------+--------------------+
|    |              | skakuje,  powiedziala      | anowska, rala     |

Dla warstwy 0 i głowy 0 obserwujemy najczęściej pary czasowników.

In [91]:
table = generate_WVO_table_for_layer(5, W_V_heads, W_O_heads, emb.T, emb, tokenizer)

print(table)

+--------------+-------------------+--------------------+-------------------+-------------------------+---------------------+-----------------------+
| Layer-Head   | Layer 5, Head 0   | Layer 5, Head 2    | Layer 5, Head 4   | Layer 5, Head 6         | Layer 5, Head 8     | Layer 5, Head 10      |
|              |                   | RM,  podpis        | RF, tto           | twoja, nowska           | hor, tic            | delu,  chuj           |
+--------------+-------------------+--------------------+-------------------+-------------------------+---------------------+-----------------------+
|              |                   | Gospodarstwa, ymet | RF, ionie         | eria, nowska            | Kinga, chem         | delu,  hehe           |
+--------------+-------------------+--------------------+-------------------+-------------------------+---------------------+-----------------------+
|              |                   | palenia,  etc      | Try, tto          | dae,  milczeli        

### Interpretacja $W_{QK}$

In [92]:
table = generate_WQK_table_for_layer(0, W_Q_heads, W_K_heads, emb.T, tokenizer)
print(table)

+----+--------------+--------------------+----------------------------+----------------------------+--------------------------+--------------------------+----------------------+
|    | Layer-Head   | Layer 0, Head 0    | Layer 0, Head 2            | Layer 0, Head 4            | Layer 0, Head 6          | Layer 0, Head 8          | Layer 0, Head 10     |
|    |              | in, daleka         | niebieskie, ustro          | TSZ, macicy                | reli, box                | miastach, sakra          | ssia, hr             |
+----+--------------+--------------------+----------------------------+----------------------------+--------------------------+--------------------------+----------------------+
|    |              | lit, daleka        | see, ustro                 | TSZ, wierciad              | decy, spon               | miastach, kota           | Ojczyzna, biurka     |
+----+--------------+--------------------+----------------------------+----------------------------+----------

In [109]:
table = generate_WQK_table_for_layer(5, W_Q_heads, W_K_heads, emb.T, tokenizer)
print(table)

+----+--------------+-------------------------+-----------------------+------------------------+---------------------+----------------------+-----------------------------+
|    | Layer-Head   | Layer 5, Head 0         | Layer 5, Head 2       | Layer 5, Head 4        | Layer 5, Head 6     | Layer 5, Head 8      | Layer 5, Head 10            |
|    |              | sakra, Frag             | cukrzy, dedy          | kota, sakra            | kota, sakra         | sakra, ac            | komputerowych, stosowaniu   |
+----+--------------+-------------------------+-----------------------+------------------------+---------------------+----------------------+-----------------------------+
|    |              | pretekstem, realizacji  | Wincentego, dedy      | sakra, kota            | kota, ac            | sygnaliz, ac         | komputerowych, senatorowie  |
+----+--------------+-------------------------+-----------------------+------------------------+---------------------+----------------------

## Herbert

In [96]:
model2 = AutoModel.from_pretrained(
    "allegro/herbert-base-cased", cache_dir=group_storage / ".cache"
)
tokenizer2 = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")

emb2 = model2.embeddings.word_embeddings.weight.data.T.detach()
emb2_inv = emb2.T
num_layers_bert = model2.config.num_hidden_layers
hidden_dim_bert = model2.config.hidden_size
num_heads_bert = model2.config.num_attention_heads
head_size_bert = hidden_dim_bert // num_heads_bert

Q_bert = torch.cat(
    [
        model2.get_parameter(f"encoder.layer.{j}.attention.self.query.weight").T
        for j in range(num_layers_bert)
    ]
).detach()
K_bert = torch.cat(
    [
        model2.get_parameter(f"encoder.layer.{j}.attention.self.key.weight").T
        for j in range(num_layers_bert)
    ]
).detach()
V_bert = torch.cat(
    [
        model2.get_parameter(f"encoder.layer.{j}.attention.self.value.weight").T
        for j in range(num_layers_bert)
    ]
).detach()
O_bert = torch.cat(
    [
        model2.get_parameter(f"encoder.layer.{j}.attention.output.dense.weight").T
        for j in range(num_layers_bert)
    ]
).detach()

Q_heads_bert = Q_bert.reshape(
    num_layers_bert, hidden_dim_bert, num_heads_bert, head_size_bert
).permute(0, 2, 1, 3)
K_heads_bert = K_bert.reshape(
    num_layers_bert, hidden_dim_bert, num_heads_bert, head_size_bert
).permute(0, 2, 1, 3)
V_heads_bert = V_bert.reshape(
    num_layers_bert, hidden_dim_bert, num_heads_bert, head_size_bert
).permute(0, 2, 1, 3)
O_heads_bert = O_bert.reshape(
    num_layers_bert, hidden_dim_bert, num_heads_bert, head_size_bert
).permute(0, 2, 1, 3)

Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.bias', 'cls.sso.sso_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [97]:
print(num_layers_bert, num_heads_bert)

12 12


### Interpretacja $W_{VO}$

In [104]:
from tabulate import tabulate


def generate_table_for_layer(layer, V_heads, O_heads, emb_inv, emb, tokenizer):
    table_data = []
    headers = ["Layer-Head"]

    for head in range(0, 12, 2):
        W_V_tmp, W_O_tmp = V_heads[layer, head, :], O_heads[layer, head]
        tmp2 = emb_inv @ (W_V_tmp @ W_O_tmp.T) @ emb
        all_high_pos2 = approx_topk(tmp2, th0=1, verbose=False)

        entries = get_top_entries(
            tmp2,
            all_high_pos2,
            only_ascii=True,
            only_alnum=True,
            exclude_same=True,
            tokens_list=None,
            tokenizer=tokenizer,
        )

        headers.append(f"Layer {layer}, Head {head}")

        if len(table_data) < len(entries):
            table_data.extend(
                [[""] * len(headers) for _ in range(len(entries) - len(table_data))]
            )

        for i, (first, second) in enumerate(entries):
            if i >= len(table_data):
                table_data.append([""] * len(headers))
            table_data[i].append(f"{first}, {second}")

    table = tabulate(table_data, headers=headers, tablefmt="grid")
    return table


layer = 0
result_table = generate_table_for_layer(
    layer, V_heads_bert, O_heads_bert, emb2.T, emb2, tokenizer2
)
print(result_table)

+----+--------------+-------------------+-------------------+--------------------+-------------------+----------------------+--------------------+
|    | Layer-Head   | Layer 0, Head 0   | Layer 0, Head 2   | Layer 0, Head 4    | Layer 0, Head 6   | Layer 0, Head 8      | Layer 0, Head 10   |
|    |              | dale, Wierzy      | cza, RZ           | wd, lada           | Zasada, chana     | Celtic, Brod         | saga, Fab          |
+----+--------------+-------------------+-------------------+--------------------+-------------------+----------------------+--------------------+
|    |              | trum, Republika   | oznacza, RZ       | BIG, pania         | jaka, sensu       | rion, Brod           | RAK, Mei           |
+----+--------------+-------------------+-------------------+--------------------+-------------------+----------------------+--------------------+
|    |              | platte, row       | rzy, RZ           | Metal, pania       | dzisiejsze, zdi   | rion, laborator

In [105]:
result_table = generate_table_for_layer(
    layer=5,
    V_heads=V_heads_bert,
    O_heads=O_heads_bert,
    emb_inv=emb2.T,
    emb=emb2,
    tokenizer=tokenizer2,
)
print(result_table)

+----+--------------+-----------------------+---------------------+---------------------+-------------------+-------------------+--------------------+
|    | Layer-Head   | Layer 5, Head 0       | Layer 5, Head 2     | Layer 5, Head 4     | Layer 5, Head 6   | Layer 5, Head 8   | Layer 5, Head 10   |
|    |              | Much, ABS             | kierunku, refleks   | moim, owanego       | cek, Astra        | University, mona  | spekt, rzod        |
+----+--------------+-----------------------+---------------------+---------------------+-------------------+-------------------+--------------------+
|    |              | Popular, cau          | daleko, Niko        | faktu, oga          | 1922, Opel        | Ziem, mona        | spekt, zaka        |
+----+--------------+-----------------------+---------------------+---------------------+-------------------+-------------------+--------------------+
|    |              | Much, WAL             | krajobraz, pseudo   | pokazuje, oga       | 1925

### Interpretacja $W_{QK}$

In [106]:
from tabulate import tabulate


def generate_WQK_table_for_layer(layer, Q_heads, K_heads, emb_inv, tokenizer):
    table_data = []
    headers = ["Layer-Head"]

    for head in range(0, 12, 2):

        W_Q_tmp, W_K_tmp = Q_heads[layer, head, :], K_heads[layer, head, :]
        tmp2 = emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T

        all_high_pos2 = approx_topk(tmp2, th0=1, verbose=False)

        entries = get_top_entries(
            tmp2,
            all_high_pos2,
            only_ascii=True,
            only_alnum=True,
            exclude_same=True,
            tokens_list=None,
            tokenizer=tokenizer2,
        )

        headers.append(f"Layer {layer}, Head {head}")

        if len(table_data) < len(entries):
            table_data.extend(
                [[""] * len(headers) for _ in range(len(entries) - len(table_data))]
            )

        for i, (first, second) in enumerate(entries):
            if i >= len(table_data):
                table_data.append([""] * len(headers))
            table_data[i].append(f"{first}, {second}")

    table = tabulate(table_data, headers=headers, tablefmt="grid")
    return table

In [107]:
layer = 0
result_table = generate_WQK_table_for_layer(
    layer, Q_heads_bert, K_heads_bert, emb2.T, tokenizer2
)
print(result_table)

+----+--------------+-------------------+-------------------------+-------------------+-------------------+-----------------------+--------------------+
|    | Layer-Head   | Layer 0, Head 0   | Layer 0, Head 2         | Layer 0, Head 4   | Layer 0, Head 6   | Layer 0, Head 8       | Layer 0, Head 10   |
|    |              |                   | Sezon, Philadelphia     | ZO, sani          | Swa, spore        | zazwyczaj, pm         | tylu, Nie          |
+----+--------------+-------------------+-------------------------+-------------------+-------------------+-----------------------+--------------------+
|    |              |                   | zmaga, Philadelphia     | Rome, sezon       | Brown, lobby      | zazwyczaj, Volkswagen | cale, Nie          |
+----+--------------+-------------------+-------------------------+-------------------+-------------------+-----------------------+--------------------+
|    |              |                   | garso, Philadelphia     | Ano, Lady     

In [108]:
layer = 5
result_table = generate_WQK_table_for_layer(
    layer, Q_heads_bert, K_heads_bert, emb2.T, tokenizer2
)
print(result_table)

+----+--------------+------------------------+-------------------+-------------------+-------------------+-------------------+----------------------+
|    | Layer-Head   | Layer 5, Head 0        | Layer 5, Head 2   | Layer 5, Head 4   | Layer 5, Head 6   | Layer 5, Head 8   | Layer 5, Head 10     |
|    |              | bogate, yka            | marco, Tru        | 1928, 1996        | przyjaciel, jada  | ULI, powinny      | umieszcza, cego      |
+----+--------------+------------------------+-------------------+-------------------+-------------------+-------------------+----------------------+
|    |              | dodatkowy, nata        | ARKA, jur         | 1928, 1986        | udziela, jcie     | 391, dii          | schrieb, aha         |
+----+--------------+------------------------+-------------------+-------------------+-------------------+-------------------+----------------------+
|    |              | mej, poziomu           | Transport, snu    | 1928, 1992        | zamiar, stria