## GPT-2
W oparciu o fragmenty kodu załączone do pracy:
```
@misc{transformers_in_embedding_space,
  doi = {10.48550/ARXIV.2209.02535},
  url = {https://arxiv.org/abs/2209.02535},
  author = {Dar, Guy and Geva, Mor and Gupta, Ankit and Berant, Jonathan},
  title = {Analyzing Transformers in Embedding Space},
  publisher = {arXiv},
  year = {2022},
  copyright = {Creative Commons Attribution 4.0 International}
}
```

In [1]:
import sys


sys.path.append("/net/software/v1/software/Python-bundle-PyPI/2023.10-GCCcore-13.2.0/lib/python3.11/site-packages")

In [2]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter


ALNUM_CHARSET = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')

def convert_to_tokens(indices, tokenizer, extended=False, extra_values_pos=None, strip=True):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else 
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]") 
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    return res


def top_tokens(v, k=100, tokenizer=None, only_alnum=False, only_ascii=True, with_values=False, 
               exclude_brackets=False, extended=True, extra_values=None, only_from_list=None):
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not val.strip('Ġ▁').isascii()])
    if only_alnum: 
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not (set(val.strip('Ġ▁[] ')) <= ALNUM_CHARSET)])
    if only_from_list:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if val.strip('Ġ▁ ').lower() not in only_from_list])
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)
        
    ignored_indices = list(set(ignored_indices))
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res

In [3]:
def approx_topk(mat, min_k=500, max_k=250_000, th0=10, max_iters=10, verbose=False):
    _get_actual_k = lambda th, th_max: torch.nonzero((mat > th) & (mat < th_max)).shape[0]
    th_max = np.inf
    left, right = 0, th0 
    while True:
        actual_k = _get_actual_k(right, th_max)
        if verbose:
            print(f"one more iteration. {actual_k}")
        if actual_k <= max_k:
            break
        left, right = right, right * 2
    if min_k <= actual_k <= max_k:
        th = right
    else:
        for _ in range(max_iters):
            mid = (left + right) / 2
            actual_k = _get_actual_k(mid, th_max)
            if verbose:
                print(f"one more iteration. {actual_k}")
            if min_k <= actual_k <= max_k:
                break
            if actual_k > max_k:
                left = mid
            else:
                right = mid
        th = mid
    return torch.nonzero((mat > th) & (mat < th_max)).tolist()

def get_top_entries(tmp, all_high_pos, only_ascii=False, only_alnum=False, exclude_same=False, exclude_fuzzy=False, tokens_list=None, tokenizer=None, reverse_list=False):
    remaining_pos = all_high_pos
    if only_ascii:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁').isascii() and tokenizer.decode(x[1]).strip('Ġ▁').isascii()), 
            remaining_pos)]
    if only_alnum:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁ ').isalnum() and tokenizer.decode(x[1]).strip('Ġ▁ ').isalnum()), 
            remaining_pos)]
    if exclude_same:
        remaining_pos = [*filter(
            lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(), 
            remaining_pos)]
    if exclude_fuzzy:
        remaining_pos = [*filter(
            lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()), 
            remaining_pos)]
    if tokens_list:
        remaining_pos = [*filter(
            lambda x: ((tokenizer.decode(x[0]).strip('Ġ▁').lower().strip() in tokens_list) and 
                       (tokenizer.decode(x[1]).strip('Ġ▁').lower().strip() in tokens_list)), 
            remaining_pos)]

    pos_val = tmp[[*zip(*remaining_pos)]]
    good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
    good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
    remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:50]]
    good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]
    # good_cells[:100]
    # list(zip(good_tokens[0], good_tokens[1]))
    return good_cells_best

In [4]:
model = AutoModelForCausalLM.from_pretrained("sdadas/polish-gpt2-medium")
tokenizer = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-medium")
emb = model.get_output_embeddings().weight.data.T.detach()

num_layers = model.config.n_layer
num_heads = model.config.n_head
hidden_dim = model.config.n_embd
head_size = hidden_dim // num_heads

K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
                           for j in range(num_layers)]).detach()
V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
                           for j in range(num_layers)]).detach()

W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight") 
                           for j in range(num_layers)]).detach().chunk(3, dim=-1)
W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight") 
                           for j in range(num_layers)]).detach()

K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)

emb_inv = emb.T

### Interpretacja $W_{VO}$

In [5]:
layer, head = 0, 0

W_V_tmp, W_O_tmp = W_V_heads[layer, head, :], W_O_heads[layer, head]
tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)

