In [75]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from itertools import combinations
from collections import Counter
from ampligraph.latent_features import ConvE, DistMult, save_model, restore_model
from ampligraph.discovery import discover_facts, find_nearest_neighbours, query_topn

RANDOM_SEED = 17
EMB_DIM = 10
N = 100000

In [2]:
articles = pd.read_json('../EDA/articles_tts.json').drop(['year'], axis=1)

In [3]:
articles.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1084405 entries, 0 to 1084404
Data columns (total 3 columns):
 #   Column      Non-Null Count    Dtype 
---  ------      --------------    ----- 
 0   _id         1084405 non-null  object
 1   references  1084405 non-null  object
 2   authors     1084405 non-null  object
dtypes: object(3)
memory usage: 33.1+ MB


In [4]:
articles.head()

Unnamed: 0,_id,references,authors
0,53e99784b7602d9701f3e151,"[53e99cf5b7602d97025ace63, 557e8a7a6fee0fe990c...","[{'_id': '53f46797dabfaeb22f542630', 'name': '..."
1,53e99784b7602d9701f3e15d,"[53e9a8a9b7602d97031f6bb9, 599c7b6b601a182cd27...","[{'_id': '53f43b03dabfaedce555bf2a', 'name': '..."
2,53e99784b7602d9701f3f411,"[53e9adbdb7602d97037be8a2, 53e9bb53b7602d97047...","[{'_id': '548a2e3ddabfae9b40134fbc', 'name': '..."
3,53e99792b7602d9701f5af1a,"[53e9b3dab7602d9703ec7ddf, 53e9a3edb7602d9702d...","[{'_id': '5631df8845cedb3399f3e752', 'name': '..."
4,53e99792b7602d9701f5b0a5,"[53e9bdceb7602d9704a7ee58, 53e9b5a8b7602d97040...","[{'_id': '53f431addabfaedd74d6d650', 'name': '..."


# Creating dataset

Dataset consist from tuples: `(entity1, relation, entity2)`, where  relation $\in$ {`REFERENCE`, `AUTHOR`, `COAUTHOR`}

In [58]:
dataset = []

coauthors = set()  # to exclude repeated coathors
for _, article in tqdm(articles.iloc[:N].iterrows(), total=N):
    dataset.extend([[article._id, 'REFERENCE', ref] for ref in article.references])
    
    authors = sorted(author['_id'] for author in article.authors)
    dataset.extend([[author, 'AUTHOR', article._id] for author in authors])
    
    curr_coauthors = set(filter(lambda p: p not in coauthors, combinations(authors, 2)))
    dataset.extend([[author1, 'COAUTHOR', author2] for (author1, author2) in curr_coauthors])
    
    coauthors |= curr_coauthors

print(len(dataset))
dataset = np.array(dataset)
print(Counter(dataset[:, 1]))

  0%|          | 0/100000 [00:00<?, ?it/s]

1596416
Counter({'REFERENCE': 942667, 'COAUTHOR': 367061, 'AUTHOR': 286688})


In [89]:
print(dataset[:20])

