In [38]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = 'flax-community/papuGaPT2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

tokenizer_papuga = AutoTokenizer.from_pretrained(model_name)
model_papuga = AutoModelForCausalLM.from_pretrained(model_name).to(device)

# The embedding matrix for PapugaGPT2 is at:
# model_papuga.transformer.wte.weight (shape: [vocab_size, embedding_dim])
papuga_emb_matrix = model_papuga.transformer.wte.weight.detach().cpu().numpy()

def get_word_embedding_papuga_non_contextual(word: str):
    """
    Returns a single vector representing the word by averaging
    the embeddings of all tokens that make up the word.
    """
    # Encode the word
    # PapuGaGPT2 expects a leading space for 'proper' tokenization, so let's add one:
    input_ids = tokenizer_papuga(" "+word, add_special_tokens=False)['input_ids']
    
    # Retrieve each token's embedding from the embedding matrix
    token_vectors = []
    for tid in input_ids:
        token_vectors.append(papuga_emb_matrix[tid])
    
    # Average them to get a single embedding
    # return np.mean(token_vectors, axis=0)
    
    # return the max value of each dimension
    return np.max(token_vectors, axis=0)

# Example usage:
words = ["matematyka", "długopis", "bławatki", "szarlotka"]
for w in words:
    emb = get_word_embedding_papuga_non_contextual(w)
    print(f"Word: {w}, Emb.shape={emb.shape}")


Using device: cuda
Word: matematyka, Emb.shape=(768,)
Word: długopis, Emb.shape=(768,)
Word: bławatki, Emb.shape=(768,)
Word: szarlotka, Emb.shape=(768,)


In [44]:
import torch
from transformers import AutoTokenizer, AutoModel

model_name_herbert = "allegro/herbert-base-cased"
tokenizer_herbert = AutoTokenizer.from_pretrained(model_name_herbert)
model_herbert = AutoModel.from_pretrained(model_name_herbert).to(device)

def get_word_embedding_bert_contextual(word: str):
    """
    Gets the 'contextual' BERT-based embedding for a single word.
    By default, we use the last hidden state and average over the tokens that form the word.
    """
    # Option 1: Just feed the single word (might be suboptimal, no real context).
    # Option 2: Insert the word in a minimal sentence, e.g. "To jest WORD."
    # We'll show option 1 for simplicity:
    input_ids = tokenizer_herbert.encode(word, add_special_tokens=True)
    # Convert to tensors
    input_ids_tensor = torch.tensor([input_ids]).to(device)
    
    with torch.no_grad():
        outputs = model_herbert(input_ids_tensor)
    
    # 'outputs' is a BaseModelOutput with the first element: last_hidden_state shape => [batch_size, seq_len, hidden_dim]
    last_hidden_state = outputs.last_hidden_state.squeeze(0)  # shape => [seq_len, hidden_dim]
    
    # The word may be split into multiple subwords (HerBERT uses WordPiece).
    # Typically, [CLS] is the first token. We'll skip that in the averaging.
    # Also skip [SEP] if present at the end.
    # So we'll average over indices [1 : seq_len-1] ignoring special tokens
    # If the word is subword-split, this includes all subword tokens.

    # indices 1 to len(input_ids)-2 to skip [CLS] and [SEP].
    subword_vectors = last_hidden_state[1:-1]
    # subword_vectors = last_hidden_state
    if subword_vectors.shape[0] == 0:
        # edge case if the word is extremely short, fallback
        subword_vectors = last_hidden_state

    # average pooling
    word_vector = torch.mean(subword_vectors, dim=0)

    return word_vector.cpu().numpy()

# Example usage:
words_bert = ["długopis", "szarlotka", "bławatki"]
for w in words_bert:
    emb_bert = get_word_embedding_bert_contextual(w)
    print(f"BERT Word: {w}, Emb.shape={emb_bert.shape}")

Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.sso.sso_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.sso.sso_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BERT Word: długopis, Emb.shape=(768,)
BERT Word: szarlotka, Emb.shape=(768,)
BERT Word: bławatki, Emb.shape=(768,)


In [2]:
import random

# Polish diacritics mapping
polish_diacritics_map = {
    'ą': 'a', 'ć': 'c', 'ę': 'e', 'ł': 'l', 
    'ń': 'n', 'ó': 'o', 'ś': 's', 'ź': 'z', 'ż': 'z'
}

