In [None]:
!pip install transformers

In [None]:
from transformers import pipeline

In [None]:
import torch
from tqdm import tqdm
from collections import defaultdict

Функция извлекает отношения отношения страна - столица, страна - валюта, страна - официальный язык, личность - страна, географический объект - страна. Необходимо передать модель и токенизатор 

In [None]:
def get_relations(model, tokenizer):
    all_relations = {}
    unmasker = pipeline("fill-mask", model=model, tokenizer=tokenizer)
    masked_token = tokenizer.mask_token
    def generate_persons(proff):
        prompts = ['The greatest ' + prof + f' in the world is {masked_token}.' for prof in proff]
        preds = unmasker(prompts, batch_size=16, top_k=200)
        corr_members = defaultdict(list)
        for pred in tqdm(preds):
            category = pred[0]['sequence'].split(' ')[2]
            members = []
            seqs = [pr['sequence'][:-1] + f' {masked_token}.' for pr in pred]
            member2 = ''
            preds2token = unmasker(seqs, batch_size=32, top_k=1)
            for pred2 in preds2token:
                token1 = pred2[0]['sequence'].split(' ')[-2]
                members.append(token1)
                score2 = pred2[0]['score']
                token2 = pred2[0]['token_str']
                if score2 > 0.8 and token2 != 'himself' and token2 != 'herself':
                    member2 = token1 + ' ' + token2
                    members.append(member2)
            members = [f'{member} is the greatest {masked_token} in the world' for member in members]
            preds = unmasker(members, batch_size=32, top_k=5)
            for pred in preds:
                tokens = [pr['token_str'] for pr in pred]
                if category in tokens and pred[0]['score'] > 0.1:
                    corr_members[category].append(pred[0]['sequence'].split(' is')[0])
        return corr_members

    # страны

    # столица
    countries_bert = []
    preds = unmasker(f"The capital of {masked_token}.", top_k = 200)
    for pred in preds:
        countries_bert.append(pred['token_str'])

    corr_capitals = {}
    for country in countries_bert:
        preds = unmasker(f"The capital of {country} is {masked_token}.")
        if preds[0]['score'] > 0.1 and preds[0]['token_str'] != country:
            if unmasker(f"{preds[0]['sequence']} is the capital of {masked_token}.")[0]['token_str'] == country:
                corr_capitals[country] = (preds[0]['token_str'])
    all_relations['capital'] = corr_capitals

    #официальный язык
    corr_languages = {}
    preds = unmasker(f"The official language of {masked_token}.", top_k=150)
    countries = [pred['token_str'] for pred in preds]
    for country in countries:
        pred = unmasker(f"The official language of {country} is {masked_token}.", top_k=1)
        if pred[0]['score'] > 0.6:
            corr_languages[country] = pred[0]['token_str']
    all_relations['language'] = corr_languages

    # валюта
    corr_currencies = {}
    preds = unmasker(f"The official currency of {masked_token}.", top_k=150)
    countries = [pred['token_str'] for pred in preds]
    for country in countries:
        pred = unmasker(f"The official currency of {country} is {masked_token}.", top_k=1)
        if pred[0]['score'] > 0.1:
            corr_currencies[country] = pred[0]['token_str']
    all_relations['currency'] = corr_currencies

    # личность - страна
    apreds = unmasker(f"He is a {masked_token} by profession.", top_k=20)
    proff = [pred['token_str'] for pred in apreds]
    anpreds = unmasker(f"He is an {masked_token}.", top_k=20)
    anprof = [pred['token_str'] for pred in anpreds]
    for prof in anprof:
        if prof not in proff:
            proff.append(prof)
    persons = generate_persons(proff)

    persons2countries = {}
    corr_memb = []
    for prof, members in persons.items():
        for member in members:
            if prof not in member:
                token1 = ''
                score1 = 0
                a_preds = unmasker(f"{member} is a {masked_token} {prof}", top_k=1)
                if a_preds[0]['score'] >= 0.5:
                    score1 = a_preds[0]['score']
                    seq1 = a_preds[0]['sequence']
                    token1 = a_preds[0]['token_str']
                an_preds = unmasker(f"{member} is an {masked_token} {prof}", top_k=1)
                if an_preds[0]['score'] >= 0.5:
                    if score1 < an_preds[0]['score']:
                        persons2countries[(member, prof)] = an_preds[0]['token_str']
                        corr_memb.append(member)
                    else:
                        persons2countries[(member, prof)] = token1
                        corr_memb.append(member)
                elif token1 != '':
                    persons2countries[(member, prof)] = token1
                    corr_memb.append(member)
    all_relations['persons2countries'] = persons2countries

    #географический объект - страна

    def get_loc_relations(adj, type_loc, be):
        loc_relations = {}
        preds = unmasker(f'The {adj} {type_loc} in the world {be} {masked_token}.', top_k=100)
        places = [pred['token_str'] for pred in preds]
        corr_places = []
        for place in places:
            preds = unmasker(f'{place} {be} the {adj} {masked_token} in the world.', top_k=1)
            if preds[0]['token_str'] == type_loc:
                corr_places.append(place)
        for place in corr_places:
            preds = unmasker(f"{place} {be} located in {masked_token}.", top_k=1)
            if preds[0]['score'] > 0.1:
                loc_relations[place] = preds[0]['token_str']
        return loc_relations

    rivers = get_loc_relations('longest', 'river', 'is')
    mountains = get_loc_relations('highest', 'mountains', 'are')
    islands = get_loc_relations('biggest', 'island', 'is')

    all_relations['located in'] = rivers
    all_relations['located in'] = all_relations['located in'] | mountains
    all_relations['located in'] = all_relations['located in'] | islands