In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback
from peft import AutoPeftModelForTokenClassification
from datasets import load_dataset
from glob import glob

In [2]:
dataset = load_dataset("clarin-knext/wsd_polish_datasets", trust_remote_code=True)

README.md:   0%|          | 0.00/11.0k [00:00<?, ?B/s]

wsd_polish_datasets.py:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

sherlock_text.jsonl:   0%|          | 0.00/2.25M [00:00<?, ?B/s]

skladnica_text.jsonl:   0%|          | 0.00/29.0M [00:00<?, ?B/s]

wikiglex_text.jsonl:   0%|          | 0.00/12.0M [00:00<?, ?B/s]

emoglex_text.jsonl:   0%|          | 0.00/23.1M [00:00<?, ?B/s]

walenty_text.jsonl:   0%|          | 0.00/50.6M [00:00<?, ?B/s]

kpwr_text.jsonl:   0%|          | 0.00/57.4M [00:00<?, ?B/s]

kpwr-100_text.jsonl:   0%|          | 0.00/8.02M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'tokens', 'phrases', 'wsd'],
        num_rows: 7848
    })
})

In [4]:
dataset['train'].features['tokens'].feature['pos']

Value(dtype='string', id=None)

In [7]:
tokenizer_bert = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")
model_bert = AutoModelForTokenClassification.from_pretrained("allegro/herbert-base-cased")

tokenizer_gpt = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-medium", add_prefix_space=True)
tokenizer_gpt.pad_token = tokenizer_gpt.eos_token
model_gpt = AutoModelForTokenClassification.from_pretrained("sdadas/polish-gpt2-medium", pad_token_id=tokenizer_gpt.pad_token_id)

pytorch_model.bin:  64%|######4   | 419M/654M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at allegro/herbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.34M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/837 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.53G [00:00<?, ?B/s]

Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at sdadas/polish-gpt2-medium and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
sentence = dataset['train'][0]['text']

In [9]:
def get_embeddings(text, tokenizer, model, layer=-1):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs,output_hidden_states=True)
    # print(outputs)
    # x = outputs.hidden_states[-1][:, 0, :].cpu().detach().numpy().squeeze()
    x = outputs.hidden_states[layer].cpu().detach().numpy().squeeze()  

    return x
    # return outputs.last_hidden_state.squeeze(0).detach()  # Embeddings for each token

# Example for a sentence
sentence = dataset['train'][0]['text']  # Replace with the correct column
bert_embeddings = get_embeddings(sentence, tokenizer_bert, model_bert, layer=-1)
gpt2_embeddings = get_embeddings(sentence, tokenizer_gpt, model_gpt, layer=-1)

In [10]:
from scipy.spatial.distance import cosine
import torch

def measure_anisotropy(embeddings):
    # Compute cosine similarities for pairs of embeddings
    cos_similarities = []
    num_samples = 1000  # Adjust for sampling efficiency

    for _ in range(num_samples):
        # print(embeddings.size)
        idx1, idx2 = torch.randint(0, embeddings.shape[0], (2,))
        emb1 = embeddings[idx1]
        # print(emb1)
        sim = 1 - cosine(embeddings[idx1], embeddings[idx2])
        cos_similarities.append(sim)
    
    return sum(cos_similarities) / len(cos_similarities)

bert_anisotropy = measure_anisotropy(bert_embeddings)
gpt2_anisotropy = measure_anisotropy(gpt2_embeddings)
print("BERT Anisotropy:", bert_anisotropy)
print("GPT-2 Anisotropy:", gpt2_anisotropy)

BERT Anisotropy: 0.7217383639671769
GPT-2 Anisotropy: 0.24481897037399522


In [11]:
num_layers_bert = len(model_bert.bert.encoder.layer)
num_layers_bert

12

In [12]:
num_layers_gpt = len(model_gpt.transformer.h)
num_layers_gpt

24

In [13]:
bert_anisotropies = []
for i in range(num_layers_bert):
    bert_embeddings = get_embeddings(sentence, tokenizer_bert, model_bert, layer=i)
    anisotropy = measure_anisotropy(bert_embeddings)
    bert_anisotropies.append(anisotropy)

gpt_anisotropies = []
for i in range(num_layers_gpt):
    gpt_embeddings = get_embeddings(sentence, tokenizer_gpt, model_gpt, layer=i)
    anisotropy = measure_anisotropy(gpt_embeddings)
    gpt_anisotropies.append(anisotropy)

In [14]:
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

data = {
    'Layer': list(range(num_layers_bert)) + list(range(num_layers_gpt)),
    'Anisotropy': bert_anisotropies + gpt_anisotropies,
    'Model': ['BERT'] * num_layers_bert + ['GPT-2'] * num_layers_gpt
}

