In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import logging
from transformers import pipeline
from transformers import BertTokenizer, BertModel
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

In [2]:
from LamaTRExData import LamaTRExData#
from SentenceTypologyQueryResults import SentenceTypologyQueryResults
from TypologyEmbeddingComparison import TypologyEmbeddingComparison

In [3]:
from relation_templates.P19_place_of_birth import retrieve_all_prompts as p19
from relation_templates.P36_capital import retrieve_all_prompts as p36
from relation_templates.P101_field_of_work import retrieve_all_prompts as p101
from relation_templates.P178_developer import retrieve_all_prompts as p178

In [4]:
logging.set_verbosity_error()

In [5]:
MASK = "[MASK]"

In [6]:
relations = ["P19", "P36", "P101", "P178"]

In [7]:
TREx = LamaTRExData(relations = relations)
TREx.load()

In [8]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [9]:
def get_emebddings(sentence: str) -> torch.Tensor:
    input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
    return model(input_ids)[0][0].mean(dim=0)
    

In [10]:
cos = nn.CosineSimilarity(dim=0, eps=1e-6)

In [11]:
comparison = TypologyEmbeddingComparison(relations, cos)

In [12]:
for relation in relations:
    for sub, obj in tqdm(TREx.data[relation]):
        query_simple, query_compound, query_complex, query_complex_compound  = p19(sub, obj)
        emb_simple = get_emebddings(query_simple)
        emb_compound = get_emebddings(query_compound)
        emb_complex = get_emebddings(query_complex)
        emb_complex_compound = get_emebddings(query_complex_compound)
        
        comparison.process_embeddings(relation, sub, obj, [emb_simple, emb_compound, emb_complex, emb_complex_compound])
        
        

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 944/944 [02:19<00:00,  6.76it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 703/703 [01:47<00:00,  6.53it/s]
100%|█████████████████████████████

In [17]:
comparison.show_similarity("P178",3,4)

Similarities for P178 type 3 vs type 4:
MessagePad, Apple: 0.8860700130462646 
macOS, Apple: 0.8108241558074951 
iOS, Apple: 0.8571710586547852 
Safari, Apple: 0.9198052287101746 
iLife, Apple: 0.9146302342414856 
iWork, Apple: 0.891785204410553 
PostScript, Adobe: 0.8557952046394348 
IBM AIX, IBM: 0.9306423664093018 
AppleTalk, Apple: 0.9078571200370789 
Applesoft BASIC, Microsoft: 0.9360917210578918 
Active Directory, Microsoft: 0.9208530187606812 
Windows Vista, Microsoft: 0.9389230608940125 
Internet Explorer, Microsoft: 0.938947856426239 
Outlook Express, Microsoft: 0.9330407381057739 
Microsoft Outlook, Microsoft: 0.9398163557052612 
B-17 Flying Fortress, Boeing: 0.9193887114524841 
DirectX, Microsoft: 0.9271332025527954 
Microsoft Windows, Microsoft: 0.9300569891929626 
Xbox 360, Microsoft: 0.9192625284194946 
Xbox One, Microsoft: 0.917088508605957 
Microsoft Internet Explorer 4, Microsoft: 0.928252100944519 
Itanium, Intel: 0.8875919580459595 
MS-DOS, Microsoft: 0.9419382214546

In [14]:
p101("John Calvin", "Theology")

["John Calvin's field of work was Theology.",
 'John Calvin was known in his field of work and his particular field was Theology',
 "People in his time knew Theology's field of work, which was John Calvin",
 'John Calvin was known in his discipline for the work he had done in his field and the subject of his work was Theology']

In [15]:
cos(get_emebddings("John Calvin's field of work was Theology."), get_emebddings('John Calvin was known in his discipline for the work he had done in his field and the subject of his work was Theology'))

tensor(0.8156, grad_fn=<DivBackward0>)

In [21]:
for relation in relations:
    for sub, obj in TREx.data[relation]:
        emb_sub = get_emebddings(sub)
        emb_obj = get_emebddings(obj)
        print(sub, obj, cos(emb_sub, emb_obj).item())

Allan Peiper Alexandra 0.6789600253105164
Anthony Barber Doncaster 0.6750724911689758
Paul Mounsey Scotland 0.540736734867096
Moe Koffman Toronto 0.5930782556533813
Kurt Schwertsik Vienna 0.49512773752212524
Claude Arrieu Paris 0.5313374996185303
Ryō Kase Yokohama 0.4694552421569824
Frans Floris I Antwerp 0.5467697381973267
Henry Heras Barcelona 0.5607268214225769
Daniele Franceschini Rome 0.4067355990409851
Yeung Sum Guangzhou 0.5820107460021973
Sajjad Ali Lahore 0.49277329444885254
Giovanni Maria Morandi Florence 0.4849843680858612
Brian Boyd Belfast 0.6825699806213379
Maurice de Vlaminck Paris 0.4220474064350128
Guy Deghy Budapest 0.5697110295295715
Ze'ev Jabotinsky Odessa 0.361272394657135
John Brumby Melbourne 0.505050003528595
Pierre Dupont Lyon 0.7195153832435608
Taras Kuzio Halifax 0.37455007433891296
Masako Natsume Tokyo 0.557748019695282
Henry Mayhew London 0.5011945962905884
Henri Bourassa Montreal 0.536745548248291
Fergus McMaster Queensland 0.600176990032196
Natalie Lowe S