def remove_polish_diacritics(word: str) -> str:
    """
    Removes Polish diacritics by mapping each character
    to its unaccented counterpart.
    """
    output = []
    for ch in word:
        if ch in polish_diacritics_map:
            output.append(polish_diacritics_map[ch])
        else:
            output.append(ch)
    return "".join(output)

def random_swap(word: str) -> str:
    """
    Swaps two letters in the word.
    If the word has length < 2, return it as is.
    """
    w = list(word)
    if len(w) < 2:
        return word
    
    # pick two positions to swap
    i = random.randint(0, len(w) - 1)
    j = random.randint(0, len(w) - 1)
    
    # ensure we do something 
    if i != j:
        w[i], w[j] = w[j], w[i]
    return "".join(w)

def distort_word(word: str) -> str:
    """
    Applies both distortions:
    1) Removing Polish diacritics
    2) Random swapping of two letters
    ensuring the final word is *always* different in some way.
    """
    # remove diacritics
    no_diac = remove_polish_diacritics(word)
    # random swap
    distorted = random_swap(no_diac)
    return distorted

# Example
original_words = ["długopis", "krówka", "źdźbło", "pająk"]
for w in original_words:
    dist = distort_word(w)
    print(f"{w} -> {dist}")


długopis -> dulgopis
krówka -> krkwoa
źdźbło -> zdzblo
pająk -> paajk


In [5]:
clusters_txt = '''
piśmiennicze: pisak flamaster ołówek długopis pióro
małe_ssaki: mysz szczur chomik łasica kuna bóbr
okręty: niszczyciel lotniskowiec trałowiec krążownik pancernik fregata korweta
lekarze: lekarz pediatra ginekolog kardiolog internista geriatra
zupy: rosół żurek barszcz
uczucia: miłość przyjaźń nienawiść gniew smutek radość strach
działy_matematyki: algebra analiza topologia logika geometria 
budynki_sakralne: kościół bazylika kaplica katedra świątynia synagoga zbór
stopień_wojskowy: chorąży podporucznik porucznik kapitan major pułkownik generał podpułkownik
grzyby_jadalne: pieczarka borowik gąska kurka boczniak kania
prądy_filozoficzne: empiryzm stoicyzm racjonalizm egzystencjalizm marksizm romantyzm
religie: chrześcijaństwo buddyzm islam prawosławie protestantyzm kalwinizm luteranizm judaizm
dzieła_muzyczne: sonata synfonia koncert preludium fuga suita
cyfry: jedynka dwójka trójka czwórka piątka szóstka siódemka ósemka dziewiątka
owady: ważka biedronka żuk mrówka mucha osa pszczoła chrząszcz
broń_biała: miecz topór sztylet nóż siekiera
broń_palna: karabin pistolet rewolwer fuzja strzelba
komputery: komputer laptop kalkulator notebook
kolory: biel żółć czerwień błękit zieleń brąz czerń
duchowny: wikary biskup ksiądz proboszcz rabin pop arcybiskup kardynał pastor
ryby: karp śledź łosoś dorsz okoń sandacz szczupak płotka
napoje_mleczne: jogurt kefir maślanka
czynności_sportowe: bieganie skakanie pływanie maszerowanie marsz trucht
ubranie:  garnitur smoking frak żakiet marynarka koszula bluzka sweter sweterek sukienka kamizelka spódnica spodnie
mebel: krzesło fotel kanapa łóżko wersalka sofa stół stolik ława
przestępca: morderca zabójca gwałciciel złodziej bandyta kieszonkowiec łajdak łobuz
mięso_wędliny wieprzowina wołowina baranina cielęcina boczek baleron kiełbasa szynka schab karkówka dziczyzna
drzewo: dąb klon wiąz jesion świerk sosna modrzew platan buk cis jawor jarzębina akacja
źródło_światła: lampa latarka lampka żyrandol żarówka reflektor latarnia lampka
organ: wątroba płuco serce trzustka żołądek nerka macica jajowód nasieniowód prostata śledziona
oddziały: kompania pluton batalion brygada armia dywizja pułk
napój_alkoholowy: piwo wino wódka dżin nalewka bimber wiśniówka cydr koniak wiśniówka
kot_drapieżny: puma pantera lampart tygrys lew ryś żbik gepard jaguar
metal: żelazo złoto srebro miedź nikiel cyna cynk potas platyna chrom glin aluminium
samolot: samolot odrzutowiec awionetka bombowiec myśliwiec samolocik helikopter śmigłowiec
owoc: jabłko gruszka śliwka brzoskwinia cytryna pomarańcza grejpfrut porzeczka nektaryna
pościel: poduszka prześcieradło kołdra kołderka poduszeczka pierzyna koc kocyk pled
agd: lodówka kuchenka pralka zmywarka mikser sokowirówka piec piecyk piekarnik
'''