all_high_pos = approx_topk(tmp, th0=1, verbose=False)

get_top_entries(tmp, all_high_pos, only_ascii=True, only_alnum=True, 
                exclude_same=True, tokens_list=None, tokenizer=tokenizer)

[('alam', 'alem'),
 ('alam', 'lem'),
 ('skakuje', ' powiedziala'),
 ('suwa', ' skierowali'),
 (' pytam', ' powiedziala'),
 ('puszczam', ' poczuli'),
 ('puszczam', 'dli'),
 (' dodaje', 'knal'),
 (' powinnam', 'ales'),
 (' pytam', 'stuje'),
 (' dostajemy', ' dowiedzieli'),
 ('puszczam', 'cil'),
 ('suwa', ' poczuli'),
 ('siadam', 'muje'),
 (' odpowiadam', ' odpowiedzieli'),
 ('suwa', ' powiedziala'),
 ('staje', ' stwierdzili'),
 ('suwa', ' postanawia'),
 ('suwa', ' stwierdzili'),
 ('lewam', 'stuje'),
 ('alam', 'glem'),
 ('lewam', ' poznali'),
 ('lewam', 'dujemy'),
 (' odpowiadam', 'alia'),
 ('staje', 'cila'),
 (' odpowiadam', ' powiedziala'),
 ('puszczam', ' poznali'),
 ('lewam', 'cil'),
 ('mawiam', 'duje'),
 ('siadam', 'stuje'),
 ('alam', 'tional'),
 (' udaje', 'lem'),
 ('staje', ' powiedziala'),
 (' wyjmuje', ' wymienili'),
 ('suwa', 'knal'),
 (' wyjmuje', ' usiedli'),
 (' chwyta', ' poczuli'),
 ('suwa', ' poznali'),
 ('puszczam', 'duje'),
 (' wyjmuje', ' powiedziala'),
 (' powinnam', '

Dla warstwy 0 i głowy 0 obserwujemy najczęściej pary czasowników.

### Interpretacja $W_{KV}$

In [6]:
layer, head = 4, 4

W_Q_tmp, W_K_tmp = W_Q_heads[layer, head, :], W_K_heads[layer, head, :]
tmp2 = (emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T)

all_high_pos2 = approx_topk(tmp2, th0=1, verbose=True)

get_top_entries(tmp2, all_high_pos2, only_ascii=True, only_alnum=True, 
                exclude_same=True, tokens_list=None, tokenizer=tokenizer)

one more iteration. 0
one more iteration. 0
one more iteration. 16
one more iteration. 92784


[('dora', 'ario'),
 ('nta', 'rion'),
 ('CJE', 'CJ'),
 ('dora', 'ces'),
 ('taria', 'lio'),
 ('ariuszy', 'ariusza'),
 ('jek', 'rion'),
 ('KP', 'PL'),
 ('ariusz', 'ariusza'),
 ('CJ', 'EL'),
 (' informacyjnych', ' RPO'),
 ('erb', 'fin'),
 ('dora', 'rion'),
 ('czyn', 'kup'),
 ('aryj', 'nowo'),
 ('eu', 'lio'),
 ('feld', 'dorf'),
 ('RD', 'AN'),
 ('PR', 'PL'),
 ('tari', 'lio'),
 ('pad', 'pada'),
 ('BN', 'EL'),
 (' informacyjnych', ' informacyjne'),
 ('data', 'dat'),
 ('dnika', 'kup'),
 ('czew', 'kup'),
 ('chta', 'cht'),
 ('pety', 'ceu'),
 ('zofre', 'lio'),
 ('chowie', 'dorf'),
 ('ariusza', 'ariusz'),
 ('DE', 'EL'),
 ('ariusze', 'ariusza'),
 ('ire', 'lio'),
 ('czyny', 'kup'),
 ('RS', 'PL'),
 ('ariuszy', 'ariuszem'),
 ('jnie', 'kup'),
 ('RD', 'IN'),
 ('ANE', 'EL'),
 ('lia', 'lio'),
 ('din', 'der'),
 ('jek', 'jer'),
 ('szak', 'eryk'),
 (' wojennej', ' wojennych'),
 ('MP', 'IN'),
 ('kamie', 'mek'),
 ('eryka', 'eryk'),
 ('nty', 'chno'),
 ('KP', 'EL')]

Pary klucz-wartość dla warstwy 4, głowy 4 to pasujące do siebie odmiany (np. "ariusze" i "ariusza").

In [7]:
layer, head = 8, 4

W_Q_tmp, W_K_tmp = W_Q_heads[layer, head, :], W_K_heads[layer, head, :]
tmp2 = (emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T)

all_high_pos2 = approx_topk(tmp2, th0=1, verbose=True)

get_top_entries(tmp2, all_high_pos2, only_ascii=True, only_alnum=True, 
                exclude_same=True, tokens_list=None, tokenizer=tokenizer)

one more iteration. 0
one more iteration. 12
one more iteration. 8791


[('dzkimi', 'unkami'),
 (' ktora', ' ktorej'),
 ('dzkimi', 'nkami'),
 ('ings', 'tti'),
 ('nina', 'anina'),
 (' narodowymi', 'stkami'),
 ('perze', 'czniku'),
 (' katolickim', 'leckim'),
 ('zwykopem', 'czko'),
 ('ksem', 'nkiem'),
 ('lotem', 'jazdem'),
 ('usem', 'jazdem'),
 ('lnej', 'niowej'),
 ('zin', 'wei'),
 ('zwykopem', 'kosz'),
 ('inga', 'essa'),
 ('OWIE', 'Wer'),
 ('towanym', 'tkowski'),
 ('padzie', 'locie'),
 (' katolickim', 'tyzmem'),
 ('perze', 'cianie'),
 (' rodzinami', 'niakami'),
 ('burg', 'hoe'),
 ('dziu', 'niczki'),
 ('dzkimi', 'onami'),
 (' sumieniu', 'twor'),
 (' prokuratorem', 'pieniem'),
 ('dzkimi', 'chami'),
 ('usem', 'rusem'),
 ('padzie', 'pisie'),
 ('zdni', 'niowej'),
 ('loty', 'roli'),
 (' rodzinna', 'arska'),
 ('osobowa', 'nikowa'),
 ('onny', 'twor'),
 ('jsku', 'niuk'),
 ('szkiem', 'cznemu'),
 (' religijnym', 'nickim'),
 ('mieniu', 'manie'),
 ('gnieniu', 'nalnie'),
 (' katolickim', 'tage'),
 ('dziu', 'niczek'),
 ('grzech', 'manie'),
 ('dniem', 'nkiem'),
 (' warta', 

In [8]:
layer, head = 23, 4

W_Q_tmp, W_K_tmp = W_Q_heads[layer, head, :], W_K_heads[layer, head, :]
tmp2 = (emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T)

all_high_pos2 = approx_topk(tmp2, th0=1, verbose=True)

get_top_entries(tmp2, all_high_pos2, only_ascii=True, only_alnum=True, 
                exclude_same=True, tokens_list=None, tokenizer=tokenizer)

one more iteration. 0
one more iteration. 0
one more iteration. 17667


[(' wnim', 'pokoju'),
 (' wdomu', 'pokoju'),
 (' iroz', 'pokoju'),
 (' wna', 'pokoju'),
 (' wnim', 'di'),
 (' konsumenta', ' konsumen'),
 (' wtym', 'pokoju'),
 (' wnim', 'cala'),
 (' znim', 'pokoju'),
 (' wtej', 'tara'),
 (' wtej', 'pokoju'),
 (' wnim', 'pew'),
 (' wnim', 'des'),
 (' itak', 'pokoju'),
 (' wnim', 'tym'),
 (' Prezydent', ' RP'),
 (' wnim', 'tara'),
 (' iwy', 'pokoju'),
 (' itak', 'cala'),
 (' wdomu', 'tara'),
 (' apotem', 'tara'),
 (' wmo', 'pokoju'),
 (' wnim', 'tale'),
 (' otym', 'wszyscy'),
 (' iprze', 'pokoju'),
 (' io', 'pokoju'),
 (' regulowane', ' reguluje'),
 (' wjego', 'pokoju'),
 (' wnim', 'wiel'),
 (' znim', 'pew'),
 (' wtej', 'pew'),
 (' itak', 'di'),
 (' miasteczka', ' nadmor'),
 (' Prezydenta', ' RP'),
 (' znich', 'di'),
 (' itak', 'tara'),
 (' wnim', 'jego'),
 ('finans', ' inwestycyjne'),
 (' konsument', ' konsumen'),
 (' iz', 'cala'),
 (' stowarzyszenia', ' zrzesza'),
 (' wnim', 'dia'),
 (' wna', 'tara'),
 (' znim', 'di'),
 (' biznesie', ' biznes'),
 (' i