In [1]:
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tqdm
from collections import defaultdict
import re

from src.datasets.who_is_who import WhoIsWhoDataset
from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import NodeType, EdgeType
from src.shared.graph_sampling import GraphSampling

  from tqdm.autonotebook import tqdm, trange


## Configuration

In [2]:
max_iterations = 50
add_true_authors = False
link_true_authors = True
link_co_authors = True
link_title = True
link_abstract = True
link_venue = True
link_org = True

In [3]:

db = DatabaseWrapper(database='small-graph')
db.delete_all_nodes()

2024-12-04 17:10:27,567 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-04 17:10:27,567 - DatabaseWrapper - INFO - Database ready.
2024-12-04 17:10:27,570 - DatabaseWrapper - INFO - Deleted all nodes.


In [4]:
data = WhoIsWhoDataset.parse_data()
train_data = WhoIsWhoDataset.parse_train()

co_author_overlap_threshold = 0.25
link_title_threshold = 0.7
link_abstract_threshold = 0.7
link_venue_threshold = 0.7
link_org_threshold = 0.9

link_title_k = 8
link_abstract_k = 8
link_venue_k = 8
link_org_k = 8

In [5]:
model = SentenceTransformer(
    'sentence-transformers/all-MiniLM-L6-v2',
    device='cuda'
)
print(f"Model dim: {model.get_sentence_embedding_dimension()}")



Model dim: 384


## Add publication nodes to the graph database

In [6]:
def process_batch(batch):
    if not batch[NodeType.PUBLICATION]:
        return
    title_embs = model.encode(
        [node['title'] for node in batch[NodeType.PUBLICATION]]
    )
    abstract_embs = model.encode(
        [node['abstract'] for node in batch[NodeType.PUBLICATION]]
    )
    venue_embs = model.encode(
        [node['venue'] for node in batch[NodeType.PUBLICATION]]
    )
    org_embs = model.encode(
        [node['org'] for node in batch[NodeType.PUBLICATION]]
    )
    for i, node in enumerate(batch[NodeType.PUBLICATION]):
        node['title_emb'] = title_embs[i]
        node['abstract_emb'] = abstract_embs[i]
        node['venue_emb'] = venue_embs[i]
        node['org_emb'] = org_embs[i]
        # vertically stack the embeddings
        node['feature_vec'] = list(title_embs[i]) + list(abstract_embs[i])
    db.merge_nodes(NodeType.PUBLICATION, batch[NodeType.PUBLICATION])
    batch[NodeType.PUBLICATION] = []

In [7]:
batch_nodes = defaultdict(list)
current_iteration = 0
authors_in_graph = set()

with tqdm(total=max_iterations) as pbar:
    for author_id, values in train_data.items():
        if max_iterations is not None and current_iteration >= max_iterations:
            break
            
        authors_in_graph.add(author_id)
        
        papers = values.get('normal_data', [])
        papers.extend(values.get('outliers', []))
        
        
        current_iteration += 1
        
        for paper_id in papers:
            values = data[paper_id]
            authors = values.get('authors', [])
            org = ''
            if len(authors) > 0 and 'org' in authors[0]:
                org = authors[0].get('org', '')
            paper_node = {
                'id': values['id'],
                'title': values['title'],
                'abstract': values['abstract'],
                'year': values['year'],
                'venue': values['venue'],
                'org': org
            }
            batch_nodes[NodeType.PUBLICATION].append(paper_node)
            
            if len(batch_nodes[NodeType.PUBLICATION]) % 1000 == 0:
                process_batch(batch_nodes)
                
        pbar.update(1)
        