In [17]:
# Cell A: Extract unique words from the clusters
def get_unique_words_from_clusters(clusters_txt):
    words = []
    for line in clusters_txt.split('\n'):
        parts = line.split()
        if len(parts) < 2:
            continue
        # add the cluster name as well
        words.append(parts[0][:-1])
        # the cluster name is parts[0], actual words in parts[1:]
        cluster_words = parts[1:]
        for w in cluster_words:
            words.append(w)
    return words

# We assume "clusters_txt" is already defined in the notebook by the evaluation script.
all_words = get_unique_words_from_clusters(clusters_txt)
print(f"Total unique words in clusters: {len(all_words)}")
print(all_words[:30], "...")
all_words = list(set(all_words))  # remove duplicates


Total unique words in clusters: 328
['piśmiennicze', 'pisak', 'flamaster', 'ołówek', 'długopis', 'pióro', 'małe_ssaki', 'mysz', 'szczur', 'chomik', 'łasica', 'kuna', 'bóbr', 'okręty', 'niszczyciel', 'lotniskowiec', 'trałowiec', 'krążownik', 'pancernik', 'fregata', 'korweta', 'lekarze', 'lekarz', 'pediatra', 'ginekolog', 'kardiolog', 'internista', 'geriatra', 'zupy', 'rosół'] ...


In [39]:
# Cell B: Papuga on Original Words
with open("word_embedings_file_papuga_original.txt", "w", encoding="utf-8") as f:
    for w in all_words:
        emb_vec = get_word_embedding_papuga_non_contextual(w)
        # Convert to string
        vec_str = " ".join(map(str, emb_vec))
        f.write(f"{w} {vec_str}\n")

print("Papuga embeddings for ORIGINAL words have been saved to word_embedings_file.txt.")
print("Now re-run the evaluation script cell above/below to see the score (Papuga - original).")

Papuga embeddings for ORIGINAL words have been saved to word_embedings_file.txt.
Now re-run the evaluation script cell above/below to see the score (Papuga - original).


In [45]:
# Cell C: BERT on Original Words
with open("word_embedings_file_BERT_original.txt", "w", encoding="utf-8") as f:
    for w in all_words:
        emb_vec = get_word_embedding_bert_contextual(w)
        vec_str = " ".join(map(str, emb_vec))
        f.write(f"{w} {vec_str}\n")

print("BERT embeddings for ORIGINAL words have been saved to word_embedings_file.txt.")
print("Re-run the evaluation script cell to see the score (BERT - original).")

BERT embeddings for ORIGINAL words have been saved to word_embedings_file.txt.
Re-run the evaluation script cell to see the score (BERT - original).


In [40]:
# Cell D: Papuga on Distorted Words
with open("word_embedings_file_papuga_deformed.txt", "w", encoding="utf-8") as f:
    for w in all_words:
        w_dist = distort_word(w)  # e.g. remove diacritics + random swap
        emb_vec = get_word_embedding_papuga_non_contextual(w_dist)
        vec_str = " ".join(map(str, emb_vec))
        f.write(f"{w} {vec_str}\n")

print("Papuga embeddings for DISTORTED words have been saved to word_embedings_file.txt.")
print("Re-run the evaluation script cell to see the score (Papuga - distorted).")

Papuga embeddings for DISTORTED words have been saved to word_embedings_file.txt.
Re-run the evaluation script cell to see the score (Papuga - distorted).


In [46]:
# Cell E: BERT on Distorted Words
with open("word_embedings_file_BERT_deformed.txt", "w", encoding="utf-8") as f:
    for w in all_words:
        w_dist = distort_word(w)
        emb_vec = get_word_embedding_bert_contextual(w_dist)
        vec_str = " ".join(map(str, emb_vec))
        # Again, if you want to preserve the *original* key in the script, do:
        # f.write(f"{w} {vec_str}\n")
        # Otherwise, do:
        f.write(f"{w} {vec_str}\n")

print("BERT embeddings for DISTORTED words have been saved to word_embedings_file.txt.")
print("Re-run the evaluation script cell to see the score (BERT - distorted).")

BERT embeddings for DISTORTED words have been saved to word_embedings_file.txt.
Re-run the evaluation script cell to see the score (BERT - distorted).
