In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
from typing import Dict, Optional
import warnings
import pickle
import os
warnings.filterwarnings('ignore')

# PyTorch and PyTorch Geometric imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.data import Data

class StreamlinedGNNPipeline:
    """
    Streamlined GNN Pipeline with essential steps:
    1. Data loading and labeling
    2. PCA with standardization (50 features)
    3. Cosine similarity computation
    4. Connected graph construction with minimal edges
    5. Training on GCN, GAT, and SAGE models
    
    NEW: Added state saving/loading functionality for reproducible results
    """
    
    def __init__(self, file_path: str, save_dir: str = "gnn_pipeline_state"):
        self.file_path = file_path
        self.save_dir = save_dir
        self.data = None
        self.processed_data = None
        self.graph_data = None
        self.similarity_matrix = None
        self.optimal_threshold = None
        
        # Create save directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)
        
    def save_state(self, state_dict: Dict, filename: str):
        """Save state dictionary to file"""
        filepath = os.path.join(self.save_dir, filename)
        with open(filepath, 'wb') as f:
            pickle.dump(state_dict, f)
        print(f" Saved state to: {filepath}")
    
    def load_state(self, filename: str) -> Optional[Dict]:
        """Load state dictionary from file"""
        filepath = os.path.join(self.save_dir, filename)
        if os.path.exists(filepath):
            with open(filepath, 'rb') as f:
                state_dict = pickle.load(f)
            print(f" Loaded state from: {filepath}")
            return state_dict
        return None
    
    def save_model_state(self, model: nn.Module, model_name: str):
        """Save model weights"""
        filepath = os.path.join(self.save_dir, f"{model_name}_weights.pth")
        torch.save(model.state_dict(), filepath)
        print(f" Saved {model_name} weights to: {filepath}")
    
    def load_model_state(self, model: nn.Module, model_name: str) -> bool:
        """Load model weights"""
        filepath = os.path.join(self.save_dir, f"{model_name}_weights.pth")
        if os.path.exists(filepath):
            model.load_state_dict(torch.load(filepath, map_location='cpu'))
            print(f" Loaded {model_name} weights from: {filepath}")
            return True
        return False
        
    def load_and_label_data(self) -> Dict:
        """Step 1: Load data and extract labels"""
        print("="*60)
        print("STEP 1: DATA LOADING AND LABELING")
        print("="*60)
        
        # Check if data processing state exists
        saved_state = self.load_state("data_processing_state.pkl")
        if saved_state is not None:
            print(" Using saved data processing state")
            return saved_state
        
        # Load data
        self.data = pd.read_csv(self.file_path)
        feature_columns = [col for col in self.data.columns 
                          if col not in ['Energy(keV)', 'Spectrum', 'Acq mode']]
        
        print(f"Data loaded: {self.data.shape[0]} samples, {len(feature_columns)} features")
        
        # Extract labels
        samples_info = []
        for idx, row in self.data.iterrows():
            sample_name = row['Energy(keV)']
            
            # Label extraction
            if '_3C_' in sample_name or '3C' in sample_name:
                label = 0  # Control
            elif '_3T_' in sample_name or '3T' in sample_name:
                label = 1  # Patient
            elif '_5C_' in sample_name or '5C' in sample_name:
                label = 0  # Control
            elif '_5T_' in sample_name or '5T' in sample_name:
                label = 1  # Patient
            elif '_1C_' in sample_name or '1C' in sample_name:
                label = 0  # Control
            elif '_1T_' in sample_name or '1T' in sample_name:
                label = 1  # Patient
            else:
                label = -1  # Unknown
            
            samples_info.append({
                'sample_idx': idx,
                'sample_name': sample_name,
                'label': label
            })
        
        samples_df = pd.DataFrame(samples_info)
        
        # Remove unknown labels
        valid_mask = samples_df['label'] != -1
        samples_df = samples_df[valid_mask].reset_index(drop=True)
        valid_data = self.data.loc[samples_df['sample_idx']].reset_index(drop=True)
        
        # Get features and labels
        X = valid_data[feature_columns].values
        y = samples_df['label'].values
        
        print(f"Valid samples: {len(y)}")
        print(f"Label distribution - Control: {np.sum(y == 0)}, Patient: {np.sum(y == 1)}")
        
        # Save state
        state_dict = {
            'X': X,
            'y': y,
            'samples_info': samples_df,
            'feature_columns': feature_columns
        }
        self.save_state(state_dict, "data_processing_state.pkl")
        
        return state_dict
    
    def apply_pca_standardization(self, X: np.ndarray, n_components: int = 50) -> Dict:
        """Step 2: Apply PCA with standardization to reduce to 50 features"""
        print("="*60)
        print("STEP 2: PCA WITH STANDARDIZATION")
        print("="*60)
        
        # Check if PCA state exists
        saved_state = self.load_state("pca_state.pkl")
        if saved_state is not None:
            print(" Using saved PCA state")
            return saved_state
        
        original_dim = X.shape[1]
        print(f"Original features: {original_dim}")
        
        # Standardization first
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        print("Applied standardization")
        
        # PCA with fixed random state
        n_components = min(n_components, X.shape[0] - 1, X.shape[1])
        pca = PCA(n_components=n_components, random_state=42)
        X_pca = pca.fit_transform(X_scaled)
        
        explained_variance = np.sum(pca.explained_variance_ratio_)
        
        print(f"PCA applied: {original_dim} → {X_pca.shape[1]} features")
        print(f"Explained variance: {explained_variance:.4f} ({100*explained_variance:.2f}%)")
        
        # Save state
        state_dict = {
            'X_pca': X_pca,
            'explained_variance': explained_variance,
            'scaler': scaler,
            'pca': pca
        }
        self.save_state(state_dict, "pca_state.pkl")
        
        return state_dict
    
    def compute_cosine_similarity(self, X: np.ndarray) -> np.ndarray:
        """Step 3: Compute cosine similarity matrix"""
        print("="*60)
        print("STEP 3: COSINE SIMILARITY COMPUTATION")
        print("="*60)
        
        # Check if similarity state exists
        saved_state = self.load_state("similarity_state.pkl")
        if saved_state is not None:
            print(" Using saved similarity matrix")
            self.similarity_matrix = saved_state['similarity_matrix']
            return self.similarity_matrix
        
        self.similarity_matrix = cosine_similarity(X)
        n_samples = X.shape[0]
        
        # Get pairwise similarities (upper triangle, excluding diagonal)
        similarities = []
        for i in range(n_samples):
            for j in range(i + 1, n_samples):
                similarities.append(self.similarity_matrix[i, j])
        
        similarities = np.array(similarities)
        
        print(f"Similarity matrix computed: {n_samples} × {n_samples}")
        print(f"Similarity range: [{np.min(similarities):.4f}, {np.max(similarities):.4f}]")
        print(f"Mean similarity: {np.mean(similarities):.4f} ± {np.std(similarities):.4f}")
        
        # Save state
        state_dict = {
            'similarity_matrix': self.similarity_matrix
        }
        self.save_state(state_dict, "similarity_state.pkl")
        
        return self.similarity_matrix
    
    def find_connected_threshold(self, max_edges: int = 500) -> float:
        """Step 4: Find threshold for connected graph with minimal edges"""
        print("="*60)
        print("STEP 4: CONNECTED GRAPH CONSTRUCTION")
        print("="*60)
        
        # Check if threshold state exists
        saved_state = self.load_state("threshold_state.pkl")
        if saved_state is not None:
            print(" Using saved threshold")
            self.optimal_threshold = saved_state['optimal_threshold']
            print(f" LOADED THRESHOLD: {self.optimal_threshold:.6f}")
            return self.optimal_threshold
        
        n_samples = self.similarity_matrix.shape[0]
        max_possible_edges = n_samples * (n_samples - 1) // 2
        
        print(f"Finding connected graph with minimal edges")
        print(f"Target max edges: {max_edges} (out of {max_possible_edges} possible)")
        
        # Get unique similarity values and sort in descending order
        similarities = []
        for i in range(n_samples):
            for j in range(i + 1, n_samples):
                similarities.append(self.similarity_matrix[i, j])
        
        unique_sims = np.unique(similarities)
        unique_sims = np.sort(unique_sims)[::-1]  # Descending order
        
        # Binary search for minimum threshold that ensures connectivity
        def is_connected_at_threshold(threshold):
            G = nx.Graph()
            G.add_nodes_from(range(n_samples))
            
            edge_count = 0
            for i in range(n_samples):
                for j in range(i + 1, n_samples):
                    if self.similarity_matrix[i, j] >= threshold:
                        G.add_edge(i, j)
                        edge_count += 1
            
            return nx.is_connected(G), edge_count
        
        # Find minimum threshold for connectivity
        connected_threshold = None
        min_edges_connected = float('inf')
        
        print("Searching for optimal threshold...")
        for i, threshold in enumerate(unique_sims):
            is_conn, edge_count = is_connected_at_threshold(threshold)
            
            if is_conn:
                if edge_count <= max_edges:
                    # Found connected graph within edge limit
                    connected_threshold = threshold
                    min_edges_connected = edge_count
                    print(f" Found connected graph: threshold={threshold:.6f}, edges={edge_count}")
                    break
                elif connected_threshold is None:
                    # First connected graph (may exceed edge limit)
                    connected_threshold = threshold
                    min_edges_connected = edge_count
                    print(f"⚠️ Connected graph found but exceeds limit: threshold={threshold:.6f}, edges={edge_count}")
            
            if i % 50 == 0:
                status = "Yes" if is_conn else "No"
                print(f"  Testing threshold {threshold:.6f}: {edge_count} edges, connected={status}")
        
        if connected_threshold is None:
            # Fallback: use threshold that gives largest component
            print(" No connected graph found, using largest component approach")
            connected_threshold = np.percentile(similarities, 90)
        
        self.optimal_threshold = connected_threshold
        
        # Build final graph
        final_conn, final_edges = is_connected_at_threshold(connected_threshold)
        print(f"\n SELECTED THRESHOLD: {connected_threshold:.6f}")
        print(f"    Edges: {final_edges}")
        print(f"    Connected: {'YES' if final_conn else 'NO'}")
        print(f"    Density: {final_edges/max_possible_edges:.6f}")
        
        # Save state
        state_dict = {
            'optimal_threshold': connected_threshold
        }
        self.save_state(state_dict, "threshold_state.pkl")
        
        return connected_threshold
    
    def build_graph_data(self, X: np.ndarray, y: np.ndarray) -> Data:
        """Build PyTorch Geometric graph data"""
        n_samples = X.shape[0]
        
        # Build edges based on threshold
        edge_indices = []
        edge_weights = []
        
        for i in range(n_samples):
            for j in range(i + 1, n_samples):
                sim = self.similarity_matrix[i, j]
                if sim >= self.optimal_threshold:
                    edge_indices.append([i, j])
                    edge_indices.append([j, i])  # Undirected
                    edge_weights.extend([sim, sim])
        
        if len(edge_indices) > 0:
            edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(edge_weights, dtype=torch.float)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_attr = torch.empty((0,), dtype=torch.float)
        
        # Create graph data
        self.graph_data = Data(
            x=torch.tensor(X, dtype=torch.float),
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=torch.tensor(y, dtype=torch.long)
        )
        
        return self.graph_data
    
    def split_data(self, train_ratio: float = 0.7, val_ratio: float = 0.15) -> Dict:
        """Step 5: Split data into train/validation/test sets"""
        print("="*60)
        print("STEP 5: DATA SPLITTING")
        print("="*60)
        
        # Check if split state exists
        saved_state = self.load_state("split_state.pkl")
        if saved_state is not None:
            print("Using saved data splits")
            # Apply saved masks to graph data
            self.graph_data.train_mask = torch.tensor(saved_state['train_mask'], dtype=torch.bool)
            self.graph_data.val_mask = torch.tensor(saved_state['val_mask'], dtype=torch.bool)
            self.graph_data.test_mask = torch.tensor(saved_state['test_mask'], dtype=torch.bool)
            return saved_state['splits']
        
        n_nodes = self.graph_data.x.shape[0]
        y = self.graph_data.y.numpy()
        
        # Stratified split with fixed random state
        indices = np.arange(n_nodes)
        train_idx, temp_idx = train_test_split(
            indices, test_size=(1-train_ratio), stratify=y, random_state=42
        )
        
        val_test_ratio = val_ratio / (1-train_ratio)
        val_idx, test_idx = train_test_split(
            temp_idx, test_size=(1-val_test_ratio), stratify=y[temp_idx], random_state=42
        )
        
        # Create masks
        train_mask = torch.zeros(n_nodes, dtype=torch.bool)
        val_mask = torch.zeros(n_nodes, dtype=torch.bool)
        test_mask = torch.zeros(n_nodes, dtype=torch.bool)
        
        train_mask[train_idx] = True
        val_mask[val_idx] = True
        test_mask[test_idx] = True
        
        self.graph_data.train_mask = train_mask
        self.graph_data.val_mask = val_mask
        self.graph_data.test_mask = test_mask
        
        splits = {'train': train_idx, 'val': val_idx, 'test': test_idx}
        
        print(f"Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")
        
        # Save state
        state_dict = {
            'splits': splits,
            'train_mask': train_mask.numpy(),
            'val_mask': val_mask.numpy(),
            'test_mask': test_mask.numpy()
        }
        self.save_state(state_dict, "split_state.pkl")
        
        return splits
    
    def train_model(self, model: nn.Module, model_name: str, num_epochs: int = 200, lr: float = 0.01) -> Dict:
        """Train a GNN model with state saving"""
        # Check if trained model exists
        if self.load_model_state(model, model_name):
            print(f"Using saved {model_name} weights")
            # Still return dummy history for consistency
            return {'train_acc': [], 'val_acc': []}
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        data = self.graph_data.to(device)
        
        # Set seeds for reproducibility
        torch.manual_seed(42)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(42)
        
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
        criterion = nn.CrossEntropyLoss()
        
        history = {'train_acc': [], 'val_acc': []}
        best_val_acc = 0
        
        print(f"Training {model.__class__.__name__}...")
        
        for epoch in range(num_epochs):
            # Training
            model.train()
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()
            
            # Evaluation
            if epoch % 20 == 0:
                model.eval()
                with torch.no_grad():
                    out = model(data.x, data.edge_index)
                    
                    # Training accuracy
                    train_pred = out[data.train_mask].argmax(dim=1)
                    train_acc = (train_pred == data.y[data.train_mask]).float().mean()
                    
                    # Validation accuracy
                    val_pred = out[data.val_mask].argmax(dim=1)
                    val_acc = (val_pred == data.y[data.val_mask]).float().mean()
                    
                    history['train_acc'].append(train_acc.item())
                    history['val_acc'].append(val_acc.item())
                    
                    if val_acc > best_val_acc:
                        best_val_acc = val_acc
                    
                    if epoch % 40 == 0:
                        print(f'  Epoch {epoch:3d}: Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')
        
        # Save trained model
        self.save_model_state(model, model_name)
        
        return history
    
    def evaluate_model(self, model: nn.Module) -> Dict:
        """Evaluate model on test set"""
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        data = self.graph_data.to(device)
        
        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index)
            test_pred = out[data.test_mask].argmax(dim=1).cpu().numpy()
            test_true = data.y[data.test_mask].cpu().numpy()
            
            test_acc = accuracy_score(test_true, test_pred)
            
            return {
                'test_accuracy': test_acc,
                'predictions': test_pred,
                'true_labels': test_true,
                'classification_report': classification_report(test_true, test_pred, 
                                                             target_names=['Control', 'Patient'], 
                                                             output_dict=True, zero_division=0)
            }
    
    def clear_saved_state(self):
        """Clear all saved states to start fresh"""
        import shutil
        if os.path.exists(self.save_dir):
            shutil.rmtree(self.save_dir)
            os.makedirs(self.save_dir, exist_ok=True)
            print(f" Cleared all saved states from {self.save_dir}")
    
    def run_complete_pipeline(self, use_saved_state: bool = True) -> Dict:
        """Step 6: Run complete pipeline and train all three models"""
        print("STREAMLINED GNN PIPELINE WITH STATE SAVING")
        print("="*60)
        
        if not use_saved_state:
            print(" Starting fresh - clearing saved states")
            self.clear_saved_state()
        
        # Step 1: Load and label data
        data_dict = self.load_and_label_data()
        
        # Step 2: PCA with standardization
        pca_results = self.apply_pca_standardization(data_dict['X'], n_components=50)
        
        # Step 3: Cosine similarity
        self.compute_cosine_similarity(pca_results['X_pca'])
        
        # Step 4: Find connected threshold
        self.find_connected_threshold(max_edges=500)
        
        # Build graph
        self.build_graph_data(pca_results['X_pca'], data_dict['y'])
        
        # Step 5: Split data
        splits = self.split_data()
        
        # Step 6: Train all three models
        print("="*60)
        print("STEP 6: TRAINING ALL THREE MODELS")
        print("="*60)
        
        input_dim = self.graph_data.x.shape[1]
        hidden_dim = 64
        num_classes = 2
        
        models = {
            'GCN': GCNModel(input_dim, hidden_dim, num_classes),
            'GAT': GATModel(input_dim, hidden_dim, num_classes),
            'SAGE': SAGEModel(input_dim, hidden_dim, num_classes)
        }
        
        results = {}
        
        for model_name, model in models.items():
            print(f"\n--- Training {model_name} ---")
            history = self.train_model(model, model_name, num_epochs=200)
            evaluation = self.evaluate_model(model)
            
            print(f"{model_name} Test Accuracy: {evaluation['test_accuracy']:.4f}")
            
            results[model_name] = {
                'model': model,
                'history': history,
                'evaluation': evaluation
            }
        
        # Summary
        print("\n" + "="*60)
        print("FINAL RESULTS SUMMARY")
        print("="*60)
        print(f"Graph: {self.graph_data.x.shape[0]} nodes, {self.graph_data.edge_index.shape[1]//2} edges")
        print(f"Features: {input_dim} (after PCA)")
        print("\nModel Performance:")
        for model_name, result in results.items():
            acc = result['evaluation']['test_accuracy']
            print(f"  {model_name}: {acc:.4f}")
        
        # Save final results
        final_results = {
            'model_performances': {name: result['evaluation']['test_accuracy'] 
                                 for name, result in results.items()},
            'graph_info': {
                'nodes': self.graph_data.x.shape[0],
                'edges': self.graph_data.edge_index.shape[1]//2,
                'features': input_dim
            }
        }
        self.save_state(final_results, "final_results.pkl")
        
        return results


