In [None]:
import pickle
import torch 
import json
import pandas as pd
from ResearchGraphDataset import * 

In [None]:
df = pd.read_json('df.json') 
df = df[df['publication_year']>=2020]

with open('splits_New.pkl', 'rb') as f: # 划分文件 
	splits = pickle.load(f)

with open( f'train_samples_2023_Original.pkl', 'rb') as f: # 获取的最优样本
    best_sample = pickle.load(f)

dataset = ResearchGraphDataset(df,splits, max_authors = 2) 
dataset.splits = splits

  _torch_pytree._register_pytree_node(
Processing papers: 100%|██████████| 12209/12209 [00:01<00:00, 6290.17it/s]
Processing collaborations: 12209it [00:01, 11373.10it/s]


In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [41]:
error_list = []

# 

In [None]:
def get_neighbors(graph: dgl.DGLGraph, target_pair: tuple) -> list:
    """获取目标节点对的二跳邻居ID集合"""
    a, b = target_pair

    a_subg = dgl.khop_in_subgraph(graph, a, k=2)[0]
    b_subg = dgl.khop_in_subgraph(graph, b, k=2)[0]

    a_neighbors = set(a_subg.ndata[dgl.NID].tolist())
    b_neighbors = set(b_subg.ndata[dgl.NID].tolist())
    return sorted(a_neighbors- {b}, key=lambda x: (len(str(x)), str(x))),sorted(b_neighbors- {a}, key=lambda x: (len(str(x)), str(x)))

def generate_focused_graph_desc(, graph: dgl.DGLGraph, target_pair: tuple) -> str:
    #初始化
    a_id, b_id = target_pair
    current_year = graph.edata['year'].max().item() 
    active_authors,author_id_to_idx,idx_to_author_id = dataset.generate_authorid2idx(current_year)
    a_idx,b_idx = author_id_to_idx[a_id],author_id_to_idx[b_id]
    desc = []
    #合作情况
    collab = dataset.author_metadata[a_id]['collaborators'].get(b_id, 0) # 在a的合作者中寻找b

    if collab >= 0:
        desc.append(f"历史合作: 与对方合作{collab}篇")
    else:
        desc.append(f"历史合作: 与对方合作0篇")

    dist_vec = dgl.shortest_dist(graph, root=a_idx)

    if dist_vec[b_idx] >0:
        desc.append(f"与对方路径距离: {int(dist_vec[b_idx].item())}跳")
    else:
        desc.append("与对方路径不可达")
    return " | ".join(desc)


def build_agent_context(graph: dgl.DGLGraph, agent_id: str, current_year: int) -> dict:
    '''构建agent的自我认知上下文'''
    current_year = graph.edata['year'].max().item() 
    active_authors,author_id_to_idx,idx_to_author_id = dataset.generate_authorid2idx(current_year)
    agent_idx = author_id_to_idx[agent_id]
    return {
        'degree_centrality': graph.ndata['degree_cent'][agent_idx].item(),
        'constraint': graph.ndata['constraint'][agent_idx].item(),
        'papers_num': graph.ndata['paper_count'][agent_idx].item(),
        'citations': graph.ndata['citations'][agent_idx].item(),
        'text': dataset.get_raw_data(agent_id)['raw_text']
    }


In [None]:
splits = dataset.splits[2023]
active_authors,author_id_to_idx,idx_to_author_id = dataset.generate_authorid2idx(2023)
train_g = dataset._add_topological_features(splits['train'])

In [None]:
data = {}
current_year=2023

active_authors,author_id_to_idx,idx_to_author_id = dataset.generate_authorid2idx(current_year)

for src_nodes, dst_nodes, batch_labels in tqdm(best_sample):
    info = {}
    src_nodes = src_nodes.to(device)
    dst_nodes = dst_nodes.to(device)
    batch_labels = batch_labels.float().to(device)
    key = (src_nodes.item(), dst_nodes.item()) 
    if key not in data:
            try:
                a_idx = key[0] if isinstance(key[0], torch.Tensor) else int(key[0])
                b_idx = key[1] if isinstance(key[1], torch.Tensor) else int(key[1])
                a_id,b_id = idx_to_author_id[a_idx], idx_to_author_id[b_idx]
                a_ctx = build_agent_context(train_g,a_id)
                b_ctx = build_agent_context(train_g,b_id)
                a_neighbors,b_neighbors = get_neighbors(train_g,key)
                
                info['dist'] = generate_focused_graph_desc(train_g, (a_id,b_id))
                info['a_degree_centrality'] = a_ctx['degree_centrality']
                info['b_degree_centrality'] = b_ctx['degree_centrality']
                info['a_constraint'] = a_ctx['constraint']
                info['b_constraint'] = b_ctx['constraint']
                info['a_papers_num'] = a_ctx['papers_num']
                info['b_papers_num'] = b_ctx['papers_num']
                info['a_citations'] = a_ctx['citations']
                info['b_citations'] = b_ctx['citations']
                info['a_neighbors'] = a_neighbors
                info['b_neighbors'] = b_neighbors
                info['a_text'] = a_ctx['text']
                info['b_text'] = b_ctx['text']
            except:
                error_list.append(key)
                continue
            
            data[key] = info
    else:
        continue 

100%|██████████| 2000/2000 [1:14:24<00:00,  2.23s/it]


In [54]:
new_data= {}
for row_key, cols in data.items():
    row_str = str(row_key)
    for col_key, value in cols.items():
        new_data.setdefault(col_key, {})[row_str] = value
with open(f'data.json', "w") as f:
    json.dump(new_data, f)