In [None]:
!pip install transformers accelerate sentencepiece --quiet

In [None]:
!pip install -U bitsandbytes --quiet

In [None]:
!pip install torch torch_geometric requests tqdm

In [None]:
import requests
import torch
from torch_geometric.data import Data
import json
import os
from tqdm.auto import tqdm
import time
import traceback

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"GPU Memory: {gpu_mem_gb:.2f} GB")
    if gpu_mem_gb < 15:
         print("Warning: GPU memory might be insufficient for Mistral-7B 4-bit.")
else:
    device = torch.device("cpu")

# Config
papers_file = 'openalex_papers_raw.json'
citations_file = 'openalex_citations_raw.json'
final_graph_file = 'openalex_subset_mistral_base_graph.pt'
llm_name = "mistralai/Mistral-7B-v0.1"
embedding_batch_size = 8
max_abstract_length = 512

# Load Data
papers = []
citations = {}
data_loaded_from_file = False

# Check if saved files exist
if os.path.exists(papers_file) and os.path.exists(citations_file):
    print(f"Loading raw data from {papers_file} and {citations_file}...")
    try:
        with open(papers_file, 'r') as f:
            papers = json.load(f)
        with open(citations_file, 'r') as f:
            citations = json.load(f)
        print(f"Loaded {len(papers)} papers and citation data for {len(citations)} papers from files.")
        data_loaded_from_file = True
    except Exception as e:
        print(f"Error loading files: {e}. Will attempt to fetch from API instead.")
        papers = []
        citations = {}

# Else fetch from API
if not data_loaded_from_file:
    print("\nRaw data files not found or failed to load. Fetching from OpenAlex API...")--
    filter_params = {
        "filter": "primary_location.source.id:S4306400194,concepts.id:C41008148,from_publication_date:2020-01-01",
        "per-page": 200,
        "cursor": "*",     # For pagination
        "mailto": "rzg3@rice.edu"
    }

    page_count = 0
    max_pages = 100 # for 20,000 samples

    while page_count < max_pages:
        print(f"Requesting page {page_count + 1}...")
        try:
            response = requests.get("https://api.openalex.org/works", params=filter_params)
            response.raise_for_status()
            data = response.json()

            if not data.get('results'):
                 print("No more results found.")
                 break

            current_page_papers = 0
            for work in data['results']:
                paper_id = work.get('id')
                if not paper_id:
                    continue

                abstract_text = ""
                abstract_inverted_index = work.get('abstract_inverted_index')
                if abstract_inverted_index and isinstance(abstract_inverted_index, dict):
                    max_index = 0
                    for positions in abstract_inverted_index.values():
                        if positions: max_index = max(max_index, max(positions))
                    abstract_list = [None] * (max_index + 1)
                    for word, positions in abstract_inverted_index.items():
                        for pos in positions:
                            if 0 <= pos < len(abstract_list): abstract_list[pos] = word
                    abstract_text = ' '.join(word if word is not None else '' for word in abstract_list)

                papers.append({'id': paper_id, 'abstract': abstract_text})
                citations[paper_id] = work.get('referenced_works', []) or []
                current_page_papers += 1

            print(f"Retrieved {current_page_papers} papers from page {page_count + 1}. Total papers so far: {len(papers)}")
            page_count += 1

            # Check for next cursor
            next_cursor = data.get('meta', {}).get('next_cursor')
            if next_cursor:
                filter_params['cursor'] = next_cursor
                time.sleep(0.1)
            else:
                print("No next cursor found, stopping API calls.")
                break

        except requests.exceptions.HTTPError as e:
            print(f"HTTP Error fetching page {page_count + 1}: {e}")
            print(f"Response content: {response.text}")
            break # Stop fetching on error
        except requests.exceptions.RequestException as e:
            print(f"Network Error fetching page {page_count + 1}: {e}")
            break # Stop fetching on error
        except json.JSONDecodeError:
            print(f"Failed to decode JSON response from page {page_count + 1}.")
            break
        except Exception as e:
            print(f"An unexpected error occurred during API fetch on page {page_count + 1}: {e}")
            traceback.print_exc()
            break

    print(f"\nFinished API retrieval. Total papers retrieved: {len(papers)}")

    if papers: # Only save if we actually got some papers
        print(f"Saving raw data to {papers_file} and {citations_file}...")
        try:
            with open(papers_file, 'w') as f:
                json.dump(papers, f, indent=2)
            with open(citations_file, 'w') as f:
                json.dump(citations, f, indent=2)
            print("Raw data saved successfully.")
        except Exception as e:
            print(f"Error saving raw data: {e}")
    else:
        print("No papers retrieved from API, nothing to save.")