# GNN Model Definitions (unchanged)
class GCNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, dropout=0.5):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, num_classes)
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


class GATModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, dropout=0.5, heads=8):
        super(GATModel, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim, heads=heads, dropout=dropout)
        self.conv2 = GATConv(hidden_dim * heads, num_classes, heads=1, dropout=dropout)
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


class SAGEModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, dropout=0.5):
        super(SAGEModel, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, num_classes)
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)



if __name__ == "__main__":
    # Initialize pipeline
    pipeline = StreamlinedGNNPipeline('flou.csv')
    
    # Run complete pipeline (uses saved state by default)
    # First run: use_saved_state=False to compute and save everything
    # Subsequent runs: use_saved_state=True (default) to load saved state
    results = pipeline.run_complete_pipeline(use_saved_state=True)
    
    # Access individual model results
    gcn_accuracy = results['GCN']['evaluation']['test_accuracy']
    gat_accuracy = results['GAT']['evaluation']['test_accuracy']
    sage_accuracy = results['SAGE']['evaluation']['test_accuracy']
    
    print(f"\nBest performing model: {max(results.keys(), key=lambda x: results[x]['evaluation']['test_accuracy'])}")
    
    # To start fresh and recompute everything:
    # results = pipeline.run_complete_pipeline(use_saved_state=False)