process_batch(batch_nodes)
print("Number of authors in the graph:", len(authors_in_graph))
print("Number of publication nodes:", db.count_nodes(NodeType.PUBLICATION))

  0%|          | 0/50 [00:00<?, ?it/s]

  warn("Expected a result with a single record, "


Number of authors in the graph: 50
Number of publication nodes: 8865


## Add true author data to the nodes

In [8]:
def reverse_dict(author_dict):
    paper_to_author = {}
    for author_id, values in author_dict.items():
        normal_papers = values.get('normal_data', [])
        for paper_id in normal_papers:
            paper_to_author[paper_id] = author_id
    return paper_to_author

def add_true_authors(db: DatabaseWrapper, train_data):
    paper_id_to_author = reverse_dict(train_data)

    with tqdm(total=db.count_nodes(NodeType.PUBLICATION), desc="Merging WhoIsWho train_author.json") as pbar:
        for nodes in db.iter_nodes(NodeType.PUBLICATION, ['id']):
            for node in nodes:
                true_author_id = paper_id_to_author.get(node['id'], '')
                true_author_name = train_data.get(true_author_id, {}).get('name', '')
                
                db.merge_properties(
                    type=NodeType.PUBLICATION, 
                    node_id=node['id'], 
                    properties={'true_author_id': true_author_id, 'true_author_name': true_author_name}
                )
                pbar.update(1)

if add_true_authors:
    add_true_authors(db, train_data)

Merging WhoIsWho train_author.json:   0%|          | 0/8865 [00:00<?, ?it/s]

## Link publication nodes if they share the same true author

In [9]:
import random

def get_sample_size(num_papers):
    random_k = 3 + random.random() * 0.2
    return int(num_papers**0.5 * random_k)

def link_true_authors(db: DatabaseWrapper, train_data):
    node_type = NodeType.PUBLICATION
    edges_to_merge = []
    current_iteration = 0
    
    with tqdm(total=max_iterations, desc="Merging true-author edges") as pbar:
        for author_id, values in train_data.items():
            papers = values.get('normal_data', [])
            print(f"Number of papers: {len(papers)}")
            
            pbar.update(len(papers) + len(values.get('outliers', [])))
            if max_iterations is not None and current_iteration >= max_iterations:
                break 
            current_iteration += len(papers) + len(values.get('outliers', []))
            
            for i in range(len(papers)):
                for j in range(len(papers)):
                    if i == j:
                        continue
                    edges_to_merge.append([papers[i], papers[j]])
            
            print(f"Number of edges to merge: {len(edges_to_merge)}")
            # Randomly sample 10-40% of the edges_to_merge list
            sample_size = get_sample_size(len(edges_to_merge))
            print(f"Sample size: {sample_size}")
            edges_to_merge = random.sample(edges_to_merge, sample_size)
            print(f"Sampled number of edges to merge: {len(edges_to_merge)}")
             
            if len(edges_to_merge) > 1000:
                db.merge_edges(start_label=node_type, end_label=node_type, edge_type=EdgeType.SAME_AUTHOR, edges=edges_to_merge)
                edges_to_merge.clear()
            
            
            
        
        if edges_to_merge:
            db.merge_edges(start_label=node_type, end_label=node_type, edge_type=EdgeType.SAME_AUTHOR, edges=edges_to_merge)
            edges_to_merge.clear()
            
if link_true_authors:
    db.delete_edges(EdgeType.SAME_AUTHOR)
    link_true_authors(db, train_data)

Merging true-author edges:   0%|          | 0/50 [00:00<?, ?it/s]

Number of papers: 565
Number of edges to merge: 318660
Sample size: 1784
Sampled number of edges to merge: 1784
Number of papers: 40


## Link nodes based on co-author relationships

In [10]:
def link_co_author_network(db: DatabaseWrapper, node_type: NodeType):
    num_nodes = db.count_nodes(node_type)
    attrs = ['id']
    co_author_map = defaultdict(list)
    co_author_overlap = defaultdict(int)
        
    print(f"Linking {node_type.value} nodes based on co-authorship ...")
    with tqdm(total=num_nodes, desc=f"Progress {node_type.value} co-authorship") as pbar:
        for nodes in db.iter_nodes(node_type, attrs):
            for node in nodes:
                co_authors = [author["name"] for author in data[node['id']]['authors']]
                for author in co_authors:
                    name = author.strip()
                    name = re.sub(r'[^A-Za-z\s]', '', name)
                    name_parts = name.split()
                    if len(name_parts) == 0:
                        continue
                    
                    surname = name_parts[-1]
                    given_name_initial = (name_parts[0] if len(name_parts) > 1 else ' ')[0]
                    abbrev = f"{surname} {given_name_initial}"
                    co_author_map[abbrev].append(node['id'])
                pbar.update(1)
    
    for k, v in co_author_map.items():
        for i in range(len(v)):
            for j in range(i + 1, len(v)):
                co_author_overlap[(v[i], v[j])] += 1
                
    for k, v in co_author_overlap.items():
        total_num_authors = data[k[0]]["authors"] + data[k[1]]["authors"]
        co_author_overlap[k] = v / len(total_num_authors)
    
    print(f"Max. co-authors: {max(len(v) for v in co_author_map.values())}")
    print(f"Max. co-author overlap: {max(co_author_overlap.values())}")
    
    print("Number of co-author pairs:", len(co_author_overlap))
    print(f"Number of co-author pairs with overlap > {co_author_overlap_threshold}:", len([v for v in co_author_overlap.values() if v > co_author_overlap_threshold]))
    
    print("Merging edges ...")
    edges_to_merge = [[k[0], k[1], {'sim': v}] for k, v in co_author_overlap.items() if v > co_author_overlap_threshold]
    with tqdm(total=len(edges_to_merge), desc="Merging co-author edges") as pbar:
        for i in range(0, len(edges_to_merge), 1000):
            db.merge_edges_with_properties(start_label=node_type, end_label=node_type, edge_type=EdgeType.SIM_AUTHOR, edges=edges_to_merge[i:i+1000])
            pbar.update(1000)

In [11]:
if link_co_authors:
    link_co_author_network(db, NodeType.PUBLICATION)

Linking Publication nodes based on co-authorship ...


Progress Publication co-authorship:   0%|          | 0/8865 [00:00<?, ?it/s]

Max. co-authors: 797
Max. co-author overlap: 1.0
Number of co-author pairs: 1296615
Number of co-author pairs with overlap > 0.25: 46386
Merging edges ...


Merging co-author edges:   0%|          | 0/46386 [00:00<?, ?it/s]

## Link nodes based on cosine similarity of their embeddings

In [12]:
def link_node_attr_cosine(db: DatabaseWrapper, node_type: NodeType, vec_attr: str, edge_type: EdgeType, threshold: float = 0.7, filter_empty_original_attr: str = None, k: int = 8):
    num_nodes = db.count_nodes(node_type)
    edges = []
    attrs = ['id', vec_attr]
    if filter_empty_original_attr:
        attrs.append(filter_empty_original_attr)
        
    print(f"Linking {node_type.value} nodes based on {vec_attr} attribute ...")
    with tqdm(total=num_nodes, desc=f"Progress {node_type.value} {vec_attr}") as pbar:
        for nodes in db.iter_nodes(node_type, attrs):
            for node in nodes:
                if filter_empty_original_attr and not node[filter_empty_original_attr]:
                    pbar.update(1)
                    print(f"Skipping node {node['id']} because {filter_empty_original_attr} is empty")
                    continue

                similar_nodes = db.get_similar_nodes_vec(
                    node_type,
                    vec_attr,
                    node[vec_attr],
                    threshold,
                    k
                )
                for ix, row in similar_nodes.iterrows():
                    if row['id'] == node['id']:
                        continue
                    edges.append([node['id'], row['id']])
                    #db.merge_edge(node_type, node['id'], node_type, row['id'], edge_type, {"sim": row['sim']})
                if len(edges) > 1000:
                    print(f"Merging {len(edges)} edges ...")
                    db.merge_edges(start_label=node_type, end_label=node_type, edge_type=edge_type, edges=edges)
                    edges.clear()
                    
                pbar.update(1)
    if edges:
        db.merge_edges(start_label=node_type, end_label=node_type, edge_type=edge_type, edges=edges)

In [13]:
model_dim = model.get_sentence_embedding_dimension()

In [14]:
# Create vector index and link nodes based on cosine similarity
if link_title:
    db.create_vector_index('title_index', NodeType.PUBLICATION, 'title_emb', model_dim)
    link_node_attr_cosine(db, NodeType.PUBLICATION, 'title_emb', EdgeType.SIM_TITLE, threshold=link_title_threshold, k=link_title_k)

Linking Publication nodes based on title_emb attribute ...


Progress Publication title_emb:   0%|          | 0/8865 [00:00<?, ?it/s]

Merging 1002 edges ...
Merging 1003 edges ...
Merging 1005 edges ...
Merging 1002 edges ...
Merging 1001 edges ...
Merging 1004 edges ...
Merging 1001 edges ...
Merging 1006 edges ...
Merging 1001 edges ...
Merging 1002 edges ...
Merging 1002 edges ...
Merging 1002 edges ...
Merging 1001 edges ...
Merging 1002 edges ...
Merging 1001 edges ...


In [15]:
if link_abstract:
    db.create_vector_index('abstract_index', NodeType.PUBLICATION, 'abstract_emb', model_dim)
    link_node_attr_cosine(db, NodeType.PUBLICATION, 'abstract_emb', EdgeType.SIM_ABSTRACT, threshold=link_abstract_threshold, k=link_abstract_k)

Linking Publication nodes based on abstract_emb attribute ...


Progress Publication abstract_emb:   0%|          | 0/8865 [00:00<?, ?it/s]

Merging 1004 edges ...
Merging 1001 edges ...
Merging 1001 edges ...
Merging 1002 edges ...
Merging 1007 edges ...
Merging 1002 edges ...
Merging 1002 edges ...
Merging 1004 edges ...
Merging 1002 edges ...
Merging 1006 edges ...
Merging 1004 edges ...
Merging 1001 edges ...
Merging 1008 edges ...
Merging 1001 edges ...
Merging 1007 edges ...
Merging 1004 edges ...
Merging 1008 edges ...
Merging 1003 edges ...
Merging 1007 edges ...
Merging 1006 edges ...
Merging 1003 edges ...
Merging 1002 edges ...
Merging 1005 edges ...
Merging 1005 edges ...
Merging 1002 edges ...
Merging 1007 edges ...
Merging 1006 edges ...
Merging 1006 edges ...
Merging 1002 edges ...
Merging 1001 edges ...
Merging 1005 edges ...
Merging 1003 edges ...
Merging 1008 edges ...


In [16]:
if link_venue:
    db.create_vector_index('venue_index', NodeType.PUBLICATION, 'venue_emb', model_dim)
    link_node_attr_cosine(db, NodeType.PUBLICATION, 'venue_emb', EdgeType.SIM_VENUE, threshold=link_venue_threshold, k=link_venue_k)

Linking Publication nodes based on venue_emb attribute ...


Progress Publication venue_emb:   0%|          | 0/8865 [00:00<?, ?it/s]

Merging 1006 edges ...
Merging 1005 edges ...
Merging 1008 edges ...
Merging 1002 edges ...
Merging 1004 edges ...
Merging 1001 edges ...
Merging 1002 edges ...
Merging 1003 edges ...
Merging 1006 edges ...
Merging 1007 edges ...
Merging 1002 edges ...
Merging 1006 edges ...
Merging 1001 edges ...
Merging 1007 edges ...
Merging 1004 edges ...
Merging 1007 edges ...
Merging 1006 edges ...
Merging 1007 edges ...
Merging 1004 edges ...
Merging 1002 edges ...
Merging 1001 edges ...
Merging 1002 edges ...
Merging 1005 edges ...
Merging 1007 edges ...
Merging 1004 edges ...
Merging 1003 edges ...
Merging 1008 edges ...
Merging 1002 edges ...
Merging 1002 edges ...
Merging 1007 edges ...
Merging 1003 edges ...
Merging 1004 edges ...
Merging 1001 edges ...
Merging 1005 edges ...
Merging 1005 edges ...
Merging 1005 edges ...
Merging 1008 edges ...
Merging 1002 edges ...
Merging 1001 edges ...
Merging 1006 edges ...
Merging 1003 edges ...
Merging 1005 edges ...
Merging 1004 edges ...
Merging 100

In [17]:
if link_org:
    db.create_vector_index('org_index', NodeType.PUBLICATION, 'org_emb', model_dim)
    link_node_attr_cosine(db, NodeType.PUBLICATION, 'org_emb', EdgeType.SIM_ORG, threshold=link_org_threshold, k=link_org_k)

Linking Publication nodes based on org_emb attribute ...


Progress Publication org_emb:   0%|          | 0/8865 [00:00<?, ?it/s]

Merging 1001 edges ...
Merging 1005 edges ...
Merging 1002 edges ...
Merging 1002 edges ...
Merging 1005 edges ...
Merging 1007 edges ...
Merging 1002 edges ...
Merging 1008 edges ...
Merging 1008 edges ...
Merging 1001 edges ...
Merging 1004 edges ...
Merging 1004 edges ...
Merging 1002 edges ...
Merging 1005 edges ...
Merging 1001 edges ...
Merging 1001 edges ...
Merging 1006 edges ...
Merging 1005 edges ...
Merging 1006 edges ...
Merging 1007 edges ...
Merging 1008 edges ...
Merging 1001 edges ...
Merging 1007 edges ...
Merging 1004 edges ...
Merging 1004 edges ...
Merging 1005 edges ...
Merging 1003 edges ...
Merging 1006 edges ...
Merging 1008 edges ...
Merging 1001 edges ...
Merging 1003 edges ...
Merging 1004 edges ...
Merging 1007 edges ...
Merging 1008 edges ...
Merging 1007 edges ...
Merging 1001 edges ...
Merging 1006 edges ...
Merging 1007 edges ...
Merging 1007 edges ...
Merging 1007 edges ...
Merging 1003 edges ...
Merging 1002 edges ...
Merging 1008 edges ...
Merging 100

In [18]:
# Remove edges if according attribute is empty
if link_title:
    db.delete_edges_for_empty_attr(NodeType.PUBLICATION, EdgeType.SIM_TITLE, 'title')
    
if link_abstract:
    db.delete_edges_for_empty_attr(NodeType.PUBLICATION, EdgeType.SIM_ABSTRACT, 'abstract')
    
if link_venue:
    db.delete_edges_for_empty_attr(NodeType.PUBLICATION, EdgeType.SIM_VENUE, 'venue')

In [19]:
db.close()

2024-12-04 18:32:29,160 - DatabaseWrapper - INFO - Closing the database connection