# Creating graph
if not papers:
    print("\nNo papers available to process (either not found in files or failed API fetch). Exiting.")
else:
    print(f"\nProceeding with {len(papers)} papers.")
    valid_paper_ids = {p['id'] for p in papers}
    id_to_index = {paper_id: i for i, paper_id in enumerate(valid_paper_ids)}
    papers = [p for p in papers if p['id'] in id_to_index]
    num_nodes = len(id_to_index)
    print(f"Created mapping for {num_nodes} unique papers.")

    edge_list = []
    citation_keys_in_subset = 0
    edges_found = 0
    for paper_id, cited_ids in citations.items():
        if paper_id in id_to_index:
            citation_keys_in_subset += 1
            if not isinstance(cited_ids, list): continue
            for cited_id in cited_ids:
                if cited_id in id_to_index:
                    edge_list.append([id_to_index[paper_id], id_to_index[cited_id]])
                    edges_found += 1
    print(f"Checked {citation_keys_in_subset} papers from citations dict that are in the subset.")
    print(f"Number of edges (citations) within the subset: {edges_found}")

    print("Preparing abstracts...")
    ordered_abstracts = [""] * num_nodes
    for paper in papers:
        idx = id_to_index.get(paper['id'])
        if idx is not None:
             ordered_abstracts[idx] = paper.get('abstract', '')

    print(f"\n--- Starting Embedding Generation with {llm_name} ---")

    if device == torch.device("cpu"):
        print("FATAL WARNING: Running Mistral-7B on CPU will be extremely slow or impossible.")

    embeddings = None
    llm_hidden_dim = None

    try:
        from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

        print(f"Loading LLM Config: {llm_name}")
        llm_cfg = AutoConfig.from_pretrained(llm_name, trust_remote_code=True)
        llm_hidden_dim = llm_cfg.hidden_size

        print("Setting up 4-bit quantization...")
        bnb_cfg = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )

        print(f"Loading LLM Model: {llm_name} (this may take time and RAM)...")
        llm = AutoModelForCausalLM.from_pretrained(
            llm_name,
            quantization_config=bnb_cfg,
            trust_remote_code=True,
        ).to(device).eval()
        print("LLM Model loaded.")

        print(f"Loading Tokenizer: {llm_name}")
        tok = AutoTokenizer.from_pretrained(llm_name, trust_remote_code=True)
        if tok.pad_token is None:
            print("Setting pad_token to eos_token")
            tok.pad_token = tok.eos_token
        tok.padding_side = "left"

        print(f"LLM Hidden Dimension: {llm_hidden_dim}")
        print(f"Using batch size: {embedding_batch_size}")
        print(f"Max abstract length (tokens): {max_abstract_length}")

        all_embeddings = []
        print("Generating embeddings in batches...")

        for i in tqdm(range(0, len(ordered_abstracts), embedding_batch_size), desc="Embedding Batches"):
            batch_abstracts = ordered_abstracts[i:i+embedding_batch_size]
            processed_batch = [text if text else tok.pad_token for text in batch_abstracts]

            inputs = tok(
                processed_batch, return_tensors="pt", padding=True,
                truncation=True, max_length=max_abstract_length
            ).to(device)

            with torch.no_grad():
                outputs = llm(**inputs, output_hidden_states=True)
                last_hidden_states = outputs.hidden_states[-2]

            attention_mask = inputs['attention_mask']
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
            sum_hidden_states = torch.sum(last_hidden_states * mask_expanded, dim=1)
            sum_mask = torch.clamp(torch.sum(mask_expanded, dim=1), min=1e-9)
            mean_pooled_embeddings = sum_hidden_states / sum_mask
            all_embeddings.append(mean_pooled_embeddings.cpu())

        embeddings = torch.cat(all_embeddings, dim=0)
        print(f"\nEmbeddings generated successfully with final shape: {embeddings.shape}")

        print("Cleaning up LLM and tokenizer...")
        del llm, tok, llm_cfg, bnb_cfg, inputs, outputs, last_hidden_states, all_embeddings
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        print("Cleanup complete.")

    except ImportError as e:
        print(f"Import Error: {e}. Make sure 'transformers', 'accelerate', 'bitsandbytes' are installed.")
    except Exception as e:
        print(f"An error occurred during LLM embedding generation: {e}")
        traceback.print_exc()

    # After embeddings create dataset
    if embeddings is not None and llm_hidden_dim is not None and embeddings.shape[0] == num_nodes:
        if edge_list:
            edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        else:
            print("Warning: No edges found within the subset. Creating empty edge_index.")
            edge_index = torch.empty((2, 0), dtype=torch.long)

        data = Data(x=embeddings, edge_index=edge_index)

        print("\nGraph data prepared:")
        print(data)
        print(f"Number of nodes: {data.num_nodes}")
        print(f"Number of edges: {data.num_edges}")
        print(f"Feature dimension: {data.num_node_features}")

        torch.save(data, final_graph_file)
        print(f"Final graph data saved to '{final_graph_file}'.")
        print("\nNext steps: Use the 'data' object or load from ")
        print(f"'{final_graph_file}' to train a GNN for link prediction.")
        print(f"IMPORTANT: Update the GNN 'in_channels' to {llm_hidden_dim} in the training script.")
    else:
        print("\nFailed to generate embeddings or llm_hidden_dim not found. Cannot proceed.")

