In [4]:
#
# --- Step 1: Install Required Libraries ---
#
!pip install -q torch torch-geometric pandas duckdb pyarrow networkx gradio

In [3]:
#
# --- Step 2: Import Libraries ---
#
import gradio as gr
import pandas as pd
import duckdb
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
import os
import logging
import io
import json
from dataclasses import dataclass
from abc import ABC, abstractmethod
from tqdm import tqdm
from collections import defaultdict

#
# --- Step 3: Google Colab Drive Mount ---
#
try:
    from google.colab import drive
    drive.mount('/content/drive/')
    print("Google Drive mounted successfully.")
except ImportError:
    print("Not running in a Google Colab environment. Please ensure your data paths are correct.")
except Exception as e:
    print(f"Error mounting Google Drive: {e}")

#
# --- Step 4: Configuration & Logging Classes ---
#

@dataclass
class LinkPredictionConfig:
    """Holds all configuration for the graph prediction pipeline."""
    # Input Paths
    edge_csv_path: str = "/content/drive/My Drive/master_july_2025/data/link_graph_edges.csv"
    embeddings_dir_path: str = "/content/drive/My Drive/master_july_2025/data/url_embeddings/"

    # Output Artifact Paths
    output_dir: str = "/content/drive/My Drive/master_july_2025/data/prediction_model/"
    model_state_path: str = os.path.join(output_dir, "graphsage_link_predictor.pth")
    node_embeddings_path: str = os.path.join(output_dir, "final_node_embeddings.pt")
    node_mapping_path: str = os.path.join(output_dir, "model_metadata.json") # Changed name for clarity
    edge_index_path: str = os.path.join(output_dir, "edge_index.pt")

    # Model Hyperparameters
    hidden_channels: int = 128
    out_channels: int = 64
    learning_rate: float = 0.01
    epochs: int = 100

class ILogger(ABC):
    @abstractmethod
    def info(self, message: str): pass
    @abstractmethod
    def error(self, message: str): pass

class ConsoleAndGradioLogger(ILogger):
    def __init__(self, log_output_stream: io.StringIO, level=logging.INFO):
        self._logger = logging.getLogger("GraphLogger")
        self._logger.setLevel(level)
        if self._logger.hasHandlers():
            self._logger.handlers.clear()
        gradio_handler = logging.StreamHandler(log_output_stream)
        gradio_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
        self._logger.addHandler(gradio_handler)
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
        self._logger.addHandler(console_handler)

    def info(self, message: str): self._logger.info(message)
    def error(self, message: str): self._logger.error(message)

#
# --- Step 5: OOP Component Classes ---
#

class GraphDataLoader:
    def __init__(self, config: LinkPredictionConfig, logger: ILogger):
        self.config = config
        self.logger = logger

    def load_data(self):
        self.logger.info("Loading data using DuckDB...")
        try:
            con = duckdb.connect()
            all_nodes_query = f"""
                (SELECT "FROM" AS url FROM read_csv_auto('{self.config.edge_csv_path}', header=true))
                UNION
                (SELECT "TO" AS url FROM read_csv_auto('{self.config.edge_csv_path}', header=true))
            """
            embeddings_glob_path = os.path.join(self.config.embeddings_dir_path, '*.parquet')
            node_features_query = f"""
                WITH all_nodes AS ({all_nodes_query})
                SELECT n.url, e.Embedding AS features
                FROM all_nodes AS n
                LEFT JOIN read_parquet('{embeddings_glob_path}') AS e ON n.url = e.URL
            """
            node_features_df = con.execute(node_features_query).fetchdf()
            edge_list_df = con.execute(f"SELECT * FROM read_csv_auto('{self.config.edge_csv_path}', header=true)").fetchdf()
            self.logger.info(f"Loaded {len(edge_list_df)} edges and {len(node_features_df)} unique nodes.")
            return node_features_df, edge_list_df
        except Exception as e:
            self.logger.error(f"Failed to load data: {e}")
            raise

