In [1]:
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tqdm
from collections import defaultdict

from src.datasets.who_is_who import WhoIsWhoDataset
from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import NodeType, EdgeType
from src.shared import config

  from tqdm.autonotebook import tqdm, trange


In [2]:
db = DatabaseWrapper(database='homogeneous-graph')
db.create_vector_index('title_index', NodeType.PUBLICATION, 'title_emb', 32)
db.create_vector_index('abstract_index', NodeType.PUBLICATION, 'abstract_emb', 32)
db.create_vector_index('venue_index', NodeType.PUBLICATION, 'venue_emb', 32)
data = WhoIsWhoDataset.parse_data()

2024-09-14 15:55:49,343 - DatabaseWrapper - INFO - Connecting to the database ...
2024-09-14 15:55:49,344 - DatabaseWrapper - INFO - Database ready.


In [3]:
model = SentenceTransformer(
    '../../data/models/scibert_scivocab_uncased_sentence_transformer-32dim',
    device='mps'
)

In [4]:
def process_batch(batch):
    if not batch[NodeType.PUBLICATION]:
        return
    title_embs = model.encode(
        [node['title'] for node in batch[NodeType.PUBLICATION]]
    )
    abstract_embs = model.encode(
        [node['abstract'] for node in batch[NodeType.PUBLICATION]]
    )
    venue_embs = model.encode(
        [node['venue'] for node in batch[NodeType.PUBLICATION]]
    )
    for i, node in enumerate(batch[NodeType.PUBLICATION]):
        node['title_emb'] = title_embs[i]
        node['abstract_emb'] = abstract_embs[i]
        node['venue_emb'] = venue_embs[i]
    db.merge_nodes(NodeType.PUBLICATION, batch[NodeType.PUBLICATION])
    batch[NodeType.PUBLICATION] = []

In [5]:
batch_nodes = defaultdict(list)
max_iterations = 1000
current_iteration = 0
with tqdm(total=max_iterations) as pbar:
    for author_id, values in data.items():
        if max_iterations is not None and current_iteration >= max_iterations:
            break
        current_iteration += 1
        pbar.update(1)
        
        paper_node = {
            'id': values['id'],
            'title': values['title'],
            'abstract': values['abstract'],
            'year': values['year'],
            'venue': values['venue'],
        }
        batch_nodes[NodeType.PUBLICATION].append(paper_node)
        
        if len(batch_nodes[NodeType.PUBLICATION]) % 1000 == 0:
            process_batch(batch_nodes)

process_batch(batch_nodes)

  0%|          | 0/1000 [00:00<?, ?it/s]

  warn("Expected a result with a single record, "


In [6]:
print(db.count_nodes(NodeType.PUBLICATION))

1000