# --- End of script ---


Using GPU: Tesla T4
GPU Memory: 14.74 GB
Loading raw data from openalex_papers_raw.json and openalex_citations_raw.json...
Loaded 20000 papers and citation data for 20000 papers from files.

Proceeding with 20000 papers.
Created mapping for 20000 unique papers.
Checked 20000 papers from citations dict that are in the subset.
Number of edges (citations) within the subset: 6114
Preparing abstracts...

--- Starting Embedding Generation with mistralai/Mistral-7B-v0.1 ---
Loading LLM Config: mistralai/Mistral-7B-v0.1


config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Setting up 4-bit quantization...
Loading LLM Model: mistralai/Mistral-7B-v0.1 (this may take time and RAM)...


model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

LLM Model loaded.
Loading Tokenizer: mistralai/Mistral-7B-v0.1


tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

Setting pad_token to eos_token
LLM Hidden Dimension: 4096
Using batch size: 8
Max abstract length (tokens): 512
Generating embeddings in batches...


Embedding Batches:   0%|          | 0/2500 [00:00<?, ?it/s]


Embeddings generated successfully with final shape: torch.Size([20000, 4096])
Cleaning up LLM and tokenizer...
Cleanup complete.

Graph data prepared:
Data(x=[20000, 4096], edge_index=[2, 6114])
Number of nodes: 20000
Number of edges: 6114
Feature dimension: 4096
Final graph data saved to 'openalex_subset_mistral_base_graph.pt'.

Next steps: Use the 'data' object or load from 
'openalex_subset_mistral_base_graph.pt' to train a GNN for link prediction.
IMPORTANT: Update the GNN 'in_channels' to 4096 in the training script.


In [None]:
import torch
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data

print("Splitting data into training, validation, and test sets...")

# Configure the split
transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    is_undirected=False,
    add_negative_train_samples=False,
    neg_sampling_ratio=1.0
)

# Apply the transform
train_data, val_data, test_data = transform(data)

print("\nData Splitting Complete.")
print("Training Data:", train_data)
print("Validation Data:", val_data)
print("Test Data:", test_data)

train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)

print(f"\nTrain edges (for GNN message passing): {train_data.edge_index.shape[1]}")
print(f"Validation edges (to predict): {val_data.edge_label_index.shape[1]}")
print(f" - Positive validation edges: {int(val_data.edge_label.sum())}")
print(f"Test edges (to predict): {test_data.edge_label_index.shape[1]}")
print(f" - Positive test edges: {int(test_data.edge_label.sum())}")


