In [6]:
import json
import glob
import os
import pickle
from collections import defaultdict
import logging
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Set
import numpy as np
from pathlib import Path

class GNNDataBuilder:
    """
    Builds and saves GNN-ready data structure from academic papers.
    Node features include topic statistics and one-hot encoded hierarchical information.
    """
    
    def __init__(self, 
                 min_topic_score: float = 0.3,
                 data_folder: str = "raw_data",
                 processed_folder: str = "processed_data"):
        self.min_topic_score = min_topic_score
        self.data_folder = data_folder
        self.processed_folder = processed_folder
        self.logger = self._setup_logger()
        
        # Create directories
        os.makedirs(processed_folder, exist_ok=True)
        
        # Mapping dictionaries for encoding
        self.topic_to_idx = {}
        self.domain_to_idx = {}
        self.field_to_idx = {}
        self.subfield_to_idx = {}
        
        # Reverse mappings for decoding
        self.idx_to_topic = {}
        self.idx_to_domain = {}
        self.idx_to_field = {}
        self.idx_to_subfield = {}
        
        self.metadata = {
            'min_topic_score': min_topic_score,
            'version': '1.0',
            'timestamp': None,
            'stats': {
                'num_topics': 0,
                'num_domains': 0,
                'num_fields': 0,
                'num_subfields': 0
            }
        }

    def build_category_indices(self, papers: List[dict]):
        """Build indices for all categories (domains, fields, subfields, topics)"""
        self.logger.info(f"Starting to process {len(papers)} papers")
        
        # Collect unique categories
        domains = set()
        fields = set()
        subfields = set()
        topics = {}  # topic_id -> topic_info
        
        for paper in papers:
            for topic in paper.get('topics', []):
                if topic.get('score', 0) >= self.min_topic_score:
                    # Extract hierarchical information
                    domain_info = topic.get('domain', {})
                    field_info = topic.get('field', {})
                    subfield_info = topic.get('subfield', {})
                    
                    # Store complete URLs as identifiers
                    topic_id = topic['id']  # This is the full URL
                    domain_id = domain_info.get('id', '')
                    field_id = field_info.get('id', '')
                    subfield_id = subfield_info.get('id', '')
                    
                    domains.add((domain_id, domain_info.get('display_name', '')))
                    fields.add((field_id, field_info.get('display_name', '')))
                    subfields.add((subfield_id, subfield_info.get('display_name', '')))
                    
                    topics[topic_id] = {
                        'display_name': topic.get('display_name', ''),
                        'domain': domain_info.get('display_name', ''),
                        'field': field_info.get('display_name', ''),
                        'subfield': subfield_info.get('display_name', '')
                    }
            
        # Create indices using the full URLs
        self.domain_to_idx = {id_: idx for idx, (id_, _) in enumerate(sorted(domains))}
        self.field_to_idx = {id_: idx for idx, (id_, _) in enumerate(sorted(fields))}
        self.subfield_to_idx = {id_: idx for idx, (id_, _) in enumerate(sorted(subfields))}
        self.topic_to_idx = {id_: idx for idx, id_ in enumerate(sorted(topics.keys()))}
        
        # Create reverse mappings
        self.idx_to_domain = {idx: name for (_, name), idx in zip(sorted(domains), range(len(domains)))}
        self.idx_to_field = {idx: name for (_, name), idx in zip(sorted(fields), range(len(fields)))}
        self.idx_to_subfield = {idx: name for (_, name), idx in zip(sorted(subfields), range(len(subfields)))}
        self.idx_to_topic = {idx: {'id': id_, 'info': topics[id_]} 
                            for id_, idx in self.topic_to_idx.items()}
        
        # Update metadata
        self.metadata['stats'].update({
            'num_domains': len(domains),
            'num_fields': len(fields),
            'num_subfields': len(subfields),
            'num_topics': len(topics)
        })
        
        self.logger.info(f"Found {len(domains)} domains, {len(fields)} fields, "
                        f"{len(subfields)} subfields, {len(topics)} topics")
    
    def _setup_logger(self):
        """Setup logging configuration"""
        logger = logging.getLogger(__name__)
        logger.setLevel(logging.INFO)
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        return logger

    
    def build_gnn_data(self, papers: List[dict]) -> Dict:
        """
        Build complete GNN-ready data structure from papers
        
        Args:
            papers (List[dict]): List of papers with topic information
            
        Returns:
            Dict: Complete GNN data structure containing:
                - node_features: [num_topics x num_features] matrix 
                - edge_index: [2 x num_edges] matrix
                - edge_features: [num_edges x 6] matrix
                - metadata and mappings
        """
        # Get dimensions
        num_topics = len(self.topic_to_idx)
        num_domains = len(self.domain_to_idx)
        num_fields = len(self.field_to_idx)
        num_subfields = len(self.subfield_to_idx)
        
        self.logger.info(f"Building GNN data with dimensions: "
                        f"{num_topics} topics, {num_domains} domains, "
                        f"{num_fields} fields, {num_subfields} subfields")
        
        # Initialize node features
        total_features = 3 + num_domains + num_fields + num_subfields
        node_features = np.zeros((num_topics, total_features))
        
        # Track topic statistics
        topic_papers = defaultdict(int)  # topic_idx -> count
        topic_scores = defaultdict(float)  # topic_idx -> accumulated score
        topic_collabs = defaultdict(int)  # topic_idx -> collaboration count
        
        # Track edge information
        edge_dict = {}  # (idx1, idx2) -> {weight, papers, ...}
        
        # Process papers
        paper_count = 0
        valid_edges = 0
        
        for paper in papers:
            paper_count += 1
            if paper_count % 10000 == 0:
                self.logger.info(f"Processed {paper_count} papers...")
            
            # Get valid topics for this paper
            topics = [t for t in paper.get('topics', []) 
                     if t.get('score', 0) >= self.min_topic_score 
                     and t['id'] in self.topic_to_idx]
            
            # Update node statistics
            for topic in topics:
                topic_idx = self.topic_to_idx[topic['id']]
                topic_papers[topic_idx] += 1
                topic_scores[topic_idx] += topic['score']
            
            # Process topic pairs to create edges
            for i in range(len(topics)):
                for j in range(i + 1, len(topics)):
                    topic1 = topics[i]
                    topic2 = topics[j]
                    
                    # Get topic indices
                    idx1 = self.topic_to_idx[topic1['id']]
                    idx2 = self.topic_to_idx[topic2['id']]
                    
                    # Ensure consistent ordering
                    if idx1 > idx2:
                        idx1, idx2 = idx2, idx1
                        topic1, topic2 = topic2, topic1
                    
                    # Calculate edge weight
                    weight = topic1['score'] * topic2['score']
                    
                    # Update edge information
                    edge_key = (idx1, idx2)
                    if edge_key not in edge_dict:
                        edge_dict[edge_key] = {
                            'weight': weight,
                            'papers': {paper['id']},
                            'topic1': topic1,
                            'topic2': topic2
                        }
                        valid_edges += 1
                    else:
                        edge_dict[edge_key]['weight'] += weight
                        edge_dict[edge_key]['papers'].add(paper['id'])
                    
                    # Update collaboration counts
                    topic_collabs[idx1] += 1
                    topic_collabs[idx2] += 1
        
        self.logger.info(f"Found {valid_edges} unique topic connections")
        
        # Build node features
        self.logger.info("Building node features...")
        for topic_idx in range(num_topics):
            topic_info = self.idx_to_topic[topic_idx]['info']
            papers_count = topic_papers[topic_idx]
            
            # Basic features [papers, avg_score, total_collabs]
            node_features[topic_idx, 0] = papers_count
            node_features[topic_idx, 1] = (topic_scores[topic_idx] / papers_count 
                                         if papers_count > 0 else 0)
            node_features[topic_idx, 2] = topic_collabs[topic_idx]
            
            # One-hot encodings for hierarchical information
            domain_name = topic_info['domain']
            field_name = topic_info['field']
            subfield_name = topic_info['subfield']
            
            # Add domain one-hot
            for domain_id, idx in self.domain_to_idx.items():
                if domain_name in domain_id:
                    node_features[topic_idx, 3 + idx] = 1
                    break
            
            # Add field one-hot
            for field_id, idx in self.field_to_idx.items():
                if field_name in field_id:
                    node_features[topic_idx, 3 + num_domains + idx] = 1
                    break
            
            # Add subfield one-hot
            for subfield_id, idx in self.subfield_to_idx.items():
                if subfield_name in subfield_id:
                    node_features[topic_idx, 3 + num_domains + num_fields + idx] = 1
                    break
        
        # Build edge index and features
        self.logger.info("Building edge index and features...")
        edge_index = []
        edge_features = []
        
        for (idx1, idx2), edge_data in edge_dict.items():
            edge_index.append([idx1, idx2])
            
            # Get topic info
            topic1_info = self.idx_to_topic[idx1]['info']
            topic2_info = self.idx_to_topic[idx2]['info']
            
            # Calculate edge features
            weight = edge_data['weight']
            joint_papers = len(edge_data['papers'])
            
            # Compare hierarchical relationships
            same_domain = topic1_info['domain'] == topic2_info['domain']
            same_field = topic1_info['field'] == topic2_info['field']
            same_subfield = topic1_info['subfield'] == topic2_info['subfield']
            
            # Normalize weight by number of joint papers
            norm_weight = weight / joint_papers if joint_papers > 0 else 0
            
            edge_features.append([
                weight,               # Raw cooperation weight
                norm_weight,         # Weight normalized by joint papers
                joint_papers,        # Number of papers with both topics
                float(same_domain),  # Same domain indicator
                float(same_field),   # Same field indicator
                float(same_subfield) # Same subfield indicator
            ])
        
        # Convert to numpy arrays
        edge_index = np.array(edge_index).T  # Convert to [2 x num_edges] format
        edge_features = np.array(edge_features)
        
        self.logger.info(f"Created graph with {len(node_features)} nodes and {len(edge_index[0])} edges")
        
        # Return complete graph structure
        return {
                'node_features': node_features,
                'edge_index': edge_index,
                'edge_features': edge_features,
                'num_nodes': num_topics,
                'feature_info': {
                    'node_feature_dims': total_features,
                    'edge_feature_dims': edge_features.shape[1] if len(edge_features) > 0 else 0,
                    'node_feature_names': [
                        'papers', 'avg_score', 'total_collabs',
                        *[f'domain_{i}' for i in range(num_domains)],
                        *[f'field_{i}' for i in range(num_fields)],
                        *[f'subfield_{i}' for i in range(num_subfields)]
                    ],
                    'edge_feature_names': [
                        'weight', 
                        'normalized_weight', 
                        'joint_papers',
                        'same_domain', 
                        'same_field', 
                        'same_subfield'
                    ]
                },
                'mappings': {
                    'topic_to_idx': self.topic_to_idx,
                    'domain_to_idx': self.domain_to_idx,
                    'field_to_idx': self.field_to_idx,
                    'subfield_to_idx': self.subfield_to_idx,
                    'idx_to_topic': self.idx_to_topic,
                    'idx_to_domain': self.idx_to_domain,
                    'idx_to_field': self.idx_to_field,
                    'idx_to_subfield': self.idx_to_subfield
                }
            }

    def save_gnn_data(self, gnn_data: Dict, filename: str = 'gnn_data') -> None:
        """Save GNN data structure to disk and print summary"""
        self.metadata['timestamp'] = datetime.now().isoformat()
        
        save_data = {
            'gnn_data': gnn_data,
            'metadata': self.metadata
        }
        
        # Save as both pickle and npz
        pickle_path = Path(self.processed_folder) / f'{filename}.pkl'
        npz_path = Path(self.processed_folder) / f'{filename}.npz'
        
        # Save complete data as pickle
        with open(pickle_path, 'wb') as f:
            pickle.dump(save_data, f)
        
        # Save numerical data as npz
        np.savez(
            npz_path,
            node_features=gnn_data['node_features'],
            edge_index=gnn_data['edge_index'],
            edge_features=gnn_data['edge_features']
        )
        
        self.logger.info(f"Saved GNN data to {pickle_path} and {npz_path}")
        
        # Print correct summary
        print("\nGNN Data Summary:")
        print(f"Number of nodes (topics): {gnn_data['num_nodes']}")
        print(f"Number of edges: {len(gnn_data['edge_index'][0])}")  # Use edge_index length
        print(f"Node feature dimensions: {gnn_data['feature_info']['node_feature_dims']}")
        print(f"Edge feature dimensions: {gnn_data['feature_info']['edge_feature_dims']}")
        
        # Print additional statistics
        print("\nFeature Breakdowns:")
        print("Node features:")
        for i, name in enumerate(gnn_data['feature_info']['node_feature_names'][:3]):
            print(f"- {name}: mean = {np.mean(gnn_data['node_features'][:, i]):.2f}")
        
        print("\nEdge features:")
        for i, name in enumerate(gnn_data['feature_info']['edge_feature_names']):
            print(f"- {name}: mean = {np.mean(gnn_data['edge_features'][:, i]):.2f}")

    def load_gnn_data(self, filename: str = 'gnn_data') -> Optional[Dict]:
        """Load previously saved GNN data"""
        pickle_path = Path(self.processed_folder) / f'{filename}.pkl'
        
        if pickle_path.exists():
            try:
                with open(pickle_path, 'rb') as f:
                    data = pickle.load(f)
                
                if data['metadata']['min_topic_score'] == self.min_topic_score:
                    self.logger.info(f"Loaded GNN data from {pickle_path}")
                    return data['gnn_data']
                
            except Exception as e:
                self.logger.error(f"Error loading GNN data: {str(e)}")
        
        return None

    