df = pd.DataFrame(data)

fig = go.Figure()

fig.add_trace(go.Scatter(x=df[df['Model'] == 'BERT']['Layer'], 
                         y=df[df['Model'] == 'BERT']['Anisotropy'],
                         mode='markers+lines',
                         name='BERT',
                         line=dict(shape='linear', dash='dot')))

fig.add_trace(go.Scatter(x=df[df['Model'] == 'GPT-2']['Layer'], 
                         y=df[df['Model'] == 'GPT-2']['Anisotropy'],
                         mode='markers+lines',
                         name='GPT-2',
                         line=dict(shape='linear', dash='dot')))


fig.update_layout(
    title="Anisotropy Comparison: BERT vs GPT-2",
    xaxis_title="Layer Number",
    yaxis_title="Anisotropy Value",
    legend_title="Model",
)

fig.show()


In [15]:
sentences = dataset['train'][:10]['text']  # Replace with the correct column

In [16]:
joined = " ".join(sentences)

In [17]:
def context_specificity(token, dataset, tokenizer, model, layer=-1):
    embeddings = []
    texts = dataset['text']
    for example in texts:
        try:
            inputs = tokenizer(example, return_tensors="pt")
            outputs = model(**inputs, output_hidden_states=True).hidden_states[layer].squeeze(0).detach()
        except:
            continue

        # Get index of token in the text
        token_id = tokenizer.convert_tokens_to_ids(token)
        if token_id in inputs['input_ids']:
            token_index = (inputs['input_ids'] == token_id).nonzero(as_tuple=True)[1]
            embeddings.append(outputs[token_index].mean(0))  # Averaging over token occurrences

    # Compute average cosine similarity between each pair of embeddings
    cos_similarities = []
    for i in range(len(embeddings)):
        for j in range(i + 1, len(embeddings)):
            sim = 1 - cosine(embeddings[i], embeddings[j])
            cos_similarities.append(sim)

    return sum(cos_similarities) / len(cos_similarities) if cos_similarities else None

# Example usage
bert_context_specificity = context_specificity("nie", dataset['train'][:100], tokenizer_bert, model_bert)
gpt2_context_specificity = context_specificity("nie", dataset['train'][:100], tokenizer_gpt, model_gpt)

print("BERT Context-Specificity:", bert_context_specificity)
print("GPT-2 Context-Specificity:", gpt2_context_specificity)


Token indices sequence length is longer than the specified maximum sequence length for this model (1781 > 512). Running this sequence through the model will result in indexing errors


BERT Context-Specificity: 0.8382445823264426
GPT-2 Context-Specificity: 0.24329707611341647


In [18]:
bert_context = []
for i in range(num_layers_bert):
    context = context_specificity("nie", dataset['train'][:100], tokenizer_bert, model_bert, layer=i)
    bert_context.append(context)

gpt_context = []
for i in range(num_layers_gpt):
    context = context_specificity("nie", dataset['train'][:100], tokenizer_gpt, model_gpt, layer=i)
    gpt_context.append(context)

In [None]:
data = {
    'Layer': list(range(num_layers_bert)) + list(range(num_layers_gpt)),
    'Context': bert_context + gpt_context,
    'Model': ['BERT'] * num_layers_bert + ['GPT-2'] * num_layers_gpt
}

df = pd.DataFrame(data)

fig = go.Figure()

fig.add_trace(go.Scatter(x=df[df['Model'] == 'BERT']['Layer'], 
                         y=df[df['Model'] == 'BERT']['Context'],
                         mode='markers+lines',
                         name='BERT',
                         line=dict(shape='linear', dash='dot')))

fig.add_trace(go.Scatter(x=df[df['Model'] == 'GPT-2']['Layer'], 
                         y=df[df['Model'] == 'GPT-2']['Context'],
                         mode='markers+lines',
                         name='GPT-2',
                         line=dict(shape='linear', dash='dot')))


fig.update_layout(
    title="Context-Specificity Comparison: BERT vs GPT-2",
    xaxis_title="Layer Index",
    yaxis_title="Context-Specificity Value",
    legend_title="Model",
)

fig.show()

# Parameter Projection


In [27]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter

Helper function

In [28]:
ALNUM_CHARSET = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')

def convert_to_tokens(indices, tokenizer, extended=False, extra_values_pos=None, strip=True):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else 
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]") 
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    return res