Splitting data into training, validation, and test sets...

Data Splitting Complete.
Training Data: Data(x=[20000, 4096], edge_index=[2, 4892], edge_label=[4892], edge_label_index=[2, 4892])
Validation Data: Data(x=[20000, 4096], edge_index=[2, 4892], edge_label=[1222], edge_label_index=[2, 1222])
Test Data: Data(x=[20000, 4096], edge_index=[2, 5503], edge_label=[1222], edge_label_index=[2, 1222])

Train edges (for GNN message passing): 4892
Validation edges (to predict): 1222
 - Positive validation edges: 611
Test edges (to predict): 1222
 - Positive test edges: 611


In [None]:
import torch
from torch_geometric.utils import to_scipy_sparse_matrix, degree, contains_isolated_nodes
import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score

def calculate_heuristic_scores(edge_index, num_nodes, edges_to_score):
    """Calculates Common Neighbors and Adamic-Adar scores for given edges."""
    print("Calculating adjacency matrix for heuristics...")s
    adj_sparse = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

    common_neighbors_scores = []
    adamic_adar_scores = []

    if contains_isolated_nodes(edge_index, num_nodes=num_nodes):
        print("Warning: Graph contains isolated nodes. Degree calculation might be affected.")

    deg = degree(edge_index[0], num_nodes=num_nodes).cpu().numpy() #
    deg_inv_log = 1.0 / np.log(deg + 1e-6)
    deg_inv_log[~np.isfinite(deg_inv_log)] = 0

    print(f"Calculating heuristic scores for {edges_to_score.shape[1]} edges...")
    source_nodes = edges_to_score[0].cpu().numpy()
    target_nodes = edges_to_score[1].cpu().numpy()

    batch_size = 10000 # Process edges in batches
    num_batches = (len(source_nodes) + batch_size - 1) // batch_size

    for i in range(num_batches):
        start = i * batch_size
        end = min((i + 1) * batch_size, len(source_nodes))
        batch_sources = source_nodes[start:end]
        batch_targets = target_nodes[start:end]

        cn_batch = np.array(adj_sparse[batch_sources].multiply(adj_sparse[batch_targets]).sum(axis=1)).flatten()
        common_neighbors_scores.extend(cn_batch)

        aa_batch = []
        for src, tgt in zip(batch_sources, batch_targets):
             common_neighbor_indices = adj_sparse[src].multiply(adj_sparse[tgt]).nonzero()[1]
             aa_score = np.sum(deg_inv_log[common_neighbor_indices])
             aa_batch.append(aa_score)
        adamic_adar_scores.extend(aa_batch)

        if (i + 1) % 10 == 0 or end == len(source_nodes):
             print(f"  Processed batch {i+1}/{num_batches}...")


    print("Heuristic score calculation finished.")
    return np.array(common_neighbors_scores), np.array(adamic_adar_scores)

print("\n--- Calculating Heuristics for Validation Set ---")
cn_val_scores, aa_val_scores = calculate_heuristic_scores(
    train_data.edge_index, data.num_nodes, val_data.edge_label_index
)

print("\n--- Calculating Heuristics for Test Set ---")
cn_test_scores, aa_test_scores = calculate_heuristic_scores(
    train_data.edge_index, data.num_nodes, test_data.edge_label_index
)

def evaluate_scores(scores, labels):
    """Evaluates prediction scores against true labels."""
    if len(np.unique(labels)) < 2:
         print("Warning: Only one class present in labels. AUC calculation is not possible.")
         auc = float('nan')
    else:
        auc = roc_auc_score(labels, scores)
    threshold = np.median(scores) if len(scores) > 0 else 0.5
    pred_labels = (scores >= threshold).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, pred_labels, average='binary', zero_division=0)
    acc = accuracy_score(labels, pred_labels)
    return auc, precision, recall, f1, acc

print("\n--- Evaluating Common Neighbors ---")
val_labels = val_data.edge_label.cpu().numpy()
test_labels = test_data.edge_label.cpu().numpy()