class GraphDataProcessor:
    def __init__(self, logger: ILogger):
        self.logger = logger

    def process(self, node_features_df: pd.DataFrame, edge_list_df: pd.DataFrame):
        self.logger.info("Processing data into tensors with neighbor feature inference...")
        url_to_features = pd.Series(node_features_df.features.values, index=node_features_df.url).to_dict()
        adj = defaultdict(set)
        for _, row in edge_list_df.iterrows():
            adj[row['FROM']].add(row['TO'])
            adj[row['TO']].add(row['FROM'])

        imputed_features = {}
        nodes_with_missing_features_count = 0

        # Drop rows with no valid embeddings to find the feature dimension reliably
        valid_features = node_features_df['features'].dropna()
        if valid_features.empty:
            raise ValueError("No nodes with features found. Cannot determine feature dimension.")
        feature_dim = len(valid_features.iloc[0])
        self.logger.info(f"Detected feature dimension: {feature_dim}")

        for url, features in url_to_features.items():
            is_missing = pd.isna(features)
            if (isinstance(is_missing, bool) and is_missing) or (hasattr(is_missing, 'any') and is_missing.any()):
                nodes_with_missing_features_count += 1
                neighbors = adj.get(url, set())
                neighbor_features = [np.array(url_to_features.get(n), dtype=np.float32) for n in neighbors if url_to_features.get(n) is not None and not pd.isna(url_to_features.get(n)).any()]
                if neighbor_features:
                    imputed_features[url] = np.mean(neighbor_features, axis=0)
                else:
                    imputed_features[url] = np.zeros(feature_dim, dtype=np.float32)
            else:
                imputed_features[url] = np.array(features, dtype=np.float32)
        if nodes_with_missing_features_count > 0:
            self.logger.info(f"Imputed features for {nodes_with_missing_features_count} nodes.")

        self.logger.info("Constructing final PyTorch tensors...")
        node_list = node_features_df['url'].tolist()
        url_to_idx = {url: i for i, url in enumerate(node_list)}
        final_feature_list = [imputed_features[url] for url in node_list]
        x = torch.tensor(np.array(final_feature_list), dtype=torch.float)
        source_indices = [url_to_idx.get(url) for url in edge_list_df['FROM']]
        dest_indices = [url_to_idx.get(url) for url in edge_list_df['TO']]
        edge_index = torch.tensor([source_indices, dest_indices], dtype=torch.long)

        data = Data(x=x, edge_index=edge_index)
        self.logger.info(f"Created PyG Data object with {data.num_nodes} nodes and {data.num_edges} edges.")
        return data, url_to_idx

class GraphSAGEModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

    def predict_link(self, z, edge_label_index):
        source_emb = z[edge_label_index[0]]
        dest_emb = z[edge_label_index[1]]
        return (source_emb * dest_emb).sum(dim=-1)

class LinkPredictionTrainer:
    def __init__(self, model: GraphSAGEModel, data: Data, config: LinkPredictionConfig, logger: ILogger):
        self.model = model
        self.data = data
        self.config = config
        self.logger = logger
        self.optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
        self.criterion = nn.BCEWithLogitsLoss()

    def _get_negative_samples(self):
        return torch.randint(0, self.data.num_nodes, (2, self.data.num_edges), dtype=torch.long)

    def train(self):
        edge_label_index = torch.cat([self.data.edge_index, self._get_negative_samples()], dim=1)
        edge_label = torch.cat([torch.ones(self.data.num_edges), torch.zeros(self.data.num_edges)], dim=0)
        for epoch in range(1, self.config.epochs + 1):
            self.model.train()
            self.optimizer.zero_grad()
            z = self.model(self.data.x, self.data.edge_index)
            out = self.model.predict_link(z, edge_label_index)
            loss = self.criterion(out, edge_label)
            loss.backward()
            self.optimizer.step()
            yield epoch, loss.item()

