In [None]:
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tqdm
from collections import defaultdict

from src.datasets.who_is_who import WhoIsWhoDataset
from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import NodeType, EdgeType
from src.shared.neo_to_pyg import GraphSampling
from src.shared import config

In [None]:
db = DatabaseWrapper(database='homogeneous-graph')
db.create_vector_index('title_index', NodeType.PUBLICATION, 'title_emb', 32)
db.create_vector_index('abstract_index', NodeType.PUBLICATION, 'abstract_emb', 32)
db.create_vector_index('venue_index', NodeType.PUBLICATION, 'venue_emb', 32)
data = WhoIsWhoDataset.parse_data()

In [None]:
model = SentenceTransformer(
    '../data/models/scibert_scivocab_uncased_sentence_transformer-32dim',
    device='mps'
)

In [None]:
def process_batch(batch):
    if not batch[NodeType.PUBLICATION]:
        return
    title_embs = model.encode(
        [node['title'] for node in batch[NodeType.PUBLICATION]]
    )
    abstract_embs = model.encode(
        [node['abstract'] for node in batch[NodeType.PUBLICATION]]
    )
    venue_embs = model.encode(
        [node['venue'] for node in batch[NodeType.PUBLICATION]]
    )
    for i, node in enumerate(batch[NodeType.PUBLICATION]):
        node['title_emb'] = title_embs[i]
        node['abstract_emb'] = abstract_embs[i]
        node['venue_emb'] = venue_embs[i]
    db.merge_nodes(NodeType.PUBLICATION, batch[NodeType.PUBLICATION])
    batch[NodeType.PUBLICATION] = []

In [None]:
batch_nodes = defaultdict(list)
max_iterations = 10000
current_iteration = 0
with tqdm(total=max_iterations) as pbar:
    for author_id, values in data.items():
        if max_iterations is not None and current_iteration >= max_iterations:
            break
        current_iteration += 1
        pbar.update(1)
        
        paper_node = {
            'id': values['id'],
            'title': values['title'],
            'abstract': values['abstract'],
            'year': values['year'],
            'venue': values['venue'],
        }
        batch_nodes[NodeType.PUBLICATION].append(paper_node)
        
        if len(batch_nodes[NodeType.PUBLICATION]) % 1000 == 0:
            process_batch(batch_nodes)

process_batch(batch_nodes)

In [None]:
true_author_data = WhoIsWhoDataset.parse_train()
props = []
with tqdm(total=len(true_author_data.items()), desc="Merging WhoIsWho train_author.json") as pbar:
    for author_id, values in true_author_data.items():
        author_name = values['name']
        for pub_id in values['normal_data']:
            db.merge_properties(NodeType.PUBLICATION, pub_id, {'true_author': author_name})
        """
        for pub_id in values['normal_data']:
            props.append({'id': pub_id, 'properties': {'true_author': author_name}})
        pbar.update(1)
        if len(props) > 1000:
            db.merge_properties_batch(NodeType.PUBLICATION, props)
            props.clear()

    if props:
        db.merge_properties_batch(NodeType.PUBLICATION, props)
        """

In [None]:
print(db.count_nodes(NodeType.PUBLICATION))

In [None]:
gs = GraphSampling(
    node_spec=[NodeType.PUBLICATION], 
    edge_spec=[], 
    node_properties=['abstract_emb', 'title_emb', 'venue_emb']
)

In [None]:
nodes = gs.random_nodes(
    node_type=NodeType.PUBLICATION,
    node_properties=['abstract_emb', 'title_emb', 'venue_emb', 'true_author'],
    n=1000
)

for node in nodes:
    print(node)
    break