## Init

In [None]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoModelForTokenClassification
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter

## Helper Functions

In [None]:
ALNUM_CHARSET = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')

def convert_to_tokens(indices, tokenizer, extended=False, extra_values_pos=None, strip=True):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]")
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    return res


def top_tokens(v, k=100, tokenizer=None, only_alnum=False, only_ascii=True, with_values=False,
               exclude_brackets=False, extended=True, extra_values=None, only_from_list=None):
    if tokenizer is None:
        tokenizer = my_tokenizer
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not val.strip('Ġ▁').isascii()])
    if only_alnum:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not (set(val.strip('Ġ▁[] ')) <= ALNUM_CHARSET)])
    if only_from_list:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if val.strip('Ġ▁ ').lower() not in only_from_list])
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)

    ignored_indices = list(set(ignored_indices))
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res

## Extract Weights

In [None]:
tokenizer_gpt = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-medium", add_prefix_space=True)
tokenizer_gpt.pad_token = tokenizer_gpt.eos_token
model_gpt = AutoModelForTokenClassification.from_pretrained("sdadas/polish-gpt2-medium", pad_token_id=tokenizer_gpt.pad_token_id)


Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at sdadas/polish-gpt2-medium and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = AutoModelForCausalLM.from_pretrained("sdadas/polish-gpt2-medium")
tokenizer = my_tokenizer = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-medium")
emb = model.get_output_embeddings().weight.data.T.detach()

num_layers = model.config.n_layer
num_heads = model.config.n_head
hidden_dim = model.config.n_embd
head_size = hidden_dim // num_heads

K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
                           for j in range(num_layers)]).detach()
V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
                           for j in range(num_layers)]).detach()

W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
                           for j in range(num_layers)]).detach().chunk(3, dim=-1)
W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
                           for j in range(num_layers)]).detach()


In [None]:
K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)

In [None]:
emb_inv = emb.T

## Interpretation

#### Alternative I: No Token List

In [None]:
tokens_list = set()

In [None]:
!pip install datasets



#### Alternative II: Can Load Token List from IMDB

In [None]:
from datasets import load_dataset

In [None]:
# imdb = load_dataset('imdb')['train']['text']
imdb = load_dataset("clarin-knext/wsd_polish_datasets", trust_remote_code=True)['train']['text']

In [None]:
max_tokens_num = None

In [None]:
if max_tokens_num is None:
    tokens_list = set()
    for txt in tqdm(imdb):
        tokens_list = tokens_list.union(set(tokenizer.tokenize(txt)))
else:
    tokens_list = Counter()
    for txt in tqdm(imdb):
        tokens_list.update(set(tokenizer.tokenize(txt)))
    tokens_list = map(lambda x: x[0], tokens_list.most_common(max_tokens_num))


100%|██████████| 7848/7848 [00:15<00:00, 490.85it/s]


In [None]:
tokens_list = set([*map(lambda x: x.strip('Ġ▁').lower(), tokens_list)])

### FF Keys & Values

In [None]:
num_layers, d_int
# (14, 1117) policja

(24, 4096)

In [None]:
i1, i2 = 23, 907
# i1, i2 = 23, 4096
i1, i2 = np.random.randint(num_layers), np.random.randint(d_int)