class RecommendationEngine:
    """Loads trained artifacts and provides link recommendations using a Top-K strategy."""
    def __init__(self, config: LinkPredictionConfig, logger: ILogger):
        self.config = config
        self.logger = logger
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        # Artifacts will be loaded on the first request
        self.model = None
        self.node_embeddings = None
        self.url_to_idx = None
        self.idx_to_url = None
        self.existing_edges = None

    def load_artifacts(self):
        """Loads the trained model, embeddings, and mappings into memory."""
        if self.model is not None:
            return True # Already loaded

        self.logger.info("Loading trained artifacts for recommendations...")
        try:
            # Load the metadata file which contains the model's "blueprint"
            with open(self.config.node_mapping_path, 'r') as f:
                model_metadata = json.load(f)

            self.url_to_idx = model_metadata['url_to_idx']
            in_channels = model_metadata['in_channels']
            hidden_channels = model_metadata['hidden_channels']
            out_channels = model_metadata['out_channels']

            self.idx_to_url = {v: k for k, v in self.url_to_idx.items()}

            # Load the final tensors
            self.node_embeddings = torch.load(self.config.node_embeddings_path).to(self.device)
            edge_index = torch.load(self.config.edge_index_path)
            self.existing_edges = set(zip(edge_index[0].tolist(), edge_index[1].tolist()))

            # Recreate model with the correct dimensions and load its state
            self.model = GraphSAGEModel(in_channels, hidden_channels, out_channels)
            self.model.load_state_dict(torch.load(self.config.model_state_path))
            self.model.to(self.device)
            self.model.eval() # Set model to evaluation mode

            self.logger.info("Artifacts loaded successfully.")
            return True
        except FileNotFoundError:
            self.logger.error("Could not find trained model artifacts. Please run the training pipeline first.")
            return False
        except Exception as e:
            self.logger.error(f"An error occurred while loading artifacts: {e}")
            raise

    def get_recommendations(self, source_url: str, top_n: int = 20):
        if not self.load_artifacts():
            return None, "Error: Trained model artifacts not found. Please run the training pipeline first."
        if source_url not in self.url_to_idx:
            return None, f"Error: Source URL '{source_url}' not found in the graph's training data."

        source_idx = self.url_to_idx[source_url]
        num_nodes = len(self.url_to_idx)

        # --- NEW "Top-K" LOGIC ---

        # 1. Create candidate edges from the source to ALL other nodes
        candidate_dest_indices = torch.arange(num_nodes, device=self.device)
        candidate_source_indices = torch.full_like(candidate_dest_indices, fill_value=source_idx)
        candidate_edge_index = torch.stack([candidate_source_indices, candidate_dest_indices])

        # 2. Score all candidates at once
        with torch.no_grad():
            scores = self.model.predict_link(self.node_embeddings, candidate_edge_index)

        # 3. Find the top K highest-scoring candidates (e.g., top 100)
        # We get more than `top_n` to account for filtering out existing links.
        k = min(num_nodes, top_n + 50)
        top_scores, top_indices = torch.topk(scores, k=k)

        # 4. Filter this small list to find novel recommendations
        recommendations = []
        for i in range(k):
            dest_idx = top_indices[i].item()

            # Stop if we've found enough recommendations
            if len(recommendations) >= top_n:
                break

            # Check if the candidate is valid (not the source itself and not an existing link)
            is_self_link = (dest_idx == source_idx)
            is_existing_link = (source_idx, dest_idx) in self.existing_edges

            if not is_self_link and not is_existing_link:
                recommendations.append({
                    "RECOMMENDED_URL": self.idx_to_url[dest_idx],
                    "SCORE": torch.sigmoid(top_scores[i]).item() # Apply sigmoid to get probability
                })

        return pd.DataFrame(recommendations), None
#
# --- Main Gradio Application Functions ---
#

def run_training_pipeline(csv_path, embeddings_path, hidden_channels, out_channels, lr, epochs, progress=gr.Progress(track_tqdm=True)):
    log_stream = io.StringIO()
    logger = ConsoleAndGradioLogger(log_stream)
    try:
        yield "Step 1/5: Initializing...", log_stream.getvalue(), None
        config = LinkPredictionConfig(
            edge_csv_path=csv_path, embeddings_dir_path=embeddings_path,
            hidden_channels=int(hidden_channels), out_channels=int(out_channels),
            learning_rate=lr, epochs=int(epochs)
        )
        os.makedirs(config.output_dir, exist_ok=True)

        yield "Step 2/5: Loading & processing data...", log_stream.getvalue(), None
        loader = GraphDataLoader(config, logger)
        node_features_df, edge_list_df = loader.load_data()
        processor = GraphDataProcessor(logger)
        data, url_to_idx = processor.process(node_features_df, edge_list_df)

        yield "Step 3/5: Initializing model...", log_stream.getvalue(), None
        model = GraphSAGEModel(in_channels=data.num_node_features, hidden_channels=config.hidden_channels, out_channels=config.out_channels)
        trainer = LinkPredictionTrainer(model, data, config, logger)

        yield "Step 4/5: Training model...", log_stream.getvalue(), None
        for epoch, loss in progress.tqdm(trainer.train(), total=config.epochs, desc="Training Model"):
            if epoch % 10 == 0 or epoch == 1:
                logger.info(f"Epoch {epoch}/{config.epochs}, Loss: {loss:.4f}")

        yield "Step 5/5: Evaluating and saving artifacts...", log_stream.getvalue(), None
        model.eval()
        with torch.no_grad():
            final_node_embeddings = model(data.x, data.edge_index)

        # --- THIS IS THE FIX ---
        # Save the model's architecture metadata along with the URL mapping
        logger.info(f"Saving model metadata to {config.node_mapping_path}")
        model_metadata = {
            "url_to_idx": url_to_idx,
            "in_channels": data.num_node_features,
            "hidden_channels": config.hidden_channels,
            "out_channels": config.out_channels
        }
        with open(config.node_mapping_path, 'w') as f:
            json.dump(model_metadata, f, indent=2)

        logger.info(f"Saving model weights to {config.model_state_path}")
        torch.save(model.state_dict(), config.model_state_path)
        logger.info(f"Saving final node embeddings to {config.node_embeddings_path}")
        torch.save(final_node_embeddings, config.node_embeddings_path)
        logger.info(f"Saving edge index to {config.edge_index_path}")
        torch.save(data.edge_index, config.edge_index_path)

        final_status = "✅ Pipeline Finished Successfully!"
        logger.info(final_status)
        yield final_status, log_stream.getvalue(), pd.DataFrame({"Message": ["Artifacts saved successfully. You can now use the recommendation tab."]})

    except Exception as e:
        logger.exception(f"A critical error occurred: {e}")
        yield "Pipeline Failed", log_stream.getvalue(), pd.DataFrame({"Error": [str(e)]})