def load_papers(data_folder: str) -> List[dict]:
    """Load papers from JSON files"""
    pattern = os.path.join(data_folder, "highly_cited_articles_2024*.json")
    files = sorted(glob.glob(pattern))
    
    if not files:
        raise FileNotFoundError(f"No JSON files found in {data_folder}")
    
    all_papers = []
    for file_path in files:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                # Check if this is a single paper or list of papers
                data = json.load(f)
                if isinstance(data, dict):
                    all_papers.append(data)
                elif isinstance(data, list):
                    all_papers.extend(data)
                else:
                    logging.warning(f"Unexpected data format in {file_path}")
                    
                logging.info(f"Loaded {len(all_papers)} papers from {file_path}")
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from {file_path}: {str(e)}")
            continue
        except Exception as e:
            logging.error(f"Error reading {file_path}: {str(e)}")
            continue
    
    logging.info(f"Total papers loaded: {len(all_papers)}")
    return all_papers
    
# Example usage
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    
    try:
        # Initialize builder
        gnn_builder = GNNDataBuilder(min_topic_score=0.3)
        
        # Load papers
        papers = load_papers("openalex_topics_results")
        
        # Build indices
        gnn_builder.build_category_indices(papers)
        
        # Build GNN data
        gnn_data = gnn_builder.build_gnn_data(papers)
        
        # Save data
        gnn_builder.save_gnn_data(gnn_data)
        
        # Print summary
        print("\nGNN Data Summary:")
        print(f"Number of nodes (topics): {gnn_data['num_nodes']}")
        print(f"Number of edges: {len(gnn_data['edge_index'])}")
        print(f"Node feature dimensions: {gnn_data['feature_info']['node_feature_dims']}")
        print(f"Edge feature dimensions: {gnn_data['feature_info']['edge_feature_dims']}")


        
    except Exception as e:
        logging.error(f"Error in processing: {str(e)}")
        raise