print(i1, i2)
print(tabulate([*zip(
    top_tokens((K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    # top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
    # top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
)], headers=['K', 'V', '-K', '-V']))

18 1126
K           V
----------  --------------
#szczu      #zek
puszek      #asi
schud       #atio
#bary       #day
diecie      #dka
dieta       #zku
nagie       duchem
nawil       #alo
diete       #chanie
warzywa     #FF
#ontari     #chaj
#szycie     #gaze
wegetaria   troche
brzuch      naturalny
muszli      #dae
skarpetki   fizj
opakowanie  bior
moczu       #chom
kilogram    Fizy
#szeniami   publicznym
cera        intelektualnie
wychu       materii
ampu        #zej
wap         #forma
#szku       #dall
pokarmu     zachci
bielizny    figura
#szen       niee
skarpe      #zka
#szczem     #dac


### Attention Weights Interpretation

In [None]:
def approx_topk(mat, min_k=500, max_k=250_000, th0=10, max_iters=10, verbose=False):
    _get_actual_k = lambda th, th_max: torch.nonzero((mat > th) & (mat < th_max)).shape[0]
    th_max = np.inf
    left, right = 0, th0
    while True:
        actual_k = _get_actual_k(right, th_max)
        if verbose:
            print(f"one more iteration. {actual_k}")
        if actual_k <= max_k:
            break
        left, right = right, right * 2
    if min_k <= actual_k <= max_k:
        th = right
    else:
        for _ in range(max_iters):
            mid = (left + right) / 2
            actual_k = _get_actual_k(mid, th_max)
            if verbose:
                print(f"one more iteration. {actual_k}")
            if min_k <= actual_k <= max_k:
                break
            if actual_k > max_k:
                left = mid
            else:
                right = mid
        th = mid
    return torch.nonzero((mat > th) & (mat < th_max)).tolist()

def get_top_entries(tmp, all_high_pos, only_ascii=False, only_alnum=False, exclude_same=False, exclude_fuzzy=False, tokens_list=None):
    remaining_pos = all_high_pos
    if only_ascii:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁').isascii() and tokenizer.decode(x[1]).strip('Ġ▁').isascii()),
            remaining_pos)]
    if only_alnum:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁ ').isalnum() and tokenizer.decode(x[1]).strip('Ġ▁ ').isalnum()),
            remaining_pos)]
    if exclude_same:
        remaining_pos = [*filter(
            lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(),
            remaining_pos)]
    if exclude_fuzzy:
        remaining_pos = [*filter(
            lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()),
            remaining_pos)]
    if tokens_list:
        remaining_pos = [*filter(
            lambda x: ((tokenizer.decode(x[0]).strip('Ġ▁').lower().strip() in tokens_list) and
                       (tokenizer.decode(x[1]).strip('Ġ▁').lower().strip() in tokens_list)),
            remaining_pos)]

    pos_val = tmp[[*zip(*remaining_pos)]]
    good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
    good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
    remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:50]]
    good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]
    # good_cells[:100]
    # list(zip(good_tokens[0], good_tokens[1]))
    return good_cells_best

#### $W_{VO}$ Interpretation

Choose **layer** and **head** here:

In [None]:
num_layers, num_heads

(24, 16)

In [None]:
i1, i2 = np.random.randint(num_layers), np.random.randint(num_heads)
# i1, i2 = 24, 9
i1, i2

(23, 12)

In [None]:
W_V_tmp, W_O_tmp = W_V_heads[i1, i2, :], W_O_heads[i1, i2]
tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)

In [None]:
all_high_pos = approx_topk(tmp, th0=1, verbose=True) # torch.nonzero((tmp > th) & (tmp < th_max)).tolist()

one more iteration. 0
one more iteration. 0
one more iteration. 31
one more iteration. 11149


In [None]:
exclude_same = False
reverse_list = False
only_ascii = True
only_alnum = False

In [None]:
get_top_entries(tmp, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum,
                exclude_same=exclude_same, tokens_list=None)

[('zwykopem', 'dzaj'),
 ('zwykopem', 'duj'),
 ('zwykopem', 'gaj'),
 ('zwykopem', 'wiaj'),
 ('zwykopem', 'laj'),
 ('emy', 'dzaj'),
 ('zwykopem', 'minaj'),
 ('zwykopem', 'aj'),
 ('mady', ' no'),
 ('zwykopem', 'kuj'),
 ('miona', ' I'),
 ('zwykopem', 'kaj'),
 ('naugu', ' I'),
 ('zwykopem', ' szukaj'),
 ('kia', ' no'),
 ('naczej', ' I'),
 ('zwykopem', 'nuj'),
 ('%', ' 2'),
 ('%,', ' 2'),
 ('zwykopem', 'chaj'),
 ('zwykopem', 'raj'),
 ('Query', 'j'),
 ('%.', ' 2'),
 ('emy', 'gaj'),
 ('zwykopem', ' moj'),
 ('zwykopem', 'zuj'),
 ('emy', 'waj'),
 ('zwykopem', 'daj'),
 ('\x01', 'laj'),
 ('%,', ' 3'),
 ('zwykopem', 'czaj'),
 ('emy', 'kuj'),
 ('zwykopem', 'tuj'),
 ('%', ' 3'),
 ('\x01', 'wiaj'),
 ('zwykopem', ' j'),
 ('lle', 'c'),
 ('%,', ' 4'),
 ('%.', ' 3'),
 ('zwykopem', ' patrz'),
 ('\x01', 'dzaj'),
 ('zwykopem', ' ej'),
 ('Query', 'gaj'),
 ('DF', ' 2'),
 ('emy', 'duj'),
 ('zwykopem', 'ruj'),
 ('\x01', ' j'),
 ('ngu', ' np'),
 ('zwykopem', 'caj'),
 ('%.', ' 4')]

#### $W_{QK}$ Interpretation

Choose **layer** and **head** here:

In [None]:
i1, i2 = np.random.randint(num_layers), np.random.randint(num_heads)
# i1, i2 = 20, 13
i1, i2

(7, 6)

In [None]:
W_Q_tmp, W_K_tmp = W_Q_heads[i1, i2, :], W_K_heads[i1, i2, :]
tmp2 = (emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T)

In [None]:
all_high_pos = approx_topk(tmp2, th0=1, verbose=True) # torch.nonzero((tmp2 > th2) & (tmp2 < th_max2)).tolist()

one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 176
one more iteration. 574796
one more iteration. 6560


In [None]:
exclude_same = False
reverse_list = False
only_ascii = True
only_alnum = True

In [None]:
get_top_entries(tmp2, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum, exclude_same=exclude_same,
                tokens_list=tokens_list)

[('dex', 'loka'),
 ('polsce', 'tka'),
 ('logie', 'tka'),
 ('use', 'arian'),
 ('struktura', 'tka'),
 ('ink', 'mer'),
 ('erytory', 'tka'),
 ('ink', 'emon'),
 ('ink', 'ole'),
 ('ink', 'mera'),
 ('niusz', 'FE'),
 ('nku', 'mera'),
 ('gas', 'OS'),
 ('PS', 'lika'),
 ('use', 'ange'),
 ('ologie', 'tka'),
 ('zmat', 'ange'),
 ('odu', 'stolatek'),
 ('kina', 'ange'),
 ('ymu', 'ME'),
 (' Lewica', 'tka'),
 ('sin', 'czko'),
 ('struktury', 'tka'),
 ('uze', 'lika'),
 ('ktura', 'tka'),
 ('noty', 'MIE'),
 ('ins', 'ole'),
 ('cit', 'loka'),
 ('owz', 'Mistrz'),
 ('Budowa', 'tka'),
 ('sel', 'OS'),
 ('ins', 'mer'),
 ('php', 'rus'),
 ('linga', 'gali'),
 ('use', 'lika'),
 ('osci', 'tka'),
 ('szenko', 'siedli'),
 ('nizacji', 'tka'),
 ('RN', 'gaj'),
 ('BN', 'gdo'),
 ('Hen', 'ange'),
 ('kil', 'loka'),
 ('Np', 'tka'),
 ('pr', 'gdo'),
 ('RN', 'gdo'),
 ('kum', 'tka'),
 ('ect', 'siedli'),
 ('VD', 'tka'),
 ('ink', 'siedli'),
 ('lino', 'mera')]

## Plots

*We thank Ohad Rubin for the idea of providing plots for better visualizations!*

In [None]:
i1, i2 = 6, 2152

In [None]:
from sklearn.manifold import TSNE
import pandas as pd
import plotly.express as px

In [None]:
def _calc_df(vector, k, coef, normalized, tokenizer):
    mat = emb
    if normalized:
        mat = F.normalize(mat, dim=-1)
    dot = vector @ mat
    sol = torch.topk(dot * coef, k=k).indices # np.argsort(dot * coef)[-k:]
    pattern = mat[:, sol].T
    scores = coef * dot[sol]
    # labels = tokenizer.batch_decode(sol)
    labels = convert_to_tokens(sol, tokenizer=tokenizer)
    X_embedded = TSNE(n_components=3,
                  learning_rate=10,
                   init='pca',
                   perplexity=3).fit_transform(pattern)

    df = pd.DataFrame(dict(x=X_embedded.T[0], y=X_embedded.T[1], z=X_embedded.T[2], label=labels, score=scores))
    return df


def plot_embedding_space(vector, is_3d=False, add_text=False, k=100, coef=1, normalized=False, tokenizer=None):
    df = _calc_df(vector, k=k, coef=coef, normalized=normalized, tokenizer=tokenizer)
    kwargs = {}
    scatter_fn = px.scatter
    if add_text:
        kwargs.update({'text': 'label'})
    if is_3d:
        scatter_fn = px.scatter_3d
        kwargs.update({'z': 'z'})
    fig = scatter_fn(
        data_frame=df,
        x='x',
        y='y',
        custom_data=["label", "score"],
        color="score", size_max=1, **kwargs)

    fig.update_traces(
        hovertemplate="<br>".join([
            "ColX: %{x}",
            "ColY: %{y}",
            "label: %{customdata[0]}",
            "score: %{customdata[1]}"
        ])
    )

    if add_text:
        fig.update_traces(textposition='middle right')
    fig.show()

In [None]:
plot_embedding_space(K_heads[i1][i2], tokenizer=tokenizer, normalized=False)