cn_val_auc, cn_val_p, cn_val_r, cn_val_f1, cn_val_acc = evaluate_scores(cn_val_scores, val_labels)
print(f"Validation CN - AUC: {cn_val_auc:.4f}, P: {cn_val_p:.4f}, R: {cn_val_r:.4f}, F1: {cn_val_f1:.4f}, Acc: {cn_val_acc:.4f}")

cn_test_auc, cn_test_p, cn_test_r, cn_test_f1, cn_test_acc = evaluate_scores(cn_test_scores, test_labels)
print(f"Test CN - AUC: {cn_test_auc:.4f}, P: {cn_test_p:.4f}, R: {cn_test_r:.4f}, F1: {cn_test_f1:.4f}, Acc: {cn_test_acc:.4f}")


print("\n--- Evaluating Adamic-Adar ---")
aa_val_auc, aa_val_p, aa_val_r, aa_val_f1, aa_val_acc = evaluate_scores(aa_val_scores, val_labels)
print(f"Validation AA - AUC: {aa_val_auc:.4f}, P: {aa_val_p:.4f}, R: {aa_val_r:.4f}, F1: {aa_val_f1:.4f}, Acc: {aa_val_acc:.4f}")

aa_test_auc, aa_test_p, aa_test_r, aa_test_f1, aa_test_acc = evaluate_scores(aa_test_scores, test_labels)
print(f"Test AA - AUC: {aa_test_auc:.4f}, P: {aa_test_p:.4f}, R: {aa_test_r:.4f}, F1: {aa_test_f1:.4f}, Acc: {aa_test_acc:.4f}")



--- Calculating Heuristics for Validation Set ---
Calculating adjacency matrix for heuristics...
Calculating heuristic scores for 1222 edges...
  Processed batch 1/1...
Heuristic score calculation finished.

--- Calculating Heuristics for Test Set ---
Calculating adjacency matrix for heuristics...
Calculating heuristic scores for 1222 edges...
  Processed batch 1/1...
Heuristic score calculation finished.

--- Evaluating Common Neighbors ---
Validation CN - AUC: 0.6113, P: 0.5000, R: 1.0000, F1: 0.6667, Acc: 0.5000
Test CN - AUC: 0.6121, P: 0.5000, R: 1.0000, F1: 0.6667, Acc: 0.5000

--- Evaluating Adamic-Adar ---
Validation AA - AUC: 0.5458, P: 0.4831, R: 0.9345, F1: 0.6369, Acc: 0.4673
Test AA - AUC: 0.5516, P: 0.4844, R: 0.9394, F1: 0.6392, Acc: 0.4697


In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, LayerNorm, GATConv

class LinkPredictorGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3, decoder_layers=2, dropout=0.5):
        super().__init__()
        assert num_layers >= 2, "Need at least 2 GNN layers"
        assert decoder_layers >= 1, "Need at least 1 decoder layer"

        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        self.residuals = torch.nn.ModuleList()

        # Input Layer
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.norms.append(LayerNorm(hidden_channels))
        # Add linear layer for residual connection if in_channels != hidden_channels
        if in_channels != hidden_channels:
            self.residuals.append(torch.nn.Linear(in_channels, hidden_channels))
        else:
            self.residuals.append(torch.nn.Identity())

        # Hidden Layers
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            self.norms.append(LayerNorm(hidden_channels))
            self.residuals.append(torch.nn.Identity()) # Residuals match dimensions here

        # Output Layer
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        if hidden_channels != out_channels:
            self.residuals.append(torch.nn.Linear(hidden_channels, out_channels))
        else:
            self.residuals.append(torch.nn.Identity())


        # Decoder
        self.decoder_layers = torch.nn.ModuleList()
        decoder_input_dim = out_channels * 2
        decoder_hidden_dim = out_channels * 4

        self.decoder_layers.append(torch.nn.Linear(decoder_input_dim, decoder_hidden_dim))
        for _ in range(decoder_layers - 1):
             self.decoder_layers.append(torch.nn.Linear(decoder_hidden_dim, decoder_hidden_dim))
        self.decoder_layers.append(torch.nn.Linear(decoder_hidden_dim, 1))


    def encode(self, x, edge_index):
        """Generates node embeddings using the GNN encoder."""
        x_residual_input = x

        for i in range(self.num_layers):
            x_in = x
            x = self.convs[i](x, edge_index)
            x = x + self.residuals[i](x_residual_input if i == 0 else x_in)

            if i < self.num_layers - 1:
                x = F.relu(x)
                if i < len(self.norms):
                    x = self.norms[i](x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        return x

    def decode(self, z, edge_label_index):
        """
        Predicts link existence scores using a multi-layer MLP decoder
        on concatenated node embeddings.
        """
        source_z = z[edge_label_index[0]]
        target_z = z[edge_label_index[1]]
        concat_z = torch.cat([source_z, target_z], dim=-1)

        x = concat_z
        for i in range(len(self.decoder_layers) - 1):
            x = self.decoder_layers[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        scores = self.decoder_layers[-1](x)

        return scores.squeeze(-1)

    def forward(self, x, edge_index, edge_label_index):
        """Full forward pass: encode nodes and decode edge scores."""
        z = self.encode(x, edge_index)
        scores = self.decode(z, edge_label_index)
        return scores

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score
import numpy as np

epochs = 150
lr = 1e-5
weight_decay = 5e-4

model = LinkPredictorGNN(
    in_channels=data.num_node_features,
    hidden_channels=128,
    out_channels=64,
    dropout=0.5
).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = torch.nn.BCEWithLogitsLoss()

patience = 15
epochs_no_improve = 0
best_val_auc_es = 0.0
best_epoch = 0
early_stop_triggered = False


def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    pos_edge_index = train_data.edge_index

    neg_edge_index = negative_sampling(
        edge_index=train_data.edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=pos_edge_index.shape[1],
        method='sparse'
    ).to(device)

    edge_label_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
    pos_label = torch.ones(pos_edge_index.shape[1], device=device)
    neg_label = torch.zeros(neg_edge_index.shape[1], device=device)
    edge_label = torch.cat([pos_label, neg_label], dim=0)

    out = model.decode(z, edge_label_index)
    loss = criterion(out, edge_label)

    if torch.isnan(loss):
        print("Warning: NaN loss detected! Stopping training.")
        return None

    loss.backward()
    optimizer.step()

    return loss.item()

@torch.no_grad()
def evaluate(eval_data):
    model.eval()
    z = model.encode(train_data.x, train_data.edge_index)
    if torch.isnan(z).any() or torch.isinf(z).any():
         print("Warning: NaN/Inf detected in node embeddings during evaluation!")
         return 0.0
    try:
        out = model.decode(z, eval_data.edge_label_index).cpu()
        labels = eval_data.edge_label.cpu()
        if len(torch.unique(labels)) < 2:
            return 0.5
        auc = roc_auc_score(labels, out.float())
        if np.isnan(auc):
            print("Warning: AUC calculation resulted in NaN. Returning 0.0.")
            return 0.0
        return auc
    except Exception as e:
        print(f"Error during evaluation: {e}")
        return 0.0

final_test_auc_at_best_val = 0

for epoch in range(1, epochs + 1):
    loss = train()
    if loss is None:
        print(f"Stopping training at epoch {epoch} due to NaN loss.")
        break

    val_auc = evaluate(val_data)
    test_auc = evaluate(test_data)

    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Test AUC: {test_auc:.4f}')

    if val_auc > best_val_auc_es:
        best_val_auc_es = val_auc
        final_test_auc_at_best_val = test_auc
        best_epoch = epoch
        epochs_no_improve = 0
        print(f"*** New Best Validation AUC: {best_val_auc_es:.4f} at Epoch {epoch} ***")
        try:
            torch.save(model.state_dict(), 'best_link_pred_model.pt')
            print("   (Best model checkpoint saved)")
        except Exception as e:
            print(f"   (Error saving model checkpoint: {e})")
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f"\nEarly stopping triggered after {patience} epochs with no improvement.")
        early_stop_triggered = True
        break


print(f"\n--- Training Finished ---")
if early_stop_triggered:
    print(f"Stopped early at epoch {epoch}.")
print(f"Best Validation AUC achieved: {best_val_auc_es:.4f} at epoch {best_epoch}")
print(f"Test AUC at Best Validation Epoch: {final_test_auc_at_best_val:.4f}")

print("\nLoading best model state for final evaluation...")
try:
    model.load_state_dict(torch.load('best_link_pred_model.pt', map_location=device))
    print("Best model loaded successfully.")
except FileNotFoundError:
    print("Warning: 'best_link_pred_model.pt' not found. Evaluating with the final model state.")
except Exception as e:
    print(f"Error loading best model state: {e}. Evaluating with the final model state.")




--- Starting GNN Training (with fixes) ---
Epoch: 001, Loss: 0.6816, Val AUC: 0.8637, Test AUC: 0.8629
*** New Best Validation AUC: 0.8637 at Epoch 1 ***
   (Best model checkpoint saved)
Epoch: 002, Loss: 0.6778, Val AUC: 0.8755, Test AUC: 0.8757
*** New Best Validation AUC: 0.8755 at Epoch 2 ***
   (Best model checkpoint saved)
Epoch: 003, Loss: 0.6780, Val AUC: 0.8828, Test AUC: 0.8846
*** New Best Validation AUC: 0.8828 at Epoch 3 ***
   (Best model checkpoint saved)
Epoch: 004, Loss: 0.6749, Val AUC: 0.8885, Test AUC: 0.8912
*** New Best Validation AUC: 0.8885 at Epoch 4 ***
   (Best model checkpoint saved)
Epoch: 005, Loss: 0.6709, Val AUC: 0.8936, Test AUC: 0.8963
*** New Best Validation AUC: 0.8936 at Epoch 5 ***
   (Best model checkpoint saved)
Epoch: 006, Loss: 0.6672, Val AUC: 0.8979, Test AUC: 0.9000
*** New Best Validation AUC: 0.8979 at Epoch 6 ***
   (Best model checkpoint saved)
Epoch: 007, Loss: 0.6633, Val AUC: 0.9011, Test AUC: 0.9032
*** New Best Validation AUC: 0.9

In [None]:
import torch
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score
import numpy as np

# Eval
@torch.no_grad()
def final_evaluate(eval_data):
    model.eval()

    z = model.encode(train_data.x, train_data.edge_index)

    logits = model.decode(z, eval_data.edge_label_index).cpu()
    labels = eval_data.edge_label.cpu().numpy()

    probs = torch.sigmoid(logits).numpy()

    # Calculate Metrics
    auc = roc_auc_score(labels, probs)

    best_threshold = 0
    best_f1 = -1
    val_logits = model.decode(z, val_data.edge_label_index).cpu()
    val_probs = torch.sigmoid(val_logits).numpy()
    val_labels_np = val_data.edge_label.cpu().numpy()
    for t in np.linspace(0, 1, 50):
        preds = (val_probs >= t).astype(int)
        _, _, f1, _ = precision_recall_fscore_support(val_labels_np, preds, average='binary', zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = t
    print(f"Best threshold found on validation set: {best_threshold:.4f}")
    threshold = best_threshold
    pred_labels = (probs >= threshold).astype(int)


    precision, recall, f1, _ = precision_recall_fscore_support(labels, pred_labels, average='binary', zero_division=0)
    acc = accuracy_score(labels, pred_labels)

    print("\n--- Final GNN Evaluation on Test Set ---")
    print(f"Test AUC: {auc:.4f}")
    print(f"Using Threshold: {threshold:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"Test F1-Score: {f1:.4f}")
    print(f"Test Accuracy: {acc:.4f}")

    return auc, precision, recall, f1, acc

final_evaluate(test_data)


Best threshold found on validation set: 0.5918

--- Final GNN Evaluation on Test Set ---
Test AUC: 0.9298
Using Threshold: 0.5918
Test Precision: 0.8854
Test Recall: 0.8347
Test F1-Score: 0.8593
Test Accuracy: 0.8633


(np.float64(0.9297869661765612),
 0.8854166666666666,
 0.8346972176759411,
 0.8593091828138163,
 0.8633387888707038)