In [32]:
import pickle
import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm
from glob import glob
import torch

In [33]:
def load_data():
    with open('../../data/batch/B_recent_10_merged_khops.pkl', 'rb') as f:
        subgraph_ids = pickle.load(f)
        for i in range(len(subgraph_ids)):
            subgraph_ids[i] = (subgraph_ids[i][0], list(subgraph_ids[i][1]))
            
    with open('../../data/preprocessed/event_index.pkl', 'rb') as f:
        event_index = pickle.load(f)
        event_index = reverse_event_index(event_index)
    return subgraph_ids, event_index

In [34]:
def reverse_event_index(event_index):
    new_index = {}
    for file, ids in event_index.items():
        for eid in ids:
            new_index[eid] = file
            
    return new_index

In [37]:
def save_batch(graphs, start_index):
    for i, graph in enumerate(graphs):
        file_name = f'batch_{str(start_index + i).zfill(5)}.pkl'
        with open('../../data/graphs/batches/' + file_name, 'wb') as f:
            pickle.dump(graph, f)

def batch_generate(subgraph_ids, event_index, batch_size):
    """
    Generates subgraphs for training
    :param subgraph_ids: list of tuples (target_ids, neighbor_ids)
    :param event_index: maps event ids -> file names
    :param batch_size: number of graphs to construct at the same time
    :return: 
    """
    
    # split subgraph_ids into batches
    batched_ids = [subgraph_ids[i:i+batch_size] for i in range(0, len(subgraph_ids), batch_size)]
    
    concept_llms = load_concept_llm_files()
    
    # generate batches
    i = 0
    for batch in batched_ids:
        graphs = []
        
        for target_ids, neighbor_ids in tqdm(batch, desc='Generating batches'):
            graph = generate_subgraph(target_ids, neighbor_ids, event_index, concept_llms)
            graphs.append(graph)
            return graphs
            
        # save batch
        save_batch(graphs, i)
        i += batch_size

In [36]:
def get_files_to_idx(t_ids, n_ids, event_index):
    """
    Maps event ids to file names
    :param t_ids: list of target event ids
    :param n_ids: list of neighbor event ids
    :param event_index: maps event ids -> file names
    :return: 
    """
    file_to_idx = {}
    for ids in [t_ids, n_ids]:
        for eid in ids:
            # skip events not in the dataset
            if eid not in event_index:
                continue
                
            file_name = event_index[eid]
            if file_name not in file_to_idx:
                file_to_idx[file_name] = set()
                
            file_to_idx[file_name].add(eid)
            
def load_concept_llm_files():
    files = glob('../../data/text/concept_embeds/concept_embeds_*.pkl')
    llm_files = {}
    for file in files:
        with open(file, 'rb') as f:
            llm_file = pickle.load(f)
            llm_files[file] = llm_file
    
    return llm_files
            
def load_files(files_to_idx):
    """
    Loads files into memory
    :param files_to_idx: maps file names -> event ids
    :return: Returns two dictionaries, one for the source files and one for the LLM files
    """
    src_files = {}
    llm_files = {}
    for file_name, ids in files_to_idx.items():
        with open(f'../../data/preprocessed/{file_name}.pkl', 'rb') as f:
            file = pickle.load(f)
            src_files[file_name] = file
        with open(f'../../data/text/embedded/{file_name}.pkl', 'rb') as f:
            file = pickle.load(f)
            llm_files[file_name] = file
            
    return src_files, llm_files

def add_event(graph, event_id, e_type, all_nodes, src_file, llm_file, concept_llms):
    """
    Adds an event to the graph
    :param graph:
    :param event_id: event id to add
    :param e_type: type of event ('event' or 'event_target')
    :param all_nodes: list of all nodes in the graph
    :param src_file: source file
    :param llm_file: LLM file
    :param concept_llms: LLM embeddings for concepts
    :return: 
    """
    
    event = src_file[event_id]
    info = event['info']
    event_counts = info['articleCounts']['total']
    event_date = info['eventDate']
    concepts = info['concepts']
    similar = event['similarEvents']
    
    llm = llm_file[event_id]
    llm_title, llm_summary = llm['title'], llm['summary']
    
    features = np.concatenate([[event_date], llm_title, llm_summary])
    if e_type == 'event':
        features = np.concatenate([[event_counts], features])
    features = torch.from_numpy(features)
    
    target = torch.tensor([event_counts])
    
    # add node
    graph.add_node(event_id, node_type=e_type, node_feature=features, node_target=target)
    
    # add similar event edges
    for se in similar:
        se_id = se['uri']
        if se_id not in all_nodes:
            continue
        graph.add_edge(event_id, se_id, edge_type='similar')
        
    # add concepts
    for concept in concepts:
        concept_id = concept['id']
        concept_llm = torch.from_numpy(concept_llms[concept_id])
        
        graph.add_node(concept_id, node_type='concept', node_feature=concept_llm)
        graph.add_edge(concept_id, event_id, edge_type='related')
        graph.add_edge(concept_id, concept_id, edge_type='concept_self')
    

def generate_subgraph(target_ids, neighbor_ids, event_index, concept_llms):
    """
    Generates a subgraph for training
    :param target_ids: list of target ids
    :param neighbor_ids: list of neighbor ids
    :param event_index: maps event ids -> file names
    :param concept_llms: LLM embeddings for concepts
    :return: 
    """
    # get files to idx
    files_to_idx = get_files_to_idx(target_ids, neighbor_ids, event_index)
    src_files, llm_files = load_files(files_to_idx)
    all_nodes = set(target_ids + neighbor_ids)
    
    graph = nx.DiGraph()
    for file_name, ids in files_to_idx.items():
        src_file = src_files[file_name]
        llm_file = llm_files[file_name]
        
        for eid in ids:
            event_type = 'event_target' if eid in target_ids else 'event'
            add_event(graph, eid, event_type, all_nodes, src_file, llm_file, concept_llms)
    

In [38]:
def main():
    subgraph_ids, event_index = load_data()
    batch_generate(subgraph_ids, event_index, 10)

In [None]:
main()