def top_tokens(v, k=100, tokenizer=None, only_alnum=False, only_ascii=True, with_values=False, 
               exclude_brackets=False, extended=True, extra_values=None, only_from_list=None):
    if tokenizer is None:
        tokenizer = my_tokenizer
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not val.strip('Ġ▁').isascii()])
    if only_alnum: 
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not (set(val.strip('Ġ▁[] ')) <= ALNUM_CHARSET)])
    if only_from_list:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if val.strip('Ġ▁ ').lower() not in only_from_list])
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)
        
    ignored_indices = list(set(ignored_indices))
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res

Extract Weights GPT

In [62]:
# tokenizer_gpt = AutoTokenizer.from_pretrained(
#     "sdadas/polish-gpt2-medium", add_prefix_space=True
# )
# tokenizer_gpt.pad_token = tokenizer_gpt.eos_token
# model_gpt = AutoModelForCausalLM.from_pretrained(
#     "sdadas/polish-gpt2-medium", pad_token_id=tokenizer_gpt.pad_token_id
# )


model = AutoModelForCausalLM.from_pretrained("sdadas/polish-gpt2-medium")
tokenizer = my_tokenizer = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-medium")
emb = model.get_output_embeddings().weight.data.T.detach()
model.config

GPT2Config {
  "_name_or_path": "sdadas/polish-gpt2-medium",
  "activation_function": "gelu_fast",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 0,
  "embd_pdrop": 0.1,
  "eos_token_id": 2,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 1024,
  "n_head": 16,
  "n_inner": 4096,
  "n_layer": 24,
  "n_positions": 2048,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "tokenizer_class": "GPT2TokenizerFast",
  "torch_dtype": "float32",
  "transformers_version": "4.45.2",
  "use_cache": true,
  "vocab_size": 51200
}

In [None]:

num_layers = model.config.n_layer
num_heads = model.config.n_head
hidden_dim = model.config.n_embd
head_size = hidden_dim // num_heads

K = torch.cat(
    [
        model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
        for j in range(num_layers)
    ]
).detach()
V = torch.cat(
    [
        model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
        for j in range(num_layers)
    ]
).detach()

W_Q, W_K, W_V = (
    torch.cat(
        [
            model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
            for j in range(num_layers)
        ]
    )
    .detach()
    .chunk(3, dim=-1)
)
W_O = torch.cat(
    [
        model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
        for j in range(num_layers)
    ]
).detach()

In [30]:
K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(
    0, 2, 1, 3
)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(
    0, 2, 1, 3
)
W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(
    0, 2, 1, 3
)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)

In [31]:
emb_inv = emb.T

In [32]:
tokens_list = set()

In [33]:
i1, i2 = 23, 907

In [34]:
print(i1, i2)

23 907


In [36]:
print(
    tabulate(
        [
            *zip(
                top_tokens(
                    (K_heads[i1, i2]) @ emb,
                    k=30,
                    only_from_list=tokens_list,
                    only_alnum=False,
                ),
                top_tokens(
                    (V_heads[i1, i2]) @ emb,
                    k=30,
                    only_from_list=tokens_list,
                    only_alnum=False,
                ),
                top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
                top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
            )
        ],
        headers=["K", "V", "-K", "-V"],
    )
)

K          V              -K          -V
---------  -------------  ----------  ---------
przycu     kody           dotychczas  #bot
zalog      #gory          rodzi       #lot
wylegi     #ei            do          #ush
#walifik   #zmy           przebie     #dzista
przesp     Apo            gatunku     #remont
pochowany  apokali        przez       #lee
#cket      #128           #ja         #wan
#ppe       ludy           rodzin      #ette
wep        #ords          zrazu       #up
sfinans    Cezary         lokalnie    #bul
#iss       archa          porywa      #puszczam
#CS        przy           two         spu
#gny       akcy           jeszcze     zamyka
#zwol      Homo           #jaw        odstawi
#-).       litera         na          #mont
lock       #pka           wymaga      #beki
skonfisk   Cezar          ty          #laks
#post      Benedykt       Drze        tap
postoju    narodem        pod         poby
erek       polityki       Nie         posto
#ionu      symbolu        rzuca  

In [37]:
def approx_topk(mat, min_k=500, max_k=250_000, th0=10, max_iters=10, verbose=False):
    _get_actual_k = lambda th, th_max: torch.nonzero((mat > th) & (mat < th_max)).shape[
        0
    ]
    th_max = np.inf
    left, right = 0, th0
    while True:
        actual_k = _get_actual_k(right, th_max)
        if verbose:
            print(f"one more iteration. {actual_k}")
        if actual_k <= max_k:
            break
        left, right = right, right * 2
    if min_k <= actual_k <= max_k:
        th = right
    else:
        for _ in range(max_iters):
            mid = (left + right) / 2
            actual_k = _get_actual_k(mid, th_max)
            if verbose:
                print(f"one more iteration. {actual_k}")
            if min_k <= actual_k <= max_k:
                break
            if actual_k > max_k:
                left = mid
            else:
                right = mid
        th = mid
    return torch.nonzero((mat > th) & (mat < th_max)).tolist()


