In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from transformers import pipeline
unmasker = pipeline("fill-mask", model='bert-base-uncased')

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
import torch
from transformers import BertTokenizer, BertForMaskedLM
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [None]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name, output_attentions=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Отношения для стран

In [None]:
preds = unmasker("The greatest country in the world is [MASK].", top_k=150)
countries = [pred['token_str'] for pred in preds]

In [None]:
corr_countries = []
wrong_countries = []
for country in countries:
    pred = unmasker(f"{country} is the greatest [MASK] in the world.", top_k=1)[0]
    if pred['token_str'] == 'country':
        corr_countries.append(country)
    else:
        wrong_countries.append(country)

In [None]:
relations = []
for country in tqdm(countries):
    preds = unmasker(f"The [MASK] of {country} is", top_k=15)
    for pred in preds:
        relations.append(pred['sequence'])

100%|██████████| 150/150 [00:15<00:00,  9.96it/s]


## Attention-based подход

In [None]:
wrong_relations = []
right_relations = []

for relation in relations:
    tokens = tokenizer.tokenize(relation + ' [MASK].')
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokens)
    masked_index = tokens.index("[MASK]")
    tokens_tensor = torch.tensor([indexed_tokens])

    # Generate the attention weights
    with torch.no_grad():
        outputs = model(tokens_tensor)
        attentions = outputs.attentions[-1]

    # Visualize the attention weights
    attention_weights = attentions[0][:, :, masked_index]
    #token_labels = tokens[:len(attention_weights)]

    mean_attentions = torch.mean(attention_weights, 0)

    if mean_attentions[1] < mean_attentions[3]:
        right_relations.append(relation)
    else:
        wrong_relations.append(relation)

In [None]:
right_relations

['the peerage of japan is',
 'the primate of japan is',
 'the commonwealth of australia is',
 'the parliament of australia is',
 'the order of australia is',
 'the flag of australia is',
 'the constitution of australia is',
 'the president of australia is',
 'the primate of australia is',
 'the pride of australia is',
 'the federation of australia is',
 'the constitution of india is',
 'the parliament of india is',
 'the anthem of india is',
 'the chancellor of india is',
 'the treasurer of india is',
 'the primate of india is',
 'the currency of brazil is',
 'the anthem of brazil is',
 'the primate of brazil is',
 'the economy of brazil is',
 'the constitution of argentina is',
 'the currency of argentina is',
 'the anthem of argentina is',
 'the government of france is',
 'the flag of germany is',
 'the constitution of germany is',
 'the chancellor of germany is',
 'the currency of germany is',
 'the state of germany is',
 'the parliament of germany is',
 'the anthem of germany is',


## Фильтрация по схожести предсказаний

In [None]:
import random

In [None]:
true_relations = []
for relation in tqdm(relations):
    preds = unmasker(relation + " [MASK].", top_k=5)
    tokens = [pred['token_str'] for pred in preds]
    if preds[0]['score'] > 0.4:
        prompt2 = ' '.join(relation.split(' ')[:3]) + ' ' + random.choice(countries) + ' is [MASK].'
        preds = unmasker(prompt2, top_k=5)
        tokens2 = [pred['token_str'] for pred in preds]
        if len(set(tokens) & set(tokens2)) < 3:
            true_relations.append(relation)

100%|██████████| 2250/2250 [04:57<00:00,  7.56it/s]


In [None]:
true_relations

['the capital of china is',
 'the treasurer of australia is',
 'the anthem of russia is',
 'the capital of russia is',
 'the capital of argentina is',
 'the language of argentina is',
 'the capital of france is',
 'the currency of france is',
 'the currency of germany is',
 'the president of mexico is',
 'the treasurer of canada is',
 'the premier of canada is',
 'the capital of peru is',
 'the language of peru is',
 'the economy of peru is',
 'the president of indonesia is',
 'the cabinet of indonesia is',
 'the currency of indonesia is',
 'the capital of iran is',
 'the capital of egypt is',
 'the capital of afghanistan is',
 'the capital of bolivia is',
 'the economy of bolivia is',
 'the language of bolivia is',
 'the anthem of bolivia is',
 'the capital of spain is',
 'the language of spain is',
 'the capital of greece is',
 'the governor of ghana is',
 'the capital of vietnam is',
 'the governor of vietnam is',
 'the parliament of bangladesh is',
 'the capital of bangladesh is',


# Отношения для личностей

In [18]:
persons = ['barack obama', 'elvis presley', 'lewis carroll', 'ernest hemingway', 'thomas edison', 'rihanna', 'bob dylan', 'taylor swift', 
           'anderson cooper', 'friedman']

In [None]:
import random
import gensim.downloader

In [None]:
glove_vectors = gensim.downloader.load('glove-twitter-25')



In [19]:
all_relations = []
#length = 0
def rec_func(phrases):
    if len(phrases[-1].split(' who ')[1].split(' ')) == 3:
        return
    all_phrases2 = []
    phrases_mask = [phrase[:-1] + ' [MASK].' for phrase in phrases]
    preds = unmasker(phrases_mask, batch_size=32, top_k=10)
    rand_phrases = [random.choice(persons) + ' ,' + phrase.split(',')[1][:-1] + ' [MASK].' for phrase in phrases]
    rand_preds = unmasker(rand_phrases, batch_size=32, top_k=10)
    for ind, pred in enumerate(preds):
        tokens = [var['token_str'] for var in pred[:5]]
        sim = 0
        for token in tokens:
            for token2 in tokens:
                if token in glove_vectors and token2 in glove_vectors:
                    if glove_vectors.similarity(token, token2) > 0.8:
                        sim += 1
        if sim > 15:
            pred_rand = rand_preds[ind]
            tokens_rand = [var['token_str'] for var in pred_rand[:5]]
            if len(set(tokens) & set(tokens_rand)) < 3:
                all_relations.append(pred[0]['sequence'])
            for var in pred:
                all_phrases2.append(var['sequence'])
    return rec_func(all_phrases2)

In [None]:
preds = unmasker(f"charles dickens, who [MASK].", top_k=100)
phrases = [pred['sequence'] for pred in preds if pred['token_str'][0].isalpha() == True]
relations = rec_func(phrases)

['charles dickens, who wrote ulysses.',
 'charles dickens, who read it.',
 'charles dickens, who plays himself.',
 'charles dickens, who illustrated it.',
 'charles dickens, who writes fiction.',
 'charles dickens, who played himself.',
 'charles dickens, who knew him.',
 'charles dickens, who ed hon.',
 'charles dickens, who acted twice.',
 'charles dickens, who visited london.',
 'charles dickens, who reads it.',
 'charles dickens, who contributed illustrations.',
 'charles dickens, who knows what.',
 'charles dickens, who drew him.',
 'charles dickens, who is blind.',
 'charles dickens, who actor ).',
 'charles dickens, who playwrights.',
 'charles dickens, who jr vol.',
 'charles dickens, who hunted deer.',
 'charles dickens, who won twice.',
 'charles dickens, who novels?.',
 'charles dickens, who killed himself.',
 'charles dickens, who fell asleep.',
 'charles dickens, who laughs again.',
 'charles dickens, who does it.',
 'charles dickens, who defied.',
 'charles dickens, who d