STREAMLINED GNN PIPELINE WITH STATE SAVING
STEP 1: DATA LOADING AND LABELING
 Loaded state from: gnn_pipeline_state\data_processing_state.pkl
 Using saved data processing state
STEP 2: PCA WITH STANDARDIZATION
 Loaded state from: gnn_pipeline_state\pca_state.pkl
 Using saved PCA state
STEP 3: COSINE SIMILARITY COMPUTATION
 Loaded state from: gnn_pipeline_state\similarity_state.pkl
 Using saved similarity matrix
STEP 4: CONNECTED GRAPH CONSTRUCTION
 Loaded state from: gnn_pipeline_state\threshold_state.pkl
 Using saved threshold
 LOADED THRESHOLD: 0.434903
STEP 5: DATA SPLITTING
 Loaded state from: gnn_pipeline_state\split_state.pkl
Using saved data splits
STEP 6: TRAINING ALL THREE MODELS

--- Training GCN ---
 Loaded GCN weights from: gnn_pipeline_state\GCN_weights.pth
Using saved GCN weights
GCN Test Accuracy: 0.8286

--- Training GAT ---
 Loaded GAT weights from: gnn_pipeline_state\GAT_weights.pth
Using saved GAT weights
GAT Test Accuracy: 0.8571

--- Training SAGE ---
 Loaded SAGE 