def get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=False,
    only_alnum=False,
    exclude_same=False,
    exclude_fuzzy=False,
    tokens_list=None,
):
    remaining_pos = all_high_pos
    if only_ascii:
        remaining_pos = [
            *filter(
                lambda x: (
                    tokenizer.decode(x[0]).strip("Ġ▁").isascii()
                    and tokenizer.decode(x[1]).strip("Ġ▁").isascii()
                ),
                remaining_pos,
            )
        ]
    if only_alnum:
        remaining_pos = [
            *filter(
                lambda x: (
                    tokenizer.decode(x[0]).strip("Ġ▁ ").isalnum()
                    and tokenizer.decode(x[1]).strip("Ġ▁ ").isalnum()
                ),
                remaining_pos,
            )
        ]
    if exclude_same:
        remaining_pos = [
            *filter(
                lambda x: tokenizer.decode(x[0]).lower().strip()
                != tokenizer.decode(x[1]).lower().strip(),
                remaining_pos,
            )
        ]
    if exclude_fuzzy:
        remaining_pos = [
            *filter(
                lambda x: not _fuzzy_eq(
                    tokenizer.decode(x[0]).lower().strip(),
                    tokenizer.decode(x[1]).lower().strip(),
                ),
                remaining_pos,
            )
        ]
    if tokens_list:
        remaining_pos = [
            *filter(
                lambda x: (
                    (tokenizer.decode(x[0]).strip("Ġ▁").lower().strip() in tokens_list)
                    and (
                        tokenizer.decode(x[1]).strip("Ġ▁").lower().strip()
                        in tokens_list
                    )
                ),
                remaining_pos,
            )
        ]

    pos_val = tmp[[*zip(*remaining_pos)]]
    good_cells = [
        *map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)
    ]
    good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
    remaining_pos_best = np.array(remaining_pos)[
        torch.argsort(pos_val if reverse_list else -pos_val)[:50]
    ]
    good_cells_best = [
        *map(
            lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])),
            remaining_pos_best,
        )
    ]
    # good_cells[:100]
    # list(zip(good_tokens[0], good_tokens[1]))
    return good_cells_best

Wvo Interpretation

In [50]:
i1, i2 = np.random.randint(num_layers), np.random.randint(num_heads)
i1, i2 = 21, 7
i1, i2

(21, 7)

In [51]:
W_V_tmp, W_O_tmp = W_V_heads[i1, i2, :], W_O_heads[i1, i2]
tmp = emb_inv @ (W_V_tmp @ W_O_tmp) @ emb

In [52]:
all_high_pos = approx_topk(
    tmp, th0=1, verbose=True
)  # torch.nonzero((tmp > th) & (tmp < th_max)).tolist()

one more iteration. 0
one more iteration. 0
one more iteration. 10
one more iteration. 5182


In [53]:
exclude_same = False
reverse_list = False
only_ascii = True
only_alnum = False

In [54]:
get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=only_ascii,
    only_alnum=only_alnum,
    exclude_same=exclude_same,
    tokens_list=None,
)