def run_recommendation_interface(source_url: str):
    if not source_url:
        return None, "Please select a source URL from the dropdown."
    log_stream = io.StringIO()
    logger = ConsoleAndGradioLogger(log_stream)
    config = LinkPredictionConfig()
    engine = RecommendationEngine(config, logger)
    recommendations_df, error_msg = engine.get_recommendations(source_url, top_n=20)
    if error_msg:
        logger.error(error_msg)
    return recommendations_df, log_stream.getvalue()

def get_all_nodes_for_dropdown():
    try:
        config = LinkPredictionConfig()
        if not os.path.exists(config.edge_csv_path):
            return ["Run training first to create edge list"]
        con = duckdb.connect()
        nodes_df = con.execute(f"""
            (SELECT "FROM" AS url FROM read_csv_auto('{config.edge_csv_path}', header=true))
            UNION
            (SELECT "TO" AS url FROM read_csv_auto('{config.edge_csv_path}', header=true))
        """).fetchdf()
        return sorted(nodes_df['url'].tolist())
    except Exception as e:
        return [f"Could not load URLs: {e}"]

#
# --- Gradio UI Definition ---
#
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 📈 GNN Link Prediction & Recommendation Engine")
    gr.Markdown("First, use the 'Train Model' tab to process your data. Then, use the 'Get Link Recommendations' tab to get predictions for new, non-existent links.")

    with gr.Tabs():
        with gr.TabItem("Train Model"):
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("## 1. Configuration")
                    with gr.Accordion("Data Paths", open=True):
                        edge_csv_path_input = gr.Textbox(label="Edge List CSV Path", value=LinkPredictionConfig.edge_csv_path)
                        embeddings_dir_path_input = gr.Textbox(label="Embeddings Directory Path", value=LinkPredictionConfig.embeddings_dir_path)
                    with gr.Accordion("Model Hyperparameters", open=True):
                        hidden_channels_input = gr.Number(label="Hidden Channels", value=LinkPredictionConfig.hidden_channels)
                        out_channels_input = gr.Number(label="Output Embedding Size", value=LinkPredictionConfig.out_channels)
                    with gr.Accordion("Training Parameters", open=True):
                        learning_rate_input = gr.Number(label="Learning Rate", value=LinkPredictionConfig.learning_rate)
                        epochs_input = gr.Number(label="Training Epochs", value=LinkPredictionConfig.epochs)
                    start_button = gr.Button("Train Link Prediction Model", variant="primary")
                with gr.Column(scale=2):
                    gr.Markdown("## 2. Training Status")
                    train_status_output = gr.Textbox(label="Current Status", interactive=False)
                    train_log_output = gr.Textbox(label="Pipeline Logs", interactive=False, lines=15)
                    train_results_output = gr.DataFrame(label="Training Completion Status")

        with gr.TabItem("Get Link Recommendations"):
            gr.Markdown("## 1. Select a Source Page")
            gr.Markdown("Choose a URL and the model will recommend top pages it should link to. (You must train the model on the tab to the left first).")
            with gr.Row():
                source_url_dropdown = gr.Dropdown(label="Source URL", choices=get_all_nodes_for_dropdown(), interactive=True)
            recommend_button = gr.Button("Get Recommendations", variant="primary")
            gr.Markdown("## 2. Results: High-Potential Missing Links")
            recommend_results_output = gr.DataFrame(label="Top 20 Link Recommendations", headers=["RECOMMENDED_URL", "SCORE"])
            recommend_log_output = gr.Textbox(label="Logs", interactive=False, lines=4)

    start_button.click(
        fn=run_training_pipeline,
        inputs=[edge_csv_path_input, embeddings_dir_path_input, hidden_channels_input, out_channels_input, learning_rate_input, epochs_input],
        outputs=[train_status_output, train_log_output, train_results_output]
    )

    recommend_button.click(
        fn=run_recommendation_interface,
        inputs=[source_url_dropdown],
        outputs=[recommend_results_output, recommend_log_output]
    )

if __name__ == '__main__':
    try:
        from google.colab import drive
        drive.mount('/content/drive/', force_remount=True)
        demo.launch(debug=True, share=True)
    except Exception as e:
        print(f"Could not launch Gradio demo in this environment: {e}")

[31mERROR: Operation cancelled by user[0m[31m
[0m^C


KeyboardInterrupt: 