Graph Structure Analysis:
Total nodes: 3402
Total edges: 30143

Top 5 Most Connected Topics:
- Gut microbiota and health
  Connections: 158
  Field: Biochemistry, Genetics and Molecular Biology
  Subfield: Molecular Biology
- Artificial Intelligence in Healthcare and Education
  Connections: 152
  Field: Medicine
  Subfield: Health Informatics
- Machine Learning in Materials Science
  Connections: 141
  Field: Materials Science
  Subfield: Materials Chemistry
- Epigenetics and DNA Methylation
  Connections: 129
  Field: Biochemistry, Genetics and Molecular Biology
  Subfield: Molecular Biology
- Single-cell and spatial transcriptomics
  Connections: 121
  Field: Biochemistry, Genetics and Molecular Biology
  Subfield: Molecular Biology

Top 5 Strongest Topic Collaborations:
- Advancements in Battery Materials <-> Advanced Battery Materials and Technologies
  Weight: 3999.097
  Joint Papers: 239.0
  Fields: Engineering <-> Engineering
- Advancements in Battery Materials <-> Advanced Ba

INFO:root:Loaded 1000 papers from openalex_topics_results/highly_cited_articles_2024_20250221_222754.json
INFO:root:Loaded 3000 papers from openalex_topics_results/highly_cited_articles_2024_20250221_222818.json
INFO:root:Loaded 6000 papers from openalex_topics_results/highly_cited_articles_2024_20250221_222842.json
INFO:root:Loaded 10000 papers from openalex_topics_results/highly_cited_articles_2024_20250221_222907.json
INFO:root:Loaded 15000 papers from openalex_topics_results/highly_cited_articles_2024_20250221_222933.json
INFO:root:Loaded 21000 papers from openalex_topics_results/highly_cited_articles_2024_20250221_223000.json
INFO:root:Loaded 28000 papers from openalex_topics_results/highly_cited_articles_2024_20250221_223026.json
INFO:root:Loaded 36000 papers from openalex_topics_results/highly_cited_articles_2024_20250221_223052.json
INFO:root:Loaded 45000 papers from openalex_topics_results/highly_cited_articles_2024_20250221_223119.json
INFO:root:Loaded 55000 papers from opena