[['53e99784b7602d9701f3e151' 'REFERENCE' '53e99cf5b7602d97025ace63']
 ['53e99784b7602d9701f3e151' 'REFERENCE' '557e8a7a6fee0fe990caa63d']
 ['53e99784b7602d9701f3e151' 'REFERENCE' '53e9a96cb7602d97032c459a']
 ['53e99784b7602d9701f3e151' 'REFERENCE' '53e9b929b7602d9704515791']
 ['53e99784b7602d9701f3e151' 'REFERENCE' '557e59ebf6678c77ea222447']
 ['53f46797dabfaeb22f542630' 'AUTHOR' '53e99784b7602d9701f3e151']
 ['54328883dabfaeb4c6a8a699' 'AUTHOR' '53e99784b7602d9701f3e151']
 ['53f46797dabfaeb22f542630' 'COAUTHOR' '54328883dabfaeb4c6a8a699']
 ['53e99784b7602d9701f3e15d' 'REFERENCE' '53e9a8a9b7602d97031f6bb9']
 ['53e99784b7602d9701f3e15d' 'REFERENCE' '599c7b6b601a182cd27360da']
 ['53e99784b7602d9701f3e15d' 'REFERENCE' '53e9b443b7602d9703f3e52b']
 ['53e99784b7602d9701f3e15d' 'REFERENCE' '53e9a6a6b7602d9702fdc57e']
 ['53e99784b7602d9701f3e15d' 'REFERENCE' '599c7b6a601a182cd2735703']
 ['53e99784b7602d9701f3e15d' 'REFERENCE' '53e9aad9b7602d970345afea']
 ['53e99784b7602d9701f3e15d' 'REFERENCE' 

In [55]:
model = ConvE(
    k=EMB_DIM,
    epochs=3,
    optimizer='sgd',
    seed=RANDOM_SEED,
    low_memory=True,
    verbose=True
)

model.fit(dataset)



NotImplementedError: ConvE not implemented when dealing with large graphs.

### For training `ConvE` dataset is too large.

In [None]:
(708010 * EMB_DIM* 200 * 2 * 4) / 1024 ** 3

10.550171136856079

In [62]:
model = DistMult(
    k=EMB_DIM,
    seed=RANDOM_SEED,
    optimizer='sgd',
    epochs=10,
    verbose=True
)

model.fit(dataset)



Average DistMult Loss:   1.386221: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:49<00:00,  4.96s/epoch]


In [76]:
save_model(model, './graph_model.pkl')

## After model is trained we explore knowledge from our graph

## Getting embeddings for entities

In [77]:
paper, author = articles.loc[0]._id, articles.loc[0].authors[0]['_id']

In [79]:
model.get_embeddings([paper, author])

array([[ 3.6442066e-03,  1.9322786e-03, -3.0478395e-03, -2.5843806e-04,
         9.0770680e-04, -2.9814788e-03,  2.2133968e-03, -7.9435494e-04,
        -1.8411485e-03, -4.1665314e-04],
       [-2.5265519e-05, -3.9500915e-03,  1.6737570e-03, -5.7361176e-04,
        -3.6262395e-04, -1.1760191e-03,  9.4050070e-04, -5.1189039e-04,
         1.1498800e-03, -1.3735156e-03]], dtype=float32)

### Discover new facts. (we can specify relations if we want)

In [69]:
new_facts = discover_facts(dataset, model, top_n=10, max_candidates=500, strategy='random_uniform', seed=RANDOM_SEED)

    protocol. This may be unnecessary and will lead to a 'harder' task. Besides, it will lead to a much slower
    evaluation procedure. We recommended to set the 'corruption_entities' argument to a reasonably sized set
    of entities. The size of corruption_entities depends on your domain-specific task.


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [42:28<00:00,  5.10s/it]


    protocol. This may be unnecessary and will lead to a 'harder' task. Besides, it will lead to a much slower
    evaluation procedure. We recommended to set the 'corruption_entities' argument to a reasonably sized set
    of entities. The size of corruption_entities depends on your domain-specific task.


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [42:23<00:00,  5.09s/it]


    protocol. This may be unnecessary and will lead to a 'harder' task. Besides, it will lead to a much slower
    evaluation procedure. We recommended to set the 'corruption_entities' argument to a reasonably sized set
    of entities. The size of corruption_entities depends on your domain-specific task.


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [10:58:59<00:00, 79.08s/it]


In [71]:
new_facts

(array([], shape=(0, 9), dtype=object), array([], dtype=float64))

### We need to specify larger number of `max_candidates` in `discover_facts`, otherwise the result will be empty.
However, it will take to much time, for `max_candidates=500` it took ~1.5h, but in doc examples `max_candidates` usually is set to $\approx10^4-10^5$

### Queries the model with two elements of a triple and returns the top_n results of all possible completions ordered by score predicted by the model.

In [51]:
topn = query_topn(model, top_n=10, head=paper, relation='REFERENCE')



 20%|████████████████████████▋                                                                                                 | 143298/708010 [1:19:11<5:12:05, 30.16it/s]


KeyboardInterrupt: 

### Return the nearest neighbors of entities.

#### For article and author as well

In [82]:
neighbors, dist = find_nearest_neighbours(model, entities=[paper, author], n_neighbors=10, metric='cosine')
neighbors, dist

(array([['53e99784b7602d9701f3e151', '53e9a751b7602d970308a2d0',
         '53e997f1b7602d9701ff51ef', '53f45674dabfaee0d9bf468b',
         '53f43002dabfaeb22f42915e', '558af2a4e4b037c0875a1efb',
         '53e9bc9eb7602d970491f34e', '53e9bc00b7602d970485eb25',
         '53f434d6dabfaee2a1cd81b3', '573695d26e3b12023e4eb362'],
        ['53f46797dabfaeb22f542630', '53e99c3db7602d97024ec417',
         '53e9aed1b7602d97038fa72b', '53f43180dabfaedf4354abf7',
         '53e9b289b7602d9703d2fd95', '53e9b917b7602d9704501bd1',
         '53e9b47cb7602d9703f8079e', '558c6b9ee4b0cfb70a1d937e',
         '53e9af7bb7602d97039c3d84', '53e99b63b7602d97024068d9']],
       dtype='<U24'),
 array([[0.        , 0.04109454, 0.04743528, 0.04954237, 0.04962188,
         0.05202979, 0.052127  , 0.05613053, 0.05625552, 0.06243461],
        [0.        , 0.03694481, 0.03990865, 0.0442304 , 0.04476464,
         0.04872388, 0.05277181, 0.0559786 , 0.06058449, 0.06213945]],
       dtype=float32))