In [1]:
import os
import torch

from stark_qa import load_skb
from stark_neo4j_loading import insert_nodes, insert_relationships, insert_node_embeddings

from neo4j import GraphDatabase
from dotenv import load_dotenv
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

#load neo4j credentials
load_dotenv('../db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
NEO4J_URI

'bolt://localhost:7687'

In [2]:
dataset_name = 'mag'
skb = load_skb(dataset_name, download_processed=True, root=None)

Loading from /Users/sbr/.cache/huggingface/hub/datasets--snap-stanford--stark/snapshots/88269e23e90587f99476c5dd74e235a0877e69be/skb/mag/processed!


In [3]:
# Load pre-generated openai text-embedding-ada-002 embeddings
# Get emb_download.py from https://github.com/snap-stanford/stark. see Readme for other ways to generate embeddings
!python emb_download.py --dataset mag --emb_dir .

Downloading...
From (original): https://drive.google.com/uc?id=1HSfUrSKBa7mJbECFbnKPQgd6HSsI8spT
From (redirected): https://drive.google.com/uc?id=1HSfUrSKBa7mJbECFbnKPQgd6HSsI8spT&confirm=t&uuid=6bd9e32b-7b2a-467d-99b6-f21171240769
To: /Users/sbr/GraphRAFT/data_loading/mag/text-embedding-ada-002/query/query_emb_dict.pt
100%|██████████████████████████████████████| 85.6M/85.6M [00:04<00:00, 17.3MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1oVdScsDRuEpCFXtWQcTAx7ycvOggWF17
From (redirected): https://drive.google.com/uc?id=1oVdScsDRuEpCFXtWQcTAx7ycvOggWF17&confirm=t&uuid=ec649d9d-fc49-490d-8e76-1573f32c421e
To: /Users/sbr/GraphRAFT/data_loading/mag/text-embedding-ada-002/doc/candidate_emb_dict.pt
100%|██████████████████████████████████████| 4.51G/4.51G [03:22<00:00, 22.2MB/s]


In [3]:
skb.edge_type_dict

{0: 'author___affiliated_with___institution',
 1: 'paper___cites___paper',
 2: 'paper___has_topic___field_of_study',
 3: 'author___writes___paper'}

In [17]:
dir(skb)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_build_sparse_adj',
 '_process_raw',
 'abstract_path',
 'candidate_ids',
 'candidate_types',
 'edge_index',
 'edge_type2id',
 'edge_type_dict',
 'edge_types',
 'get_all_paths',
 'get_candidate_ids',
 'get_doc_info',
 'get_edge_ids_by_type',
 'get_edge_type_by_id',
 'get_map',
 'get_neighbor_nodes',
 'get_node_ids_by_type',
 'get_node_ids_by_value',
 'get_node_type_by_id',
 'get_rel_info',
 'get_tuples',
 'graph_data_root',
 'is_rel_type',
 'k_hop_neighbor',
 'load_edge',
 'load_english_paper_text',
 'load_meta_data',
 'mag_mapping_dir',
 'mag_metadata_cache_dir',
 'merged_filt

In [6]:
skb.node_attr_dict

{'paper': ['title', 'abstract', 'publication date', 'venue'],
 'author': ['name'],
 'institution': ['name'],
 'field_of_study': ['name']}

In [11]:
emb_d = torch.load('./mag/text-embedding-ada-002/doc/candidate_emb_dict.pt', weights_only=False)

In [26]:
if dataset_name=='mag':
    #Rename attribute
    print("====== Preprocessing ======")
    for node_id, node in tqdm(skb.node_info.items()):
        if 'ConferenceSeriesId.1' in node.keys():
            node['ConferenceSeriesId1'] = node.pop('ConferenceSeriesId.1')
        
    with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
        insert_nodes(skb, driver=driver, key_property='new_id') 
        insert_relationships(skb=skb, driver=driver, dataset_name=dataset_name)
        
        abstract_embeddings = torch.load('./mag/text-embedding-ada-002/doc/candidate_emb_dict.pt', weights_only=False)
        insert_node_embeddings(embeddings=abstract_embeddings, embedding_name='text_embedding', driver=driver)




100%|██████████| 1872968/1872968 [00:00<00:00, 4481728.14it/s]




100%|██████████| 1104554/1104554 [00:35<00:00, 30692.24it/s]




100%|██████████| 8701/8701 [00:00<00:00, 22925.48it/s]




100%|██████████| 59469/59469 [00:01<00:00, 30146.38it/s]




100%|██████████| 700244/700244 [01:25<00:00, 8229.95it/s]




100%|██████████| 2058324/2058324 [00:33<00:00, 61063.67it/s]




100%|██████████| 9719488/9719488 [03:08<00:00, 51483.72it/s]




100%|██████████| 14486538/14486538 [04:01<00:00, 60109.80it/s]




100%|██████████| 13537766/13537766 [04:21<00:00, 51684.33it/s]




100%|██████████| 700244/700244 [31:54<00:00, 365.67it/s]


In [7]:
import torch

emb = torch.load('../data-loading/emb/amazon/text-embedding-ada-002/doc/candidate_emb_dict.pt')

from tqdm import tqdm

# format embedding records
emb_records = []
for k,v in tqdm(emb.items()):
  emb_records.append({"nodeId":k ,"textEmbedding": v.squeeze().tolist()})
emb_records[:10]



100%|██████████| 957192/957192 [01:33<00:00, 10266.44it/s]


[{'nodeId': 0,
  'textEmbedding': [-0.017573531717061996,
   -0.014251705259084702,
   -0.020346183329820633,
   -0.026507634669542313,
   0.009302452206611633,
   0.0024176998995244503,
   -0.0062686069868505,
   -0.024699382483959198,
   -0.013983815908432007,
   -0.04058521240949631,
   0.01153262984007597,
   0.0057897549122571945,
   0.01642160676419735,
   -0.013414551503956318,
   -0.008230895735323429,
   0.010836117900907993,
   0.016180507838726044,
   0.0014692047843709588,
   -0.026735341176390648,
   -0.0220204908400774,
   -0.020198844373226166,
   0.015242895111441612,
   -0.03929934278130531,
   0.020814990624785423,
   -0.02208746410906315,
   -0.0029752443078905344,
   0.04192465916275978,
   -0.02618616819381714,
   0.004510584287345409,
   0.0009752840851433575,
   0.010353917255997658,
   -0.005354435183107853,
   -0.026829103007912636,
   -0.04146924614906311,
   -0.013876659795641899,
   0.005089894402772188,
   0.014653538353741169,
   0.0024947181809693575,
   

In [27]:
from langchain_openai import OpenAIEmbeddings
embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")

import itertools

def chunked(it, size):
    it = iter(it)
    while True:
        chunk = tuple(itertools.islice(it, size))
        if not chunk:
            return
        yield chunk

In [36]:
from tqdm import tqdm
import pandas as pd

# create node_df
author_list = []
institution_list = []
field_of_study_list = []

authors = skb.get_node_ids_by_type('author')
for i in authors:
    author_list.append(skb[i].dictionary)

institutions = skb.get_node_ids_by_type('institution')
for i in institutions:
    institution_list.append(skb[i].dictionary)

fields_of_study = skb.get_node_ids_by_type('field_of_study')
for i in fields_of_study:
    field_of_study_list.append(skb[i].dictionary)


author_df = pd.DataFrame(author_list)
institution_df = pd.DataFrame(institution_list)
field_of_study_df = pd.DataFrame(field_of_study_list)

In [42]:
author_df

Unnamed: 0,id,mag_id,Rank,DisplayName,LastKnownAffiliationId,PaperCount,CitationCount,type,new_id
0,0,29650,16355,Alexander Wagenpfahl,25974101,16,982,author,0
1,1,39382,18658,Torben Jabben,31512782,8,21,author,1
2,2,41399,13895,Igor A. Abrikosov,102134673,409,7624,author,2
3,3,46586,15055,Lars Goerigk,165779595,47,6816,author,3
4,4,60751,16821,Yu. V. Kaletina,1313323035,45,105,author,4
...,...,...,...,...,...,...,...,...,...
1104549,1134644,3012560199,-1,-1,-1,-1,-1,author,1104549
1104550,1134645,3012560861,-1,-1,-1,-1,-1,author,1104550
1104551,1134646,3012560907,-1,-1,-1,-1,-1,author,1104551
1104552,1134647,3012561154,-1,-1,-1,-1,-1,author,1104552


In [50]:
author_names = author_df['DisplayName'].tolist()
author_names = [str(x) for x in author_names]
author_nodeids = author_df['id'].tolist()
author_names_dict = {i: name for i, name in zip(author_nodeids, author_names)}
author_names_dict

{0: 'Alexander Wagenpfahl',
 1: 'Torben Jabben',
 2: 'Igor A. Abrikosov',
 3: 'Lars Goerigk',
 4: 'Yu. V. Kaletina',
 5: 'Torsten Mayer-Gürr',
 6: 'Bao Shang-lian',
 7: 'Vedran Lekic',
 8: 'Nurettin Yamankaradeniz',
 9: 'Jutharatana Klinkaewnarong',
 10: 'Geva Arwas',
 11: 'Riet Labie',
 12: 'Sedat Alkoy',
 13: 'Bruce R. Gerratt',
 14: 'Fernando Fernández-Lázaro',
 15: 'Peter A. Gnoffo',
 16: 'David Trdlička',
 17: 'Francesc Rocadenbosch',
 18: 'Marta Mas-Torrent',
 19: 'Babak Nadjar Araabi',
 20: 'Thierry Giamarchi',
 21: 'David Rodriguez-Larrea',
 22: 'Thomas Quella',
 23: 'Mikael C. Rechtsman',
 24: 'Victor Vartanian',
 25: 'Naomi J. Halas',
 26: 'Guillem Pérez-Nadal',
 27: 'Jonathan Bergknoff',
 28: 'Consuelo Cid Tortuero',
 29: 'Severin N. Habisreutinger',
 30: 'Vikas Sudesh',
 32: 'Dmitriy A. Dikin',
 33: 'Shi Yuejiang',
 34: 'Eric Le Moal',
 35: 'Ugo Piomelli',
 36: 'Dominic de Lanauze',
 37: 'Arun M. Thalapillil',
 38: 'Gao Yuan-Ning',
 39: 'Hamed Alipour-Banaei',
 40: 'Leo P. 

In [51]:
author_embedding_dict = {}

for chunk in tqdm(chunked(author_names_dict.items(), 10000)):
    ids, names = zip(*chunk)
    name_embeddings = embedding_model.embed_documents(names)

    recs = [{'nodeId': id, 'textEmbedding': torch.tensor(name_embedding).unsqueeze(0)} for id, name_embedding in zip(ids, name_embeddings)]

    embeddings_dict = {id : name_embedding for id, name_embedding in zip(ids, name_embeddings)}
    author_embedding_dict.update(embeddings_dict)

111it [52:12, 28.22s/it]


In [52]:
len(author_embedding_dict)

1104554

In [54]:
import torch
torch.save(author_embedding_dict, 'emb/mag/text-embedding-ada-002/doc/author_embedding_dict.pt')

In [11]:
len(author_embedding[0])

1536

In [5]:
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    author_embedding = torch.load('emb/mag/text-embedding-ada-002/doc/author_embedding_dict.pt', weights_only=False)
    insert_node_embeddings(embeddings=author_embedding, embedding_name='textEmbedding', driver=driver)



100%|██████████| 1104554/1104554 [57:55<00:00, 317.84it/s]


In [68]:
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    insert_node_embeddings(embeddings=author_embedding, embedding_name='textEmbedding', driver=driver)



  2%|▏         | 22000/1104554 [01:02<51:06, 353.05it/s] 


KeyboardInterrupt: 

In [69]:
institution_df

Unnamed: 0,id,mag_id,Rank,DisplayName,PaperCount,CitationCount,type,new_id
0,0,9507,11616,Sangji University,1259,10636,institution,1104554
1,1,19722,11471,Ateneo de Manila University,1467,8557,institution,1104555
2,2,41870,11530,Instituto Militar de Engenharia,1289,9985,institution,1104556
3,3,43886,13361,Sapientia University,265,1173,institution,1104557
4,4,46017,10780,Kahramanmaraş Sütçü İmam University,1928,20112,institution,1104558
...,...,...,...,...,...,...,...,...
8696,8735,2961216182,-1,-1,-1,-1,institution,1113250
8697,8736,2972652528,-1,-1,-1,-1,institution,1113251
8698,8737,3004594783,-1,-1,-1,-1,institution,1113252
8699,8738,3005327000,-1,-1,-1,-1,institution,1113253


In [70]:
institution_names = institution_df['DisplayName'].tolist()
institution_names = [str(x) for x in institution_names]
institution_nodeids = institution_df['id'].tolist()
institution_names_dict = {i: name for i, name in zip(institution_nodeids, institution_names)}

In [71]:
institution_embedding_dict = {}

for chunk in tqdm(chunked(institution_names_dict.items(), 10000)):
    ids, names = zip(*chunk)
    name_embeddings = embedding_model.embed_documents(names)

    embeddings_dict = {id : name_embedding for id, name_embedding in zip(ids, name_embeddings)}
    institution_embedding_dict.update(embeddings_dict)

1it [00:24, 24.46s/it]


In [72]:
torch.save(institution_embedding_dict, 'emb/mag/text-embedding-ada-002/doc/institution_emb_dict.pt')

In [73]:
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    institution_embedding = torch.load('emb/mag/text-embedding-ada-002/doc/institution_emb_dict.pt', weights_only=False)
    insert_node_embeddings(embeddings=institution_embedding, embedding_name='textEmbedding', driver=driver)



100%|██████████| 8701/8701 [00:24<00:00, 354.29it/s]


In [74]:
field_of_study_df

Unnamed: 0,id,mag_id,Rank,DisplayName,Level,PaperCount,CitationCount,type,new_id
0,0,4250,15252,Sign function,3,820,7533,field_of_study,1113255
1,1,12843,11980,Gravitational singularity,2,29654,338027,field_of_study,1113256
2,2,20288,11787,Superstring theory,3,6772,160417,field_of_study,1113257
3,3,37253,12134,Complete intersection,2,2428,27280,field_of_study,1113258
4,4,39854,13346,Torque converter,3,15356,91481,field_of_study,1113259
...,...,...,...,...,...,...,...,...,...
59464,59960,2994609373,-1,-1,-1,-1,-1,field_of_study,1172719
59465,59961,2994609795,-1,-1,-1,-1,-1,field_of_study,1172720
59466,59962,2994609957,-1,-1,-1,-1,-1,field_of_study,1172721
59467,59963,2994610206,-1,-1,-1,-1,-1,field_of_study,1172722


In [76]:
field_of_study_names = field_of_study_df['DisplayName'].tolist()
field_of_study_names = [str(x) for x in field_of_study_names]
field_of_study_nodeids = field_of_study_df['id'].tolist()
field_of_study_names_dict = {i: name for i, name in zip(field_of_study_nodeids, field_of_study_names)}

In [77]:
field_of_study_embedding_dict = {}

for chunk in tqdm(chunked(field_of_study_names_dict.items(), 10000)):
    ids, names = zip(*chunk)
    name_embeddings = embedding_model.embed_documents(names)

    embeddings_dict = {id : name_embedding for id, name_embedding in zip(ids, name_embeddings)}
    field_of_study_embedding_dict.update(embeddings_dict)

6it [02:50, 28.40s/it]


In [78]:
torch.save(field_of_study_embedding_dict, 'emb/mag/text-embedding-ada-002/doc/field_of_study_embedding_dict.pt')

In [80]:
type(field_of_study_embedding)

dict

In [6]:
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    field_of_study_embedding = torch.load('emb/mag/text-embedding-ada-002/doc/field_of_study_embedding_dict.pt', weights_only=False)
    insert_node_embeddings(embeddings=field_of_study_embedding, embedding_name='textEmbedding', driver=driver)



100%|██████████| 59469/59469 [03:16<00:00, 302.98it/s]


In [8]:
abstract_embeddings = torch.load('./mag/text-embedding-ada-002/doc/candidate_emb_dict.pt', weights_only=False)
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    insert_node_embeddings(embeddings=abstract_embeddings, embedding_name='textEmbedding', driver=driver)



100%|██████████| 700244/700244 [35:54<00:00, 325.08it/s]


In [16]:
# create vector index

with GraphDatabase.driver(NEO4J_URI,
                          auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
  driver.execute_query('''
  CREATE VECTOR INDEX text_embeddings IF NOT EXISTS FOR (n:_Entity_) ON (n.textEmbedding)
  OPTIONS {indexConfig: {
  `vector.dimensions`: toInteger($dimension),
  `vector.similarity_function`: 'cosine'
  }}''', parameters_={'dimension': 1536})
  driver.execute_query('CALL db.awaitIndex("text_embeddings", 500)')

ClientError: {code: Neo.ClientError.Procedure.ProcedureTimedOut} {message: Index on 'Index( id=4, name='text_embeddings', type='VECTOR', schema=(:_Entity_ {textEmbedding}), indexProvider='vector-2.0' )' did not come online within 300 SECONDS}

In [3]:
# Create reltype_emb_dict.py

In [4]:
import re

def format_rel_type(s):

  return re.sub('[^0-9A-Z]+', '_', s.upper())

In [17]:
from langchain_openai import OpenAIEmbeddings
embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")
reltype_emb = {format_rel_type(v): embedding_model.embed_query(v) for k,v in  skb.edge_type_dict.items()}
# dict_keys currently is: dict_keys(['AUTHOR_AFFILIATED_WITH_INSTITUTION', 'PAPER_CITES_PAPER', 'PAPER_HAS_TOPIC_FIELD_OF_STUDY', 'AUTHOR_WRITES_PAPER'])
# Remove the source node and target node type from the key names


In [7]:
reltype_emb.keys()

dict_keys(['AUTHOR_AFFILIATED_WITH_INSTITUTION', 'PAPER_CITES_PAPER', 'PAPER_HAS_TOPIC_FIELD_OF_STUDY', 'AUTHOR_WRITES_PAPER'])

In [19]:
# Mapping for renaming keys
key_mapping = {
    'AUTHOR_AFFILIATED_WITH_INSTITUTION': 'AFFILIATED_WITH',
    'PAPER_CITES_PAPER': 'CITES',
    'PAPER_HAS_TOPIC_FIELD_OF_STUDY': 'HAS_TOPIC',
    'AUTHOR_WRITES_PAPER': 'WRITES',
}


In [20]:
renamed_data = {key_mapping[k]: v for k, v in reltype_emb.items()}

In [21]:
renamed_data.keys()

dict_keys(['AFFILIATED_WITH', 'CITES', 'HAS_TOPIC', 'WRITES'])

In [22]:
import torch

torch.save(renamed_data, 'emb/mag/text-embedding-ada-002/doc/reltype_emb_dict.pt')