GNN Data Summary:
Number of nodes (topics): 3402
Number of edges: 30143
Node feature dimensions: 279
Edge feature dimensions: 6

Feature Breakdowns:
Node features:
- papers: mean = 346.31
- avg_score: mean = 0.97
- total_collabs: mean = 679.11

Edge features:
- weight: mean = 37.47
- normalized_weight: mean = 13.51
- joint_papers: mean = 2.63
- same_domain: mean = 0.68
- same_field: mean = 0.42
- same_subfield: mean = 0.14

GNN Data Summary:
Number of nodes (topics): 3402
Number of edges: 2
Node feature dimensions: 279
Edge feature dimensions: 6


NameError: name 'analyze_graph_structure' is not defined

In [7]:
def analyze_graph_structure(gnn_data: Dict):
        """
        Analyze the graph structure to verify edges and show key insights
        """
        node_features = gnn_data['node_features']
        edge_index = gnn_data['edge_index']
        edge_features = gnn_data['edge_features']
        mappings = gnn_data['mappings']
        
        print("\nGraph Structure Analysis:")
        print(f"Total nodes: {gnn_data['num_nodes']}")
        print(f"Total edges: {len(edge_index[0])}")
        
        # 1. Analyze node connectivity
        node_connections = defaultdict(int)
        for source, target in edge_index.T:  # Transpose to get [source, target] pairs
            node_connections[source] += 1
            node_connections[target] += 1
        
        # Find most connected topics
        most_connected = sorted(node_connections.items(), key=lambda x: x[1], reverse=True)[:5]
        print("\nTop 5 Most Connected Topics:")
        for node_idx, connection_count in most_connected:
            topic_info = mappings['idx_to_topic'][node_idx]['info']
            print(f"- {topic_info['display_name']}")
            print(f"  Connections: {connection_count}")
            print(f"  Field: {topic_info['field']}")
            print(f"  Subfield: {topic_info['subfield']}")
        
        # 2. Analyze strongest collaborations
        edge_strengths = []
        for i, (source, target) in enumerate(edge_index.T):
            source_info = mappings['idx_to_topic'][source]['info']
            target_info = mappings['idx_to_topic'][target]['info']
            
            edge_strengths.append({
                'source': source_info['display_name'],
                'target': target_info['display_name'],
                'weight': edge_features[i][0],  # Raw weight
                'joint_papers': edge_features[i][2],  # Number of joint papers
                'same_field': bool(edge_features[i][4]),  # Same field indicator
                'fields': (source_info['field'], target_info['field'])
            })
        
        # Sort by weight
        strongest_edges = sorted(edge_strengths, key=lambda x: x['weight'], reverse=True)[:5]
        print("\nTop 5 Strongest Topic Collaborations:")
        for edge in strongest_edges:
            print(f"- {edge['source']} <-> {edge['target']}")
            print(f"  Weight: {edge['weight']:.3f}")
            print(f"  Joint Papers: {edge['joint_papers']}")
            print(f"  Fields: {edge['fields'][0]} <-> {edge['fields'][1]}")
        
        # 3. Analyze field-level connections
        field_connections = defaultdict(lambda: defaultdict(int))
        for i, (source, target) in enumerate(edge_index.T):
            source_field = mappings['idx_to_topic'][source]['info']['field']
            target_field = mappings['idx_to_topic'][target]['info']['field']
            if source_field != target_field:
                if source_field > target_field:
                    source_field, target_field = target_field, source_field
                field_connections[source_field][target_field] += 1
        
        print("\nTop 5 Field-Level Collaborations:")
        field_pairs = [(f1, f2, count) 
                      for f1, connections in field_connections.items() 
                      for f2, count in connections.items()]
        top_field_pairs = sorted(field_pairs, key=lambda x: x[2], reverse=True)[:5]
        
        for field1, field2, count in top_field_pairs:
            print(f"- {field1} <-> {field2}: {count} connections")
    