[(' ranem', ' nad'),
 ('ornie', ' przez'),
 ('datek', ' nad'),
 ('arl', ' przed'),
 (' pewno', ' na'),
 ('spodziewanie', ' nad'),
 (' razu', ' od'),
 ('miernie', ' nad'),
 (' okazji', ' przy'),
 (' wsk', ' na'),
 (' czele', ' na'),
 ('orne', ' przez'),
 ('godziny', ' nad'),
 ('sione', ' przed'),
 (' ranem', 'nad'),
 ('natural', ' nad'),
 ('przewo', ' nad'),
 (' Duna', ' nad'),
 ('ktory', ' do'),
 ('miar', ' nad'),
 (' dobra', ' dla'),
 (' Jezi', ' nad'),
 ('spodzie', ' nad'),
 (' niedawna', ' do'),
 ('pisie', ' pod'),
 (' wygody', ' dla'),
 ('arcie', ' przed'),
 (' sumie', ' w'),
 ('miernie', 'nad'),
 ('tek', ' pod'),
 ('wcze', ' przed'),
 ('mier', ' nad'),
 ('hala', ' pod'),
 ('ornie', ' przeze'),
 (' uboczu', ' na'),
 (' podstawie', ' na'),
 (' ranem', 'Nad'),
 ('czesne', ' do'),
 ('granicznych', ' nad'),
 ('wiska', ' przez'),
 (' barkach', ' na'),
 ('ornie', 'przez'),
 (' odmiany', ' dla'),
 ('niego', ' przed'),
 (' plecami', ' za'),
 (' dobi', ' na'),
 ('przewodni', ' nad'),
 (' ko

In [56]:
i1, i2 = 18, 2
i1, i2
W_V_tmp, W_O_tmp = W_V_heads[i1, i2, :], W_O_heads[i1, i2]
tmp = emb_inv @ (W_V_tmp @ W_O_tmp) @ emb
all_high_pos = approx_topk(
    tmp, th0=1, verbose=True
)  # torch.nonzero((tmp > th) & (tmp < th_max)).tolist()
exclude_same = False
reverse_list = False
only_ascii = True
only_alnum = False
get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=only_ascii,
    only_alnum=only_alnum,
    exclude_same=exclude_same,
    tokens_list=None,
)

one more iteration. 0
one more iteration. 0
one more iteration. 1738


[('was', '-'),
 ('was', '--'),
 ('lla', '-'),
 ('heim', '-'),
 ('lle', '-'),
 ('gonie', '-'),
 ('lla', '--'),
 ('ye', '-'),
 ('osobowej', '-'),
 ('go', '-'),
 ('has', '-'),
 ('zym', '-'),
 ('tino', '-'),
 ('kowiec', '-'),
 (' Stanu', '-'),
 ('zji', '-'),
 ('lowie', '-'),
 ('procent', '-'),
 ('lle', '-.'),
 ('stanu', '-'),
 ('lle', '--'),
 ('123', '-'),
 ('gos', '-'),
 ('head', '-'),
 ('fil', '--'),
 ('lla', '-.'),
 ('kowca', '-'),
 ('lah', '-'),
 ('lah', '--'),
 ('was', '"-'),
 ('fil', '-'),
 ('krotnie', '-'),
 ('dno', '-'),
 ('top', '-'),
 ('lla', '->'),
 ('sbur', '-'),
 ('czycy', '-'),
 ('dzkiego', '-'),
 ('pii', '-'),
 ('bak', '-'),
 ('lowie', '--'),
 ('lski', '-'),
 ('lit', '-'),
 ('has', '--'),
 ('ben', '-'),
 ('lla', '-)'),
 ('ls', '-'),
 ('tino', '--'),
 ('lli', '-'),
 ('zji', '--')]

BERT - test

In [96]:
# tokenizer_gpt = AutoTokenizer.from_pretrained(
#     "sdadas/polish-gpt2-medium", add_prefix_space=True
# )
# tokenizer_gpt.pad_token = tokenizer_gpt.eos_token
# model_gpt = AutoModelForCausalLM.from_pretrained(
#     "sdadas/polish-gpt2-medium", pad_token_id=tokenizer_gpt.pad_token_id
# )

model = AutoModelForCausalLM.from_pretrained("allegro/herbert-base-cased")
tokenizer = my_tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")
emb = model.get_output_embeddings().weight.data.T.detach()
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
hidden_dim = model.config.hidden_size
head_size = hidden_dim // num_heads

K = torch.cat(
    [
        model.get_parameter(f"bert.encoder.layer.{j}.attention.self.key.weight").T
        for j in range(num_layers)
    ]
).detach()
V = torch.cat(
    [
        model.get_parameter(f"bert.encoder.layer.{j}.attention.self.value.weight")
        for j in range(num_layers)
    ]
).detach()

W_Q, W_K, W_V = (
    torch.cat(
        [
            model.get_parameter(f"bert.encoder.layer.{j}.attention.self.query.weight")
            for j in range(num_layers)
        ]
    )
    .detach()
    .chunk(3, dim=-1)
)
W_O = torch.cat(
    [
        model.get_parameter(f"bert.encoder.layer.{j}.attention.output.dense.weight")
        for j in range(num_layers)
    ]
).detach()

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [100]:
hidden_dim

768

In [101]:
K.shape

torch.Size([9216, 768])

TODO:
- fix parameter projection for bert base model
- generate overall lists for both models
- compare the two models

In [106]:
assert K.shape[1] == hidden_dim, "K dimensions do not match hidden_dim"
assert V.shape[1] == hidden_dim, "V dimensions do not match hidden_dim"

K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(
    0, 2, 1, 3
)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(
    0, 2, 1, 3
)
W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(
    0, 2, 1, 3
)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)

RuntimeError: shape '[12, 768, 12, 64]' is invalid for input of size 2359296