# Usage
with open('processed_data/gnn_data.pkl', 'rb') as f:
    saved_data = pickle.load(f)
    gnn_data = saved_data['gnn_data']

analyze_graph_structure(gnn_data)



Graph Structure Analysis:
Total nodes: 3402
Total edges: 30143

Top 5 Most Connected Topics:
- Gut microbiota and health
  Connections: 158
  Field: Biochemistry, Genetics and Molecular Biology
  Subfield: Molecular Biology
- Artificial Intelligence in Healthcare and Education
  Connections: 152
  Field: Medicine
  Subfield: Health Informatics
- Machine Learning in Materials Science
  Connections: 141
  Field: Materials Science
  Subfield: Materials Chemistry
- Epigenetics and DNA Methylation
  Connections: 129
  Field: Biochemistry, Genetics and Molecular Biology
  Subfield: Molecular Biology
- Single-cell and spatial transcriptomics
  Connections: 121
  Field: Biochemistry, Genetics and Molecular Biology
  Subfield: Molecular Biology

Top 5 Strongest Topic Collaborations:
- Advancements in Battery Materials <-> Advanced Battery Materials and Technologies
  Weight: 3999.097
  Joint Papers: 239.0
  Fields: Engineering <-> Engineering
- Advancements in Battery Materials <-> Advanced Ba