In [5]:
"""
Enhanced LightGCN Experiments with Comprehensive Result Saving
===============================================================

A clean, well-organized implementation of LightGCN experiments with:
- Phase 1: Graph Structure Ablation Analysis
- Phase 2: Feature Importance Analysis
- Comprehensive result saving and statistical analysis

Author: Enhanced from original implementation
Date: 2024
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import pickle
import json
import os
import math
import random
import time
import gc
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
from scipy.sparse import coo_matrix, csr_matrix
from scipy import stats as scipy_stats
from collections import Counter, defaultdict
from typing import Dict, List, Tuple, Set
from datetime import datetime

warnings.filterwarnings('ignore')


# =============================================================================
# 1. UTILITY FUNCTIONS
# =============================================================================

def set_all_seeds(seed=42, use_deterministic=True):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    if use_deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    os.environ['PYTHONHASHSEED'] = str(seed)

def seed_worker(worker_id):
    """Worker function for DataLoader"""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


# =============================================================================
# 2. RESULT SAVING SYSTEM
# =============================================================================

class ResultSaver:
    """Centralized system for saving all experiment results"""

    def __init__(self, results_dir):
        self.results_dir = results_dir
        os.makedirs(results_dir, exist_ok=True)

        # Create organized subdirectories
        self.model_dir = os.path.join(results_dir, "models")
        self.metrics_dir = os.path.join(results_dir, "metrics")
        self.plots_dir = os.path.join(results_dir, "plots")
        self.logs_dir = os.path.join(results_dir, "logs")

        for directory in [self.model_dir, self.metrics_dir, self.plots_dir, self.logs_dir]:
            os.makedirs(directory, exist_ok=True)

        print(f"📁 Result directories created in: {results_dir}")

    def save_training_run(self, config_name, seed, training_result, test_metrics, timestamp=None):
        """Save individual training run results"""
        if timestamp is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        run_data = {
            'config_name': config_name,
            'seed': seed,
            'timestamp': timestamp,
            'training_time': training_result['training_time'],
            'final_loss': training_result['final_loss'],
            'best_val_loss': training_result.get('best_val_loss', float('inf')),
            'training_losses': training_result.get('training_losses', []),
            'test_metrics': test_metrics
        }

        run_file = os.path.join(self.logs_dir, f"{config_name}_seed{seed}_{timestamp}.json")
        with open(run_file, 'w') as f:
            json.dump(run_data, f, indent=2)

        return run_file

    def save_model_checkpoint(self, model, config_name, seed, timestamp=None):
        """Save model checkpoint"""
        if timestamp is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        model_file = os.path.join(self.model_dir, f"{config_name}_seed{seed}_{timestamp}.pth")
        torch.save({
            'model_state_dict': model.state_dict(),
            'config_name': config_name,
            'seed': seed,
            'timestamp': timestamp
        }, model_file)

        print(f"   🤖 Saved model checkpoint: {os.path.basename(model_file)}")
        return model_file

    def save_config_results(self, config_name, config_results, timestamp=None):
        """Save results for a specific configuration"""
        if timestamp is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        # Save as JSON
        results_file = os.path.join(self.metrics_dir, f"{config_name}_results_{timestamp}.json")
        serializable_results = {
            'config_name': config_name,
            'timestamp': timestamp,
            'statistics': config_results.get('statistics', {}),
            'runs': []
        }

        for run in config_results.get('runs', []):
            serializable_run = {
                'seed': run['seed'],
                'metrics': run['metrics'],
                'training_time': run['training_time'],
                'final_loss': run['final_loss'],
                'best_val_loss': run['best_val_loss']
            }
            serializable_results['runs'].append(serializable_run)

        with open(results_file, 'w') as f:
            json.dump(serializable_results, f, indent=2)

        # Save as pickle for easier loading
        pickle_file = os.path.join(self.metrics_dir, f"{config_name}_results_{timestamp}.pkl")
        with open(pickle_file, 'wb') as f:
            pickle.dump(config_results, f)

        print(f"   💾 Saved {config_name} results")
        return results_file, pickle_file

    def save_complete_results(self, all_results, timestamp=None):
        """Save complete experiment results"""
        if timestamp is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        # Prepare serializable data
        serializable_complete = {}
        for config_name, config_data in all_results.items():
            serializable_complete[config_name] = {
                'config': config_data.get('config', {}),
                'statistics': config_data.get('statistics', {}),
                'run_count': len(config_data.get('runs', []))
            }

            serializable_complete[config_name]['runs_summary'] = []
            for run in config_data.get('runs', []):
                run_summary = {
                    'seed': run['seed'],
                    'metrics': run['metrics'],
                    'training_time': run['training_time']
                }
                serializable_complete[config_name]['runs_summary'].append(run_summary)

        # Save JSON
        complete_file = os.path.join(self.results_dir, f"complete_results_{timestamp}.json")
        with open(complete_file, 'w') as f:
            json.dump(serializable_complete, f, indent=2)

        # Save pickle
        pickle_file = os.path.join(self.results_dir, f"complete_results_{timestamp}.pkl")
        with open(pickle_file, 'wb') as f:
            pickle.dump(all_results, f)

        print(f"📦 Complete results saved: {os.path.basename(complete_file)}")
        return complete_file, pickle_file

    def save_analysis_results(self, analysis_results, timestamp=None):
        """Save statistical analysis results"""
        if timestamp is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        analysis_file = os.path.join(self.metrics_dir, f"analysis_results_{timestamp}.json")
        with open(analysis_file, 'w') as f:
            json.dump(analysis_results, f, indent=2)

        print(f"📊 Analysis results saved: {os.path.basename(analysis_file)}")
        return analysis_file

    def create_experiment_summary(self, all_results, timestamp=None):
        """Create comprehensive experiment summary report"""
        if timestamp is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        summary_file = os.path.join(self.results_dir, f"experiment_summary_{timestamp}.txt")

        with open(summary_file, 'w') as f:
            f.write("="*80 + "\n")
            f.write("LIGHTGCN EXPERIMENT SUMMARY REPORT\n")
            f.write("="*80 + "\n")
            f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Timestamp: {timestamp}\n\n")

            # Overall summary
            f.write("EXPERIMENT OVERVIEW:\n")
            f.write("-"*50 + "\n")
            f.write(f"Total configurations tested: {len(all_results)}\n")

            if all_results:
                best_config = max(all_results.items(),
                                key=lambda x: x[1]['statistics'].get('ndcg@10', {}).get('mean', 0))
                f.write(f"Best configuration: {best_config[0]}\n")
                f.write(f"Best NDCG@10: {best_config[1]['statistics']['ndcg@10']['mean']:.4f}\n\n")

            # Configuration results table
            f.write("CONFIGURATION RESULTS:\n")
            f.write("-"*50 + "\n")
            f.write(f"{'Configuration':<25} {'NDCG@10':<15} {'AUC':<15} {'Time(s)':<10}\n")
            f.write("-"*65 + "\n")

            for config_name, result in all_results.items():
                stats = result['statistics']
                ndcg_mean = stats.get('ndcg@10', {}).get('mean', 0)
                auc_mean = stats.get('auc', {}).get('mean', 0)
                avg_time = np.mean([run['training_time'] for run in result['runs']]) if result['runs'] else 0
                f.write(f"{config_name:<25} {ndcg_mean:<15.4f} {auc_mean:<15.4f} {avg_time:<10.1f}\n")

            f.write(f"\n")
            f.write("FILES GENERATED:\n")
            f.write("-"*50 + "\n")
            f.write(f"- Complete results: complete_results_{timestamp}.json/.pkl\n")
            f.write(f"- Analysis results: analysis_results_{timestamp}.json\n")
            f.write(f"- Individual config results: {self.metrics_dir}/\n")
            f.write(f"- Training logs: {self.logs_dir}/\n")
            f.write(f"- Model checkpoints: {self.model_dir}/\n")

        print(f"📋 Experiment summary saved: {os.path.basename(summary_file)}")
        return summary_file


# =============================================================================
# 3. CONFIGURATION MANAGEMENT
# =============================================================================

class ExperimentConfig:
    """Centralized experiment configuration"""

    def __init__(self):
        # Experiment settings
        self.random_seeds = [42, 123, 456, 789, 999]
        self.target_playlists = 1500
        self.data_dir = "../data/processed/gnn_ready"
        self.results_dir = "../results/lightgcn_experiments_enhanced"

        # Model architecture
        self.embedding_dim = 128
        self.n_layers = 2
        self.dropout = 0.2

        # Training parameters
        self.batch_size = 256
        self.learning_rate = 0.0005
        self.epochs = 200
        self.early_stopping_patience = 15
        self.val_every = 10
        self.reg_weight = 1e-3

        # Data sampling
        self.max_train_edges = 50000
        self.num_neg_samples = 1

        # Evaluation parameters
        self.k_values = [5, 10, 20]
        self.eval_sample_size = 500

        # Setup
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        os.makedirs(self.results_dir, exist_ok=True)
        self.result_saver = ResultSaver(self.results_dir)

        self._print_config()

    def _print_config(self):
        """Print configuration summary"""
        print(f"🎯 Experiment Configuration:")
        print(f"   📱 Device: {self.device}")
        print(f"   🎲 Seeds: {len(self.random_seeds)} seeds")
        print(f"   📊 Target playlists: {self.target_playlists:,}")
        print(f"   🧠 Embedding dim: {self.embedding_dim}")
        print(f"   ⚡ Learning rate: {self.learning_rate}")
        print(f"   🛡️ Regularization: {self.reg_weight}")
        print(f"   💾 Results dir: {self.results_dir}")

    def get_experiment_configs(self):
        """Get all experimental configurations"""
        return {
            # Phase 1: Graph Structure Ablation
            "baseline": {
                "name": "Baseline",
                "description": "Playlist-track edges only",
                "edge_types": ["playlist_track"],
                "use_features": False,
                "feature_types": [],
                "use_playlist_features": False,
                "use_track_features": False,
                "use_user_features": False
            },
            "with_artists": {
                "name": "With Artists",
                "description": "Add track-artist relationships",
                "edge_types": ["playlist_track", "track_artist"],
                "use_features": False,
                "feature_types": [],
                "use_playlist_features": False,
                "use_track_features": False,
                "use_user_features": False
            },
            "with_users": {
                "name": "With Users",
                "description": "Add user-playlist relationships",
                "edge_types": ["playlist_track", "user_playlist"],
                "use_features": False,
                "feature_types": [],
                "use_playlist_features": False,
                "use_track_features": False,
                "use_user_features": False
            },
            "full_graph": {
                "name": "Full Graph",
                "description": "All edge types",
                "edge_types": ["playlist_track", "track_artist", "user_playlist", "track_album"],
                "use_features": False,
                "feature_types": [],
                "use_playlist_features": False,
                "use_track_features": False,
                "use_user_features": False
            },

            # Phase 2: Feature Importance Analysis (all on baseline graph)
            "playlist_features": {
                "name": "Playlist Features",
                "description": "Only playlist features (6D) on baseline graph",
                "edge_types": ["playlist_track"],
                "use_features": True,
                "feature_types": ["playlist"],
                "use_playlist_features": True,
                "use_track_features": False,
                "use_user_features": False
            },
            "track_features": {
                "name": "Track Features",
                "description": "Only track features (4D) on baseline graph",
                "edge_types": ["playlist_track"],
                "use_features": True,
                "feature_types": ["track"],
                "use_playlist_features": False,
                "use_track_features": True,
                "use_user_features": False
            },
            "user_features": {
                "name": "User Features",
                "description": "Only user features (4D) on baseline graph",
                "edge_types": ["playlist_track"],
                "use_features": True,
                "feature_types": ["user"],
                "use_playlist_features": False,
                "use_track_features": False,
                "use_user_features": True
            },
            "all_features": {
                "name": "All Features",
                "description": "All feature types combined on baseline graph",
                "edge_types": ["playlist_track"],
                "use_features": True,
                "feature_types": ["playlist", "track", "user"],
                "use_playlist_features": True,
                "use_track_features": True,
                "use_user_features": True
            }
        }


# =============================================================================
# 4. DATA PROCESSING
# =============================================================================

class SpotifyDataProcessor:
    """Process real Spotify data for experiments"""

    def __init__(self, config):
        self.config = config

    def load_and_process_data(self, data_path: str, seed: int = 42) -> Dict:
        """Load and process Spotify data"""
        random.seed(seed)
        np.random.seed(seed)

        print(f"🎵 Loading Spotify data from: {data_path}")

        with open(data_path, 'r') as f:
            spotify_data = json.load(f)

        playlists = spotify_data.get('playlists', [])
        print(f"✅ Loaded {len(playlists):,} playlists")

        # Extract entities and create mappings
        entity_mappings = self._create_entity_mappings(playlists)
        entity_counts = {k: len(v) for k, v in entity_mappings.items()}

        print(f"📊 Dataset statistics:")
        for entity_type, count in entity_counts.items():
            print(f"   • {entity_type.title()}: {count:,}")

        # Calculate node offsets
        node_offsets = self._calculate_node_offsets(entity_counts)
        total_nodes = sum(entity_counts.values())

        # Extract edges
        edges = self._extract_edges(playlists, entity_mappings, node_offsets)

        # Create data splits
        splits = self._create_data_splits(edges['playlist_track'], seed=seed)

        # Generate features
        features = self._generate_features(playlists, entity_mappings, entity_counts, seed=seed)

        return {
            'entity_counts': entity_counts,
            'node_offsets': node_offsets,
            'total_nodes': total_nodes,
            'edges': edges,
            'splits': splits,
            'features': features
        }

    def _create_entity_mappings(self, playlists: List[Dict]) -> Dict[str, Dict]:
        """Extract and map all entities"""
        playlist_ids = set()
        track_uris = set()
        artist_uris = set()
        album_uris = set()
        user_ids = set()

        for playlist in playlists:
            pid = playlist.get('pid')
            if pid is not None:
                playlist_ids.add(pid)

            # Extract user ID from playlist name
            name = playlist.get('name', '').strip()
            user_id = name.split()[0] if name else f"user_{pid}"
            user_ids.add(user_id)

            # Extract track info
            for track in playlist.get('tracks', []):
                track_uri = track.get('track_uri', '')
                artist_uri = track.get('artist_uri', '')
                album_uri = track.get('album_uri', '')

                if track_uri:
                    track_uris.add(track_uri)
                if artist_uri:
                    artist_uris.add(artist_uri)
                if album_uri:
                    album_uris.add(album_uri)

        # Create mappings
        entity_mappings = {
            'playlists': {pid: i for i, pid in enumerate(sorted(playlist_ids))},
            'tracks': {uri: i for i, uri in enumerate(sorted(track_uris))},
            'artists': {uri: i for i, uri in enumerate(sorted(artist_uris))},
            'albums': {uri: i for i, uri in enumerate(sorted(album_uris))},
            'users': {uid: i for i, uid in enumerate(sorted(user_ids))}
        }

        return entity_mappings

    def _calculate_node_offsets(self, entity_counts: Dict[str, int]) -> Dict[str, int]:
        """Calculate node offsets for heterogeneous graph"""
        node_offsets = {}
        current_offset = 0

        for entity_type in ['playlists', 'tracks', 'artists', 'albums', 'users']:
            node_offsets[entity_type] = current_offset
            current_offset += entity_counts[entity_type]

        return node_offsets

    def _extract_edges(self, playlists: List[Dict], entity_mappings: Dict, node_offsets: Dict) -> Dict[str, np.ndarray]:
        """Extract all edge types from data"""
        playlist_track_edges = []
        track_artist_edges = []
        user_playlist_edges = []
        track_album_edges = []

        for playlist in playlists:
            pid = playlist.get('pid')
            if pid not in entity_mappings['playlists']:
                continue

            playlist_node = node_offsets['playlists'] + entity_mappings['playlists'][pid]

            # User-playlist edges
            name = playlist.get('name', '').strip()
            user_id = name.split()[0] if name else f"user_{pid}"
            if user_id in entity_mappings['users']:
                user_node = node_offsets['users'] + entity_mappings['users'][user_id]
                user_playlist_edges.append([user_node, playlist_node])

            # Track-related edges
            for track in playlist.get('tracks', []):
                track_uri = track.get('track_uri', '')
                artist_uri = track.get('artist_uri', '')
                album_uri = track.get('album_uri', '')

                if track_uri in entity_mappings['tracks']:
                    track_node = node_offsets['tracks'] + entity_mappings['tracks'][track_uri]

                    # Playlist-track edge
                    playlist_track_edges.append([playlist_node, track_node])

                    # Track-artist edge
                    if artist_uri in entity_mappings['artists']:
                        artist_node = node_offsets['artists'] + entity_mappings['artists'][artist_uri]
                        track_artist_edges.append([track_node, artist_node])

                    # Track-album edge
                    if album_uri in entity_mappings['albums']:
                        album_node = node_offsets['albums'] + entity_mappings['albums'][album_uri]
                        track_album_edges.append([track_node, album_node])

        # Convert to numpy arrays and remove duplicates
        edges = {}
        edge_lists = {
            'playlist_track': playlist_track_edges,
            'track_artist': track_artist_edges,
            'user_playlist': user_playlist_edges,
            'track_album': track_album_edges
        }

        for edge_type, edge_list in edge_lists.items():
            if edge_list:
                edges[edge_type] = np.unique(np.array(edge_list), axis=0)

        print(f"📈 Extracted edges:")
        for edge_type, edge_array in edges.items():
            print(f"   • {edge_type}: {len(edge_array):,} edges")

        return edges

    def _create_data_splits(self, playlist_track_edges: np.ndarray, seed: int = 42) -> Dict[str, np.ndarray]:
        """Create train/val/test splits"""
        np.random.seed(seed)

        shuffled_edges = playlist_track_edges.copy()
        np.random.shuffle(shuffled_edges)

        total_edges = len(shuffled_edges)
        train_size = int(0.7 * total_edges)
        val_size = int(0.15 * total_edges)

        train_edges = shuffled_edges[:train_size]
        val_edges = shuffled_edges[train_size:train_size + val_size]
        test_edges = shuffled_edges[train_size + val_size:]

        print(f"📋 Data splits:")
        print(f"   • Train: {len(train_edges):,} edges ({len(train_edges)/total_edges:.1%})")
        print(f"   • Validation: {len(val_edges):,} edges ({len(val_edges)/total_edges:.1%})")
        print(f"   • Test: {len(test_edges):,} edges ({len(test_edges)/total_edges:.1%})")

        return {
            'train_edges': train_edges,
            'val_edges': val_edges,
            'test_edges': test_edges
        }

    def _generate_features(self, playlists: List[Dict], entity_mappings: Dict,
                          entity_counts: Dict, seed: int = 42) -> Dict[str, np.ndarray]:
        """Generate features based on real data patterns"""
        np.random.seed(seed)

        # Extract real statistics
        playlist_lengths = []
        track_frequencies = Counter()
        user_playlist_counts = Counter()

        for playlist in playlists:
            playlist_lengths.append(len(playlist.get('tracks', [])))

            name = playlist.get('name', '').strip()
            user_id = name.split()[0] if name else f"user_{playlist.get('pid')}"
            user_playlist_counts[user_id] += 1

            for track in playlist.get('tracks', []):
                track_uri = track.get('track_uri', '')
                if track_uri:
                    track_frequencies[track_uri] += 1

        features = {}

        # Playlist features (6 dimensions)
        playlist_features = []
        for pid in sorted(entity_mappings['playlists'].keys()):
            playlist_data = next((p for p in playlists if p.get('pid') == pid), None)

            if playlist_data:
                length = len(playlist_data.get('tracks', []))
                collaborative = 1.0 if length > np.median(playlist_lengths) else 0.0
                followers = max(0, int(np.random.exponential(scale=length/10)))
                length_norm = length / max(playlist_lengths) if playlist_lengths else 0
                followers_norm = min(1.0, followers / 1000)

                playlist_features.append([
                    length_norm, collaborative, followers_norm,
                    np.random.random(), np.random.random(), np.random.random()
                ])
            else:
                playlist_features.append([0, 0, 0, 0, 0, 0])

        features['playlist'] = np.array(playlist_features, dtype=np.float32)

        # Track features (4 dimensions)
        track_features = []
        for track_uri in sorted(entity_mappings['tracks'].keys()):
            frequency = track_frequencies.get(track_uri, 1)
            popularity = min(1.0, np.log1p(frequency) / 10)
            track_features.append([
                popularity,
                np.random.random(),
                np.random.random(),
                np.random.random()
            ])

        features['track'] = np.array(track_features, dtype=np.float32)

        # User features (4 dimensions)
        user_features = []
        for user_id in sorted(entity_mappings['users'].keys()):
            playlist_count = user_playlist_counts.get(user_id, 1)
            activity = min(1.0, np.log1p(playlist_count) / 5)
            user_features.append([
                activity,
                np.random.random(),
                np.random.random(),
                np.random.random()
            ])

        features['user'] = np.array(user_features, dtype=np.float32)

        # Artist and Album features (random placeholders)
        features['artist'] = np.random.randn(entity_counts['artists'], 4).astype(np.float32)
        features['album'] = np.random.randn(entity_counts['albums'], 4).astype(np.float32)

        print(f"🎨 Generated features:")
        for entity_type, feature_matrix in features.items():
            print(f"   • {entity_type}: {feature_matrix.shape}")

        return features


# =============================================================================
# 5. MODEL ARCHITECTURE
# =============================================================================

class LightGCN(nn.Module):
    """LightGCN model with feature integration"""

    def __init__(self, total_nodes, playlist_count, track_count, user_count, embedding_dim, n_layers,
                 playlist_features=None, track_features=None, user_features=None, dropout=0.0, node_offsets=None):
        super(LightGCN, self).__init__()

        self.total_nodes = total_nodes
        self.playlist_count = playlist_count
        self.track_count = track_count
        self.user_count = user_count
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers
        self.dropout = dropout
        self.node_offsets = node_offsets or {}

        # Core embeddings
        self.node_embedding = nn.Embedding(total_nodes, embedding_dim)

        # Feature integration
        self.use_playlist_features = playlist_features is not None
        self.use_track_features = track_features is not None
        self.use_user_features = user_features is not None

        if self.use_playlist_features:
            self.playlist_features = playlist_features
            self.playlist_feature_transform = self._create_feature_transform(playlist_features.size(1))

        if self.use_track_features:
            self.track_features = track_features
            self.track_feature_transform = self._create_feature_transform(track_features.size(1))

        if self.use_user_features:
            self.user_features = user_features
            self.user_feature_transform = self._create_feature_transform(user_features.size(1))

        self.dropout_layer = nn.Dropout(dropout)
        self._init_embeddings()

        print(f"   🧠 LightGCN initialized:")
        print(f"      📏 Nodes: {total_nodes:,}, Embedding: {embedding_dim}, Layers: {n_layers}")
        print(f"      🎨 Features: P:{self.use_playlist_features}, T:{self.use_track_features}, U:{self.use_user_features}")

    def _create_feature_transform(self, feature_dim):
        """Create feature transformation network"""
        return nn.Sequential(
            nn.BatchNorm1d(feature_dim),
            nn.Linear(feature_dim, self.embedding_dim // 2),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.embedding_dim // 2, self.embedding_dim),
            nn.BatchNorm1d(self.embedding_dim)
        )

    def _init_embeddings(self):
        """Initialize embeddings with Xavier uniform"""
        nn.init.xavier_uniform_(self.node_embedding.weight)
        self.node_embedding.weight.data *= 0.1

    def forward(self, adj_matrix):
        """Forward pass with feature integration and message passing"""
        all_embeddings = self.node_embedding.weight.clone()

        # Integrate features
        if self.use_playlist_features:
            playlist_feat = self.playlist_feature_transform(self.playlist_features)
            playlist_start = self.node_offsets.get('playlists', 0)
            playlist_end = playlist_start + self.playlist_count
            all_embeddings[playlist_start:playlist_end] = (
                0.9 * all_embeddings[playlist_start:playlist_end] + 0.1 * playlist_feat
            )

        if self.use_track_features:
            track_feat = self.track_feature_transform(self.track_features)
            track_start = self.node_offsets.get('tracks', self.playlist_count)
            track_end = track_start + self.track_count
            all_embeddings[track_start:track_end] = (
                0.9 * all_embeddings[track_start:track_end] + 0.1 * track_feat
            )

        if self.use_user_features:
            user_feat = self.user_feature_transform(self.user_features)
            user_start = self.node_offsets.get('users', 0)
            user_end = user_start + self.user_count
            all_embeddings[user_start:user_end] = (
                0.9 * all_embeddings[user_start:user_end] + 0.1 * user_feat
            )

        # Message passing layers
        embeddings_layers = [all_embeddings]
        current_embeddings = all_embeddings

        for layer in range(self.n_layers):
            if adj_matrix.device != current_embeddings.device:
                adj_matrix = adj_matrix.to(current_embeddings.device)

            new_embeddings = torch.sparse.mm(adj_matrix, current_embeddings)

            if self.dropout > 0:
                new_embeddings = self.dropout_layer(new_embeddings)

            embeddings_layers.append(new_embeddings)
            current_embeddings = new_embeddings

        # Average all layers
        final_embeddings = torch.mean(torch.stack(embeddings_layers), dim=0)
        return final_embeddings

    def predict(self, playlist_indices, track_indices, all_embeddings=None):
        """Predict scores for playlist-track pairs"""
        if all_embeddings is None:
            raise ValueError("all_embeddings must be provided")

        playlist_embs = all_embeddings[playlist_indices]
        track_embs = all_embeddings[track_indices]

        # Normalized cosine similarity
        playlist_embs = F.normalize(playlist_embs, p=2, dim=1)
        track_embs = F.normalize(track_embs, p=2, dim=1)

        scores = (playlist_embs * track_embs).sum(dim=1)
        return scores


# =============================================================================
# 6. DATA LOADING
# =============================================================================

class MusicRecommendationDataset(Dataset):
    """Dataset for music recommendation with improved negative sampling"""

    def __init__(self, positive_edges, playlist_count, track_count, track_offset,
                 max_edges=None, num_neg_samples=1, seed=None):
        self.seed = seed
        if seed is not None:
            set_all_seeds(seed)

        # Convert edges to playlist-track pairs
        self.positive_pairs = []
        for edge in positive_edges:
            playlist_node, track_node = edge
            playlist_id = playlist_node
            track_id = track_node - track_offset

            if 0 <= playlist_id < playlist_count and 0 <= track_id < track_count:
                self.positive_pairs.append((playlist_id, track_id))

        # Sample training data if needed
        if max_edges and len(self.positive_pairs) > max_edges:
            np.random.seed(seed if seed is not None else 42)
            indices = np.random.choice(len(self.positive_pairs), max_edges, replace=False)
            self.positive_pairs = [self.positive_pairs[i] for i in indices]

        self.playlist_count = playlist_count
        self.track_count = track_count
        self.track_offset = track_offset
        self.num_neg_samples = num_neg_samples

        # Build interaction sets for negative sampling
        self.user_items = defaultdict(set)
        for playlist_id, track_id in self.positive_pairs:
            self.user_items[playlist_id].add(track_id)

        # Create popularity-based negative sampling distribution
        track_popularity = defaultdict(int)
        for _, track_id in self.positive_pairs:
            track_popularity[track_id] += 1

        all_tracks = list(range(track_count))
        track_counts = [track_popularity.get(track_id, 0) for track_id in all_tracks]
        max_count = max(track_counts) if track_counts else 1
        inv_popularity = [max_count - count + 1 for count in track_counts]

        self.neg_sampling_probs = np.array(inv_popularity, dtype=float)
        self.neg_sampling_probs = self.neg_sampling_probs / self.neg_sampling_probs.sum()
        self.rng = np.random.RandomState(seed if seed is not None else 42)

        print(f"      📊 Dataset: {len(self.positive_pairs):,} positive pairs")

    def __len__(self):
        return len(self.positive_pairs)

    def __getitem__(self, idx):
        playlist_id, track_id = self.positive_pairs[idx]

        playlist_node = playlist_id
        track_node = track_id + self.track_offset

        # Negative sampling
        neg_tracks = []
        max_attempts = 100
        attempts = 0

        while len(neg_tracks) < self.num_neg_samples and attempts < max_attempts:
            neg_track_id = self.rng.choice(self.track_count, p=self.neg_sampling_probs)
            if neg_track_id not in self.user_items[playlist_id]:
                neg_track_node = neg_track_id + self.track_offset
                neg_tracks.append(neg_track_node)
            attempts += 1

        # Fill remaining with random
        while len(neg_tracks) < self.num_neg_samples:
            neg_track_id = self.rng.randint(0, self.track_count)
            neg_track_node = neg_track_id + self.track_offset
            neg_tracks.append(neg_track_node)

        return {
            'playlist': torch.LongTensor([playlist_node]),
            'pos_track': torch.LongTensor([track_node]),
            'neg_tracks': torch.LongTensor(neg_tracks)
        }


# =============================================================================
# 7. GRAPH BUILDING
# =============================================================================

class GraphBuilder:
    """Build heterogeneous graphs with improved normalization"""

    def __init__(self, data, config):
        self.data = data
        self.config = config
        self.entity_counts = data['entity_counts']
        self.node_offsets = data['node_offsets']
        self.total_nodes = data['total_nodes']

    def build_adjacency_matrix(self, edge_types, device, seed=None):
        """Build normalized adjacency matrix"""
        if seed is not None:
            set_all_seeds(seed)

        print(f"   🔗 Building adjacency matrix: {edge_types}")

        row_indices = []
        col_indices = []
        edge_count = 0

        # Add edges for each type
        for edge_type in edge_types:
            if edge_type in self.data['edges']:
                edges = self.data['edges'][edge_type]
                print(f"      📊 Adding {edge_type}: {len(edges):,} edges")

                # Add bidirectional edges
                for edge in edges:
                    src, dst = edge
                    row_indices.extend([src, dst])
                    col_indices.extend([dst, src])
                    edge_count += 2

        if not row_indices:
            print("      ❌ No valid edges found!")
            return None

        print(f"      ✅ Total edges: {edge_count:,}")

        # Create and normalize adjacency matrix
        values = np.ones(len(row_indices), dtype=np.float32)
        adj_coo = coo_matrix(
            (values, (row_indices, col_indices)),
            shape=(self.total_nodes, self.total_nodes),
            dtype=np.float32
        )

        adj_normalized = self._symmetric_normalize(adj_coo.tocsr())
        adj_tensor = self._to_torch_sparse(adj_normalized, device)

        print(f"      ✅ Adjacency matrix: {adj_tensor.shape}, nnz: {adj_tensor._nnz()}")
        return adj_tensor

    def _symmetric_normalize(self, adj_matrix):
        """Symmetric normalization with numerical stability"""
        # Add self-loops
        adj_matrix = adj_matrix + csr_matrix(np.eye(adj_matrix.shape[0]))

        # Compute degrees
        degrees = np.array(adj_matrix.sum(axis=1)).flatten()
        degrees_inv_sqrt = np.power(degrees + 1e-10, -0.5)
        degrees_inv_sqrt[np.isinf(degrees_inv_sqrt)] = 0.
        degrees_inv_sqrt[np.isnan(degrees_inv_sqrt)] = 0.

        # Create degree matrix
        diag_indices = np.arange(len(degrees_inv_sqrt))
        degree_matrix = csr_matrix(
            (degrees_inv_sqrt, (diag_indices, diag_indices)),
            shape=(len(degrees_inv_sqrt), len(degrees_inv_sqrt))
        )

        # Apply normalization
        adj_normalized = degree_matrix @ adj_matrix @ degree_matrix
        return adj_normalized

    def _to_torch_sparse(self, scipy_matrix, device):
        """Convert scipy sparse matrix to PyTorch sparse tensor"""
        coo = scipy_matrix.tocoo()
        indices = torch.LongTensor(np.vstack([coo.row, coo.col]))
        values = torch.FloatTensor(coo.data)

        sparse_tensor = torch.sparse_coo_tensor(
            indices, values, coo.shape, device=device
        ).coalesce()

        return sparse_tensor


# =============================================================================
# 8. FEATURE PROCESSING
# =============================================================================

class FeatureProcessor:
    """Process and build feature matrices"""

    def __init__(self, data, config):
        self.data = data
        self.config = config
        self.features = data['features']
        self.entity_counts = data['entity_counts']

    def build_feature_matrices(self, feature_types, device, seed=None):
        """Build feature matrices for specified types"""
        if seed is not None:
            set_all_seeds(seed)

        if not feature_types:
            return None, None, None

        print(f"   🔧 Building features: {feature_types}")

        playlist_features = None
        track_features = None
        user_features = None

        if "playlist" in feature_types:
            playlist_features = self._build_feature_matrix('playlist', device, expected_dim=6)

        if "track" in feature_types:
            track_features = self._build_feature_matrix('track', device, expected_dim=4)

        if "user" in feature_types:
            user_features = self._build_feature_matrix('user', device, expected_dim=4)

        return playlist_features, track_features, user_features

    def _build_feature_matrix(self, feature_type, device, expected_dim):
        """Build feature matrix for a specific type"""
        print(f"      🎨 Building {feature_type} features...")

        if feature_type not in self.features:
            print(f"      ❌ No {feature_type} features found")
            return None

        feature_array = self.features[feature_type]

        if len(feature_array.shape) != 2 or feature_array.shape[1] != expected_dim:
            print(f"      ❌ Unexpected {feature_type} feature shape: {feature_array.shape}")
            return None

        feature_tensor = torch.FloatTensor(feature_array).to(device)
        print(f"      ✅ {feature_type.title()} features: {feature_tensor.shape}")
        return feature_tensor


# =============================================================================
# 9. TRAINING AND EVALUATION
# =============================================================================

class ModelTrainer:
    """Handle model training and evaluation with result saving"""

    def __init__(self, config, data):
        self.config = config
        self.data = data
        self.device = config.device

        # Extract data info
        self.entity_counts = data['entity_counts']
        self.node_offsets = data['node_offsets']
        self.total_nodes = data['total_nodes']
        self.playlist_count = self.entity_counts['playlists']
        self.track_count = self.entity_counts['tracks']
        self.user_count = self.entity_counts['users']
        self.track_offset = self.node_offsets['tracks']

        print(f"🎯 Model Trainer initialized:")
        print(f"   👥 Playlists: {self.playlist_count:,}")
        print(f"   🎵 Tracks: {self.track_count:,}")
        print(f"   👤 Users: {self.user_count:,}")

    def train_model(self, config_spec, seed=None, save_model=True):
        """Train model with given configuration"""
        print(f"\n🚀 Training: {config_spec['name']} (seed={seed})")
        print(f"   📝 {config_spec['description']}")

        if seed is not None:
            set_all_seeds(seed)

        start_time = time.time()
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        try:
            # Build graph
            graph_builder = GraphBuilder(self.data, self.config)
            adj_matrix = graph_builder.build_adjacency_matrix(
                config_spec['edge_types'], self.device, seed=seed
            )

            if adj_matrix is None:
                print("   ❌ Failed to build adjacency matrix")
                return None

            # Build features
            playlist_features = None
            track_features = None
            user_features = None

            if config_spec.get('use_features', False) and config_spec.get('feature_types'):
                feature_processor = FeatureProcessor(self.data, self.config)
                playlist_features, track_features, user_features = feature_processor.build_feature_matrices(
                    config_spec['feature_types'], self.device, seed=seed
                )

            # Initialize model
            model = LightGCN(
                total_nodes=self.total_nodes,
                playlist_count=self.playlist_count,
                track_count=self.track_count,
                user_count=self.user_count,
                embedding_dim=self.config.embedding_dim,
                n_layers=self.config.n_layers,
                playlist_features=playlist_features,
                track_features=track_features,
                user_features=user_features,
                dropout=self.config.dropout,
                node_offsets=self.node_offsets
            ).to(self.device)

            # Train
            train_result = self._train_loop(model, adj_matrix, seed)
            if train_result is None:
                return None

            training_time = time.time() - start_time

            # Save model if requested
            model_file = None
            if save_model:
                model_file = self.config.result_saver.save_model_checkpoint(
                    model, config_spec['name'], seed, timestamp
                )

            return {
                'model': model,
                'adj_matrix': adj_matrix,
                'training_losses': train_result['losses'],
                'training_time': training_time,
                'final_loss': train_result['final_loss'],
                'best_val_loss': train_result.get('best_val_loss', float('inf')),
                'seed': seed,
                'config': config_spec,
                'model_file': model_file,
                'timestamp': timestamp
            }

        except Exception as e:
            print(f"   ❌ Training failed: {e}")
            import traceback
            traceback.print_exc()
            return None

    def _train_loop(self, model, adj_matrix, seed):
        """Main training loop with validation and early stopping"""
        # Prepare datasets
        train_dataset = MusicRecommendationDataset(
            self.data['splits']['train_edges'],
            self.playlist_count, self.track_count, self.track_offset,
            max_edges=self.config.max_train_edges,
            num_neg_samples=self.config.num_neg_samples, seed=seed
        )

        val_dataset = MusicRecommendationDataset(
            self.data['splits']['val_edges'],
            self.playlist_count, self.track_count, self.track_offset,
            max_edges=5000, num_neg_samples=self.config.num_neg_samples,
            seed=seed + 1 if seed else 43
        )

        # Create data loaders
        generator = torch.Generator()
        if seed is not None:
            generator.manual_seed(seed)

        train_loader = DataLoader(
            train_dataset, batch_size=self.config.batch_size, shuffle=True,
            num_workers=0, worker_init_fn=seed_worker, generator=generator, drop_last=True
        )

        val_loader = DataLoader(
            val_dataset, batch_size=self.config.batch_size, shuffle=False,
            num_workers=0, drop_last=False
        )

        # Optimizer and scheduler
        optimizer = optim.AdamW(
            model.parameters(), lr=self.config.learning_rate,
            weight_decay=self.config.reg_weight, betas=(0.9, 0.999), eps=1e-8
        )

        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.7, patience=8, verbose=False, min_lr=1e-6
        )

        # Training variables
        training_losses = []
        validation_losses = []
        best_val_loss = float('inf')
        patience_counter = 0

        print(f"   🏃 Training: {len(train_dataset):,} samples, validating: {len(val_dataset):,}")

        for epoch in range(self.config.epochs):
            # Training phase
            model.train()
            epoch_loss = 0
            num_batches = 0

            for batch in train_loader:
                optimizer.zero_grad()

                playlists = batch['playlist'].squeeze().to(self.device)
                pos_tracks = batch['pos_track'].squeeze().to(self.device)
                neg_tracks = batch['neg_tracks'].to(self.device)

                # Validate indices
                if (playlists.max() >= self.total_nodes or
                    pos_tracks.max() >= self.total_nodes or
                    neg_tracks.max() >= self.total_nodes):
                    continue

                # Forward pass
                all_embeddings = model(adj_matrix)
                pos_scores = model.predict(playlists, pos_tracks, all_embeddings)

                # Negative scores
                batch_size = playlists.size(0)
                neg_samples = neg_tracks.size(1)
                playlists_expanded = playlists.unsqueeze(1).expand(-1, neg_samples).contiguous().view(-1)
                neg_tracks_flat = neg_tracks.view(-1)
                neg_scores = model.predict(playlists_expanded, neg_tracks_flat, all_embeddings)
                neg_scores = neg_scores.view(batch_size, neg_samples)

                # BPR loss
                loss = self._bpr_loss(pos_scores, neg_scores)

                # Regularization
                reg_loss = 0
                for name, param in model.named_parameters():
                    if 'embedding' in name or 'transform' in name:
                        reg_loss += torch.norm(param, p=2)

                total_loss = loss + self.config.reg_weight * reg_loss

                if torch.isnan(total_loss) or torch.isinf(total_loss):
                    continue

                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                epoch_loss += total_loss.item()
                num_batches += 1

            if num_batches == 0:
                break

            avg_train_loss = epoch_loss / num_batches
            training_losses.append(avg_train_loss)

            # Validation phase
            if (epoch + 1) % self.config.val_every == 0:
                model.eval()
                val_loss = 0
                val_batches = 0

                with torch.no_grad():
                    for batch in val_loader:
                        playlists = batch['playlist'].squeeze().to(self.device)
                        pos_tracks = batch['pos_track'].squeeze().to(self.device)
                        neg_tracks = batch['neg_tracks'].to(self.device)

                        if (playlists.max() >= self.total_nodes or
                            pos_tracks.max() >= self.total_nodes or
                            neg_tracks.max() >= self.total_nodes):
                            continue

                        all_embeddings = model(adj_matrix)
                        pos_scores = model.predict(playlists, pos_tracks, all_embeddings)

                        batch_size = playlists.size(0)
                        neg_samples = neg_tracks.size(1)
                        playlists_expanded = playlists.unsqueeze(1).expand(-1, neg_samples).contiguous().view(-1)
                        neg_tracks_flat = neg_tracks.view(-1)
                        neg_scores = model.predict(playlists_expanded, neg_tracks_flat, all_embeddings)
                        neg_scores = neg_scores.view(batch_size, neg_samples)

                        batch_val_loss = self._bpr_loss(pos_scores, neg_scores)
                        val_loss += batch_val_loss.item()
                        val_batches += 1

                if val_batches > 0:
                    avg_val_loss = val_loss / val_batches
                    validation_losses.append(avg_val_loss)
                    scheduler.step(avg_val_loss)

                    # Early stopping
                    if avg_val_loss < best_val_loss:
                        best_val_loss = avg_val_loss
                        patience_counter = 0
                    else:
                        patience_counter += 1

                    if patience_counter >= self.config.early_stopping_patience:
                        print(f"      ⏰ Early stopping at epoch {epoch + 1}")
                        break

                    if (epoch + 1) % 50 == 0:
                        print(f"      Epoch {epoch + 1}: Train={avg_train_loss:.4f}, Val={avg_val_loss:.4f}")

        return {
            'losses': training_losses,
            'val_losses': validation_losses,
            'final_loss': training_losses[-1] if training_losses else float('inf'),
            'best_val_loss': best_val_loss
        }

    def _bpr_loss(self, pos_scores, neg_scores):
        """Bayesian Personalized Ranking loss"""
        pos_scores_expanded = pos_scores.unsqueeze(1)
        diff = pos_scores_expanded - neg_scores
        diff = torch.clamp(diff, min=-15, max=15)
        loss = -torch.log(torch.sigmoid(diff) + 1e-8).mean()
        return loss

    def evaluate_model(self, model, adj_matrix, split='test', seed=None):
        """Evaluate model on test/validation set"""
        if seed is not None:
            set_all_seeds(seed)

        print(f"   📊 Evaluating on {split} set...")

        model.eval()
        with torch.no_grad():
            all_embeddings = model(adj_matrix)

            if split == 'test':
                test_edges = self.data['splits']['test_edges']
            else:
                test_edges = self.data['splits']['val_edges']

            metrics = self._calculate_metrics(all_embeddings, test_edges, seed)

        model.train()
        return metrics

    def _calculate_metrics(self, all_embeddings, test_edges, seed):
        """Calculate recommendation metrics"""
        if seed is not None:
            np.random.seed(seed)

        # Convert test edges to playlist-track pairs
        test_pairs = []
        for edge in test_edges:
            playlist_node, track_node = edge
            playlist_id = playlist_node
            track_id = track_node - self.track_offset

            if 0 <= playlist_id < self.playlist_count and 0 <= track_id < self.track_count:
                test_pairs.append((playlist_id, track_id))

        # Group by playlist
        playlist_test_tracks = defaultdict(set)
        for playlist_id, track_id in test_pairs:
            playlist_test_tracks[playlist_id].add(track_id)

        # Filter valid playlists
        valid_playlists = [p for p, tracks in playlist_test_tracks.items()
                          if len(tracks) >= 2 and p < self.playlist_count]

        if len(valid_playlists) > self.config.eval_sample_size:
            valid_playlists = np.random.choice(
                valid_playlists, self.config.eval_sample_size, replace=False
            )

        # Calculate metrics
        all_precisions = {k: [] for k in self.config.k_values}
        all_recalls = {k: [] for k in self.config.k_values}
        all_ndcgs = {k: [] for k in self.config.k_values}
        all_auc_scores = []
        valid_evaluations = 0

        for playlist_id in valid_playlists:
            pos_tracks = list(playlist_test_tracks[playlist_id])
            if len(pos_tracks) < 2:
                continue

            # Sample negative tracks
            all_tracks = set(range(self.track_count))
            available_negatives = list(all_tracks - playlist_test_tracks[playlist_id])

            if len(available_negatives) >= 99:
                neg_tracks = np.random.choice(available_negatives, 99, replace=False)
            else:
                continue

            # Calculate scores
            all_candidate_tracks = pos_tracks + list(neg_tracks)
            playlist_emb = all_embeddings[playlist_id]
            track_nodes = [tid + self.track_offset for tid in all_candidate_tracks]
            track_embs = all_embeddings[track_nodes]

            # Normalized scores
            playlist_emb_norm = F.normalize(playlist_emb.unsqueeze(0), p=2, dim=1)
            track_embs_norm = F.normalize(track_embs, p=2, dim=1)
            scores = torch.matmul(playlist_emb_norm, track_embs_norm.t()).squeeze().cpu().numpy()

            # Ground truth
            ground_truth = np.array([1] * len(pos_tracks) + [0] * len(neg_tracks))

            # AUC
            try:
                if len(np.unique(ground_truth)) > 1 and len(np.unique(scores)) > 1:
                    auc = roc_auc_score(ground_truth, scores)
                    all_auc_scores.append(auc)
            except Exception:
                continue

            # Ranking metrics
            sorted_indices = np.argsort(scores)[::-1]
            sorted_tracks = [all_candidate_tracks[i] for i in sorted_indices]
            relevant_tracks = set(pos_tracks)

            for k in self.config.k_values:
                if k > len(sorted_tracks):
                    continue

                top_k_tracks = sorted_tracks[:k]
                recommended_relevant = set(top_k_tracks) & relevant_tracks

                # Precision@K
                precision = len(recommended_relevant) / k if k > 0 else 0
                all_precisions[k].append(precision)

                # Recall@K
                recall = len(recommended_relevant) / len(relevant_tracks) if len(relevant_tracks) > 0 else 0
                all_recalls[k].append(recall)

                # NDCG@K
                ndcg = self._calculate_ndcg(top_k_tracks, relevant_tracks, k)
                all_ndcgs[k].append(ndcg)

            valid_evaluations += 1

        # Aggregate metrics
        metrics = {}
        if valid_evaluations == 0:
            print("      ❌ No valid evaluations!")
            return {f'{metric}@{k}': 0.0 for metric in ['precision', 'recall', 'ndcg'] for k in self.config.k_values}

        for k in self.config.k_values:
            metrics[f'precision@{k}'] = np.mean(all_precisions[k]) if all_precisions[k] else 0.0
            metrics[f'recall@{k}'] = np.mean(all_recalls[k]) if all_recalls[k] else 0.0
            metrics[f'ndcg@{k}'] = np.mean(all_ndcgs[k]) if all_ndcgs[k] else 0.0

        metrics['auc'] = np.mean(all_auc_scores) if all_auc_scores else 0.0

        print(f"      ✅ Evaluated {valid_evaluations} playlists")
        print(f"      📊 NDCG@10: {metrics.get('ndcg@10', 0):.4f}, AUC: {metrics.get('auc', 0):.4f}")

        return metrics

    def _calculate_ndcg(self, ranked_list, relevant_items, k):
        """Calculate NDCG@K"""
        if k == 0 or not relevant_items:
            return 0.0

        # DCG
        dcg = 0.0
        for i, item in enumerate(ranked_list[:k]):
            if item in relevant_items:
                dcg += 1.0 / math.log2(i + 2)

        # IDCG
        idcg = 0.0
        for i in range(min(k, len(relevant_items))):
            idcg += 1.0 / math.log2(i + 2)

        return dcg / idcg if idcg > 0 else 0.0


# =============================================================================
# 10. EXPERIMENT ORCHESTRATION
# =============================================================================

class ExperimentRunner:
    """Main experiment orchestrator with comprehensive result saving"""

    def __init__(self, config, data):
        self.config = config
        self.data = data
        self.trainer = ModelTrainer(config, data)
        self.experiment_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        print("🎯 EXPERIMENT RUNNER INITIALIZED")
        print(f"   📊 Dataset: {data['entity_counts']['playlists']:,} playlists")
        print(f"   🎲 Seeds per config: {len(config.random_seeds)}")
        print(f"   📁 Results dir: {config.results_dir}")
        print(f"   🕐 Timestamp: {self.experiment_timestamp}")

    def run_all_experiments(self):
        """Run all experiments with comprehensive result saving"""
        print("\n" + "="*80)
        print("🔬 STARTING LIGHTGCN EXPERIMENTS WITH RESULT SAVING")
        print("="*80)

        experiment_configs = self.config.get_experiment_configs()
        results = {}

        print(f"\n🔬 Running {len(experiment_configs)} configurations:")
        for name, config in experiment_configs.items():
            print(f"   • {name}: {config['description']}")

        # Save experiment configuration
        self._save_experiment_config(experiment_configs)

        # Run each configuration
        for config_name, config_spec in experiment_configs.items():
            print(f"\n{'='*60}")
            print(f"🧪 Configuration: {config_spec['name']}")
            print(f"📝 {config_spec['description']}")

            config_results = self._run_single_config(config_name, config_spec)

            if config_results:
                results[config_name] = config_results
                # Save configuration results immediately
                self.config.result_saver.save_config_results(
                    config_name, config_results, self.experiment_timestamp
                )

        # Save complete results and analysis
        if results:
            self.config.result_saver.save_complete_results(results, self.experiment_timestamp)
            self._print_final_results(results)
            analysis_results = self._perform_statistical_analysis(results)
            self.config.result_saver.save_analysis_results(analysis_results, self.experiment_timestamp)
            self.config.result_saver.create_experiment_summary(results, self.experiment_timestamp)

            print(f"\n🎉 ALL RESULTS SAVED TO: {self.config.results_dir}")
            print(f"📁 Experiment timestamp: {self.experiment_timestamp}")

        return results

    def _save_experiment_config(self, experiment_configs):
        """Save experiment configuration"""
        config_info = {
            'timestamp': self.experiment_timestamp,
            'configurations': experiment_configs,
            'settings': {
                'random_seeds': self.config.random_seeds,
                'target_playlists': self.config.target_playlists,
                'embedding_dim': self.config.embedding_dim,
                'n_layers': self.config.n_layers,
                'learning_rate': self.config.learning_rate,
                'epochs': self.config.epochs,
                'batch_size': self.config.batch_size,
                'k_values': self.config.k_values
            },
            'dataset_info': {
                'entity_counts': self.data['entity_counts'],
                'node_offsets': self.data['node_offsets'],
                'total_nodes': self.data['total_nodes']
            }
        }

        config_file = os.path.join(self.config.results_dir, f"experiment_config_{self.experiment_timestamp}.json")
        with open(config_file, 'w') as f:
            json.dump(config_info, f, indent=2)
        print(f"💾 Experiment configuration saved: {os.path.basename(config_file)}")

    def _run_single_config(self, config_name, config_spec):
        """Run all seeds for a single configuration"""
        config_results = []

        for seed_idx, seed in enumerate(self.config.random_seeds):
            print(f"\n🎲 Seed {seed_idx+1}/{len(self.config.random_seeds)} (seed={seed}):")

            # Train model
            training_result = self.trainer.train_model(config_spec, seed=seed, save_model=True)

            if training_result is None:
                print(f"   ❌ Training failed for seed {seed}")
                continue

            # Evaluate model
            test_metrics = self.trainer.evaluate_model(
                training_result['model'],
                training_result['adj_matrix'],
                'test',
                seed=seed
            )

            # Store results
            run_result = {
                'seed': seed,
                'metrics': test_metrics,
                'training_time': training_result['training_time'],
                'final_loss': training_result['final_loss'],
                'best_val_loss': training_result['best_val_loss'],
                'model_file': training_result.get('model_file'),
                'timestamp': training_result.get('timestamp')
            }
            config_results.append(run_result)

            # Save individual run results
            self.config.result_saver.save_training_run(
                config_name, seed, training_result, test_metrics,
                training_result.get('timestamp')
            )

            # Print immediate results
            print(f"      📊 NDCG@10: {test_metrics.get('ndcg@10', 0):.4f}")
            print(f"      📊 AUC: {test_metrics.get('auc', 0):.4f}")
            print(f"      ⏱️ Time: {training_result['training_time']:.1f}s")

            # Memory cleanup
            del training_result
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        # Calculate statistics
        if config_results:
            stats = self._calculate_statistics(config_results)
            config_data = {
                'config': config_spec,
                'runs': config_results,
                'statistics': stats
            }

            # Print configuration summary
            ndcg_mean = stats.get('ndcg@10', {}).get('mean', 0)
            ndcg_std = stats.get('ndcg@10', {}).get('std', 0)
            print(f"\n   📊 CONFIGURATION SUMMARY: {config_spec['name']}")
            print(f"      NDCG@10: {ndcg_mean:.4f} ± {ndcg_std:.4f}")
            print(f"      ✅ Results saved for {config_name}")

            return config_data

        return None

    def _calculate_statistics(self, config_results):
        """Calculate statistics for a configuration"""
        if not config_results:
            return {}

        statistics = {}

        # Get all metric names
        all_metrics = set()
        for run in config_results:
            all_metrics.update(run['metrics'].keys())

        # Calculate statistics for each metric
        for metric in all_metrics:
            values = [run['metrics'].get(metric, 0) for run in config_results]

            if values and len(values) > 1:
                mean_val = np.mean(values)
                std_val = np.std(values, ddof=1)
                n = len(values)

                # Confidence intervals
                if n > 2:
                    t_value = scipy_stats.t.ppf(0.975, n-1)
                    margin_error = t_value * std_val / np.sqrt(n)
                    ci_lower = mean_val - margin_error
                    ci_upper = mean_val + margin_error
                else:
                    ci_lower = mean_val
                    ci_upper = mean_val

                statistics[metric] = {
                    'mean': mean_val,
                    'std': std_val,
                    'min': np.min(values),
                    'max': np.max(values),
                    'ci_lower': ci_lower,
                    'ci_upper': ci_upper,
                    'n': n
                }
            elif values:
                statistics[metric] = {
                    'mean': values[0],
                    'std': 0.0,
                    'min': values[0],
                    'max': values[0],
                    'ci_lower': values[0],
                    'ci_upper': values[0],
                    'n': 1
                }

        return statistics

    def _print_final_results(self, results):
        """Print formatted final results"""
        print("\n" + "="*80)
        print("📊 EXPERIMENTAL RESULTS")
        print("="*80)

        if not results:
            print("❌ No results to display")
            return

        print("\nConfiguration                    NDCG@10 (Mean±Std)      AUC (Mean±Std)       Time (s)")
        print("-" * 85)

        # Sort by NDCG@10 performance
        sorted_results = sorted(results.items(),
                              key=lambda x: x[1]['statistics'].get('ndcg@10', {}).get('mean', 0),
                              reverse=True)

        for config_name, result in sorted_results:
            result_stats = result['statistics']

            ndcg_mean = result_stats.get('ndcg@10', {}).get('mean', 0)
            ndcg_std = result_stats.get('ndcg@10', {}).get('std', 0)
            auc_mean = result_stats.get('auc', {}).get('mean', 0)
            auc_std = result_stats.get('auc', {}).get('std', 0)
            time_mean = np.mean([run['training_time'] for run in result['runs']]) if result['runs'] else 0

            print(f"{config_name:<30} {ndcg_mean:.4f}±{ndcg_std:.4f}        {auc_mean:.4f}±{auc_std:.4f}       {time_mean:.1f}")

        # Highlight best configuration
        if sorted_results:
            best_config = sorted_results[0]
            best_result_stats = best_config[1]['statistics']['ndcg@10']
            print(f"\n🏆 BEST CONFIGURATION: {best_config[0]}")
            print(f"   📊 NDCG@10: {best_result_stats['mean']:.4f} ± {best_result_stats['std']:.4f}")
            print(f"   🔍 95% CI: [{best_result_stats['ci_lower']:.4f}, {best_result_stats['ci_upper']:.4f}]")

        # Performance insights
        self._print_performance_insights(results)

    def _print_performance_insights(self, results):
        """Print performance insights"""
        print(f"\n💡 PERFORMANCE INSIGHTS:")

        feature_configs = [name for name in results.keys() if 'features' in name]
        non_feature_configs = [name for name in results.keys() if 'features' not in name]

        if feature_configs and non_feature_configs:
            best_with_features = max(feature_configs,
                                   key=lambda x: results[x]['statistics'].get('ndcg@10', {}).get('mean', 0))
            best_without_features = max(non_feature_configs,
                                      key=lambda x: results[x]['statistics'].get('ndcg@10', {}).get('mean', 0))

            feat_performance = results[best_with_features]['statistics']['ndcg@10']['mean']
            no_feat_performance = results[best_without_features]['statistics']['ndcg@10']['mean']
            improvement = (feat_performance - no_feat_performance) / no_feat_performance * 100

            print(f"   🎯 Feature Impact: {improvement:+.1f}% improvement")
            print(f"      Best w/ features: {best_with_features} ({feat_performance:.4f})")
            print(f"      Best w/o features: {best_without_features} ({no_feat_performance:.4f})")

    def _perform_statistical_analysis(self, results):
        """Perform comprehensive statistical analysis"""
        print("\n" + "="*80)
        print("🔬 STATISTICAL SIGNIFICANCE ANALYSIS")
        print("="*80)

        analysis_results = {
            'timestamp': self.experiment_timestamp,
            'pairwise_comparisons': [],
            'significant_differences': [],
            'feature_importance_analysis': {},
            'graph_structure_analysis': {}
        }

        if len(results) < 2:
            print("❌ Need at least 2 configurations for statistical testing")
            return analysis_results

        # Pairwise statistical tests
        config_values = {}
        for config_name, result in results.items():
            values = [run['metrics'].get('ndcg@10', 0) for run in result['runs']]
            config_values[config_name] = values

        config_names = list(config_values.keys())
        significant_pairs = []

        print("\nPairwise analysis (NDCG@10):")
        print("Configuration 1          vs Configuration 2          p-value    Effect Size  Significant")
        print("-" * 95)

        for i in range(len(config_names)):
            for j in range(i+1, len(config_names)):
                config1 = config_names[i]
                config2 = config_names[j]
                values1 = config_values[config1]
                values2 = config_values[config2]

                comparison_result = {
                    'config1': config1,
                    'config2': config2,
                    'config1_mean': np.mean(values1) if values1 else 0,
                    'config2_mean': np.mean(values2) if values2 else 0
                }

                if len(values1) > 1 and len(values2) > 1:
                    # t-test
                    t_stat, p_val = scipy_stats.ttest_ind(values1, values2)

                    # Effect size (Cohen's d)
                    pooled_std = np.sqrt(((len(values1) - 1) * np.var(values1, ddof=1) +
                                        (len(values2) - 1) * np.var(values2, ddof=1)) /
                                       (len(values1) + len(values2) - 2))

                    cohens_d = abs(np.mean(values1) - np.mean(values2)) / pooled_std if pooled_std > 0 else 0

                    comparison_result.update({
                        'p_value': p_val,
                        'cohens_d': cohens_d,
                        't_statistic': t_stat
                    })

                    # Significance levels
                    if p_val < 0.001:
                        significance = "***"
                        significant_pairs.append((config1, config2, p_val, cohens_d))
                    elif p_val < 0.01:
                        significance = "**"
                        significant_pairs.append((config1, config2, p_val, cohens_d))
                    elif p_val < 0.05:
                        significance = "*"
                        significant_pairs.append((config1, config2, p_val, cohens_d))
                    else:
                        significance = "n.s."

                    comparison_result['significance'] = significance

                    effect_size = "small" if cohens_d < 0.2 else "medium" if cohens_d < 0.5 else "large" if cohens_d < 0.8 else "very large"
                    comparison_result['effect_size_interpretation'] = effect_size

                    print(f"{config1:<25} vs {config2:<25} {p_val:.4f}      {cohens_d:.3f}({effect_size:<5}) {significance}")
                else:
                    comparison_result.update({
                        'p_value': None,
                        'cohens_d': None,
                        't_statistic': None,
                        'significance': 'insufficient_data',
                        'effect_size_interpretation': 'unknown'
                    })
                    print(f"{config1:<25} vs {config2:<25} N/A        N/A            insufficient data")

                analysis_results['pairwise_comparisons'].append(comparison_result)

        # Summary of significant differences
        if significant_pairs:
            print(f"\n📊 SIGNIFICANT DIFFERENCES FOUND:")
            for config1, config2, p_val, effect_size in significant_pairs:
                mean1 = np.mean(config_values[config1])
                mean2 = np.mean(config_values[config2])
                better = config1 if mean1 > mean2 else config2
                worse = config2 if mean1 > mean2 else config1
                improvement = abs(mean1 - mean2) / min(mean1, mean2) * 100

                significant_diff = {
                    'better_config': better,
                    'worse_config': worse,
                    'improvement_percent': improvement,
                    'p_value': p_val,
                    'effect_size': effect_size
                }
                analysis_results['significant_differences'].append(significant_diff)

                print(f"   🔸 {better} > {worse}")
                print(f"      Improvement: {improvement:.1f}%, p={p_val:.4f}, effect size={effect_size:.3f}")
        else:
            print(f"\n📊 NO SIGNIFICANT DIFFERENCES FOUND")

        # Feature and graph analysis
        analysis_results['feature_importance_analysis'] = self._analyze_feature_importance(results)
        analysis_results['graph_structure_analysis'] = self._analyze_graph_structure(results)

        return analysis_results

    def _analyze_feature_importance(self, results):
        """Analyze feature importance"""
        analysis = {
            'baseline_performance': None,
            'individual_feature_impacts': {},
            'combined_feature_effect': None,
            'feature_ranking': [],
            'synergy_analysis': None,
            'recommendations': {}
        }

        # Check for required configurations
        required_configs = ['playlist_features', 'track_features', 'user_features', 'all_features']
        available_configs = [config for config in required_configs if config in results]

        if not available_configs:
            analysis['error'] = "No feature configurations found"
            return analysis

        # Get baseline performance
        if 'baseline' in results:
            baseline_performance = results['baseline']['statistics']['ndcg@10']['mean']
            baseline_std = results['baseline']['statistics']['ndcg@10']['std']
            analysis['baseline_performance'] = {
                'ndcg_mean': baseline_performance,
                'ndcg_std': baseline_std
            }
        else:
            analysis['error'] = "No baseline configuration found"
            return analysis

        # Individual feature impacts
        feature_configs = {
            'playlist_features': 'Playlist Features (6D)',
            'track_features': 'Track Features (4D)',
            'user_features': 'User Features (4D)'
        }

        feature_improvements = {}

        for config_name, description in feature_configs.items():
            if config_name in results:
                perf = results[config_name]['statistics']['ndcg@10']['mean']
                std = results[config_name]['statistics']['ndcg@10']['std']
                improvement = (perf - baseline_performance) / baseline_performance * 100
                feature_improvements[config_name] = improvement

                feature_impact = {
                    'description': description,
                    'ndcg_mean': perf,
                    'ndcg_std': std,
                    'improvement_percent': improvement
                }

                # Statistical significance
                baseline_values = [run['metrics']['ndcg@10'] for run in results['baseline']['runs']]
                feature_values = [run['metrics']['ndcg@10'] for run in results[config_name]['runs']]

                if len(baseline_values) > 1 and len(feature_values) > 1:
                    t_stat, p_val = scipy_stats.ttest_ind(feature_values, baseline_values)
                    significance = "***" if p_val < 0.001 else "**" if p_val < 0.01 else "*" if p_val < 0.05 else "n.s."
                    feature_impact['p_value'] = p_val
                    feature_impact['significance'] = significance

                analysis['individual_feature_impacts'][config_name] = feature_impact

        # Combined feature effect
        if 'all_features' in results:
            combined_perf = results['all_features']['statistics']['ndcg@10']['mean']
            combined_std = results['all_features']['statistics']['ndcg@10']['std']
            combined_improvement = (combined_perf - baseline_performance) / baseline_performance * 100

            analysis['combined_feature_effect'] = {
                'ndcg_mean': combined_perf,
                'ndcg_std': combined_std,
                'improvement_percent': combined_improvement
            }

            # Synergy analysis
            if len(feature_improvements) >= 3:
                expected_improvement = sum(feature_improvements.values())
                synergy_effect = combined_improvement - expected_improvement

                synergy_interpretation = "neutral"
                if synergy_effect > 2.0:
                    synergy_interpretation = "strong_positive"
                elif synergy_effect > 0.5:
                    synergy_interpretation = "moderate_positive"
                elif synergy_effect < -2.0:
                    synergy_interpretation = "strong_negative"
                elif synergy_effect < -0.5:
                    synergy_interpretation = "moderate_negative"

                analysis['synergy_analysis'] = {
                    'expected_additive_improvement': expected_improvement,
                    'actual_combined_improvement': combined_improvement,
                    'synergy_effect': synergy_effect,
                    'interpretation': synergy_interpretation
                }

        # Feature ranking
        feature_ranking = sorted(feature_improvements.items(), key=lambda x: x[1], reverse=True)
        analysis['feature_ranking'] = [
            {
                'rank': i + 1,
                'feature_type': config_name.replace('_features', '').replace('_', ' ').title(),
                'config_name': config_name,
                'improvement_percent': improvement
            }
            for i, (config_name, improvement) in enumerate(feature_ranking)
        ]

        return analysis

    def _analyze_graph_structure(self, results):
        """Analyze graph structure importance"""
        analysis = {
            'configurations': {},
            'best_structure': None,
            'structure_ranking': []
        }

        phase1_configs = ['baseline', 'with_artists', 'with_users', 'full_graph']
        structure_performances = {}

        for config_name in phase1_configs:
            if config_name in results:
                perf = results[config_name]['statistics']['ndcg@10']['mean']
                std = results[config_name]['statistics']['ndcg@10']['std']
                description = results[config_name]['config']['description']

                structure_performances[config_name] = perf
                analysis['configurations'][config_name] = {
                    'ndcg_mean': perf,
                    'ndcg_std': std,
                    'description': description
                }

        # Best structure
        if structure_performances:
            best_structure = max(structure_performances.items(), key=lambda x: x[1])
            baseline_perf = structure_performances.get('baseline', 0)

            analysis['best_structure'] = {
                'config_name': best_structure[0],
                'ndcg_mean': best_structure[1],
                'improvement_over_baseline': (best_structure[1] - baseline_perf) / baseline_perf * 100 if baseline_perf > 0 else 0
            }

            # Structure ranking
            structure_ranking = sorted(structure_performances.items(), key=lambda x: x[1], reverse=True)
            analysis['structure_ranking'] = [
                {
                    'rank': i + 1,
                    'config_name': config_name,
                    'ndcg_mean': perf,
                    'improvement_over_baseline': (perf - baseline_perf) / baseline_perf * 100 if baseline_perf > 0 else 0
                }
                for i, (config_name, perf) in enumerate(structure_ranking)
            ]

        return analysis


# =============================================================================
# 11. MAIN EXECUTION
# =============================================================================

def run_lightgcn_experiments(data_path="../data/processed/spotify_scaled_hybrid_7500.json"):
    """
    Main function to run LightGCN experiments with comprehensive result saving
    """
    print("🎯 Starting Enhanced LightGCN Experiments with Result Saving...")

    # 1. Setup
    config = ExperimentConfig()
    print(f"📁 All results will be saved to: {config.results_dir}")

    # 2. Load and process data
    data_processor = SpotifyDataProcessor(config)
    data = data_processor.load_and_process_data(data_path, seed=42)

    # 3. Run experiments
    experiment_runner = ExperimentRunner(config, data)
    results = experiment_runner.run_all_experiments()

    # 4. Final summary
    if results:
        print("\n🎉 All experiments completed successfully!")

        print(f"\n📋 FINAL SUMMARY:")
        print(f"   🔬 Configurations tested: {len(results)}")
        print(f"   📁 Results directory: {config.results_dir}")
        print(f"   🕐 Experiment timestamp: {experiment_runner.experiment_timestamp}")

        # List saved files
        saved_files = []
        for root, dirs, files in os.walk(config.results_dir):
            for file in files:
                if experiment_runner.experiment_timestamp in file:
                    saved_files.append(os.path.join(root, file))

        print(f"\n📄 Generated {len(saved_files)} files:")
        for file_path in sorted(saved_files)[:10]:
            rel_path = os.path.relpath(file_path, config.results_dir)
            print(f"   📄 {rel_path}")

        if len(saved_files) > 10:
            print(f"   ... and {len(saved_files) - 10} more files")

        return results
    else:
        print("\n❌ Experiments failed!")
        return None


if __name__ == "__main__":
    print("🎯 Enhanced LightGCN Experiments with Comprehensive Result Saving")
    print("\n📋 EXPERIMENTAL DESIGN:")
    print("   Phase 1: Graph Structure Ablation (4 configurations)")
    print("      • baseline: playlist_track only")
    print("      • with_artists: + track_artist edges")
    print("      • with_users: + user_playlist edges")
    print("      • full_graph: all edge types")
    print("\n   Phase 2: Feature Importance Analysis (4 configurations)")
    print("      • playlist_features: baseline graph + playlist features (6D)")
    print("      • track_features: baseline graph + track features (4D)")
    print("      • user_features: baseline graph + user features (4D)")
    print("      • all_features: baseline graph + all feature types")

    results = run_lightgcn_experiments(
        data_path="../data/processed/spotify_scaled_hybrid_tiny.json"
    )

    if results:
        print("\n🎉 Enhanced LightGCN experiments completed successfully!")
    else:
        print("\n❌ Experiments failed! Check error messages above.")

🎯 Enhanced LightGCN Experiments with Comprehensive Result Saving

📋 EXPERIMENTAL DESIGN:
   Phase 1: Graph Structure Ablation (4 configurations)
      • baseline: playlist_track only
      • with_artists: + track_artist edges
      • with_users: + user_playlist edges
      • full_graph: all edge types

   Phase 2: Feature Importance Analysis (4 configurations)
      • playlist_features: baseline graph + playlist features (6D)
      • track_features: baseline graph + track features (4D)
      • user_features: baseline graph + user features (4D)
      • all_features: baseline graph + all feature types
🎯 Starting Enhanced LightGCN Experiments with Result Saving...
📁 Result directories created in: ../results/lightgcn_experiments_enhanced
🎯 Experiment Configuration:
   📱 Device: cpu
   🎲 Seeds: 5 seeds
   📊 Target playlists: 1,500
   🧠 Embedding dim: 128
   ⚡ Learning rate: 0.0005
   🛡️ Regularization: 0.001
   💾 Results dir: ../results/lightgcn_experiments_enhanced
📁 All results will be sa

In [8]:
import torch
import numpy as np
import json
import os
import pickle
from scipy import stats as scipy_stats
from typing import Dict, List, Tuple, Optional
from datetime import datetime

# =============================================================================
# PHASE 3: COMBINED RESULTS CONFIGURATION SELECTOR
# =============================================================================

class Phase3CombinedResultsSelector:
    """
    Enhanced configuration selector that handles combined Phase 1 & 2 results
    in a single file and automatically separates them for analysis.
    """

    def __init__(self, combined_results_path: str):
        """
        Initialize selector with combined results file.

        Args:
            combined_results_path: Path to combined Phase 1&2 results file (.pkl or .json)
        """
        self.combined_results = self._load_results(combined_results_path)

        # Automatically separate Phase 1 and Phase 2 configurations
        self.phase1_results, self.phase2_results = self._separate_phases()

        # Statistical criteria (same as before)
        self.significance_threshold = 0.05  # p < 0.05
        self.effect_size_threshold = 0.2    # Cohen's d > 0.2
        self.improvement_threshold = 5.0    # >5% improvement for features

        print("🔬 Phase 3 Combined Results Selector Initialized")
        print(f"   📊 Total configurations loaded: {len(self.combined_results)}")
        print(f"   📊 Phase 1 configs (graph structure): {len(self.phase1_results)}")
        print(f"   📊 Phase 2 configs (feature importance): {len(self.phase2_results)}")
        print(f"   📋 Statistical criteria:")
        print(f"      • Significance: p < {self.significance_threshold}")
        print(f"      • Effect size: Cohen's d > {self.effect_size_threshold}")
        print(f"      • Feature improvement: >{self.improvement_threshold}%")

    def _load_results(self, results_path: str) -> Dict:
        """Load results from file (supports both .pkl and .json)"""
        if results_path.endswith('.pkl'):
            with open(results_path, 'rb') as f:
                return pickle.load(f)
        elif results_path.endswith('.json'):
            with open(results_path, 'r') as f:
                return json.load(f)
        else:
            raise ValueError(f"Unsupported file format: {results_path}")

    def _separate_phases(self) -> Tuple[Dict, Dict]:
        """
        Automatically separate Phase 1 and Phase 2 configurations based on their characteristics.

        Phase 1: Graph structure ablation (no features, different edge types)
        Phase 2: Feature importance (same baseline graph, different features)

        Returns:
            Tuple[Dict, Dict]: (phase1_results, phase2_results)
        """
        phase1_configs = {}
        phase2_configs = {}

        print("\n🔍 Analyzing configurations to separate phases...")

        for config_name, config_data in self.combined_results.items():
            config = config_data['config']

            # Phase 1 identification: No features used
            if not config.get('use_features', False):
                phase1_configs[config_name] = config_data
                print(f"   📊 Phase 1: {config_name} - {config['description']}")

            # Phase 2 identification: Features used with baseline graph
            elif (config.get('use_features', False) and
                  config.get('edge_types') == ['playlist_track']):
                phase2_configs[config_name] = config_data
                print(f"   🎨 Phase 2: {config_name} - {config['description']}")

            else:
                # Handle edge cases - classify based on primary characteristic
                if len(config.get('edge_types', [])) > 1:
                    phase1_configs[config_name] = config_data
                    print(f"   📊 Phase 1 (edge case): {config_name} - {config['description']}")
                else:
                    phase2_configs[config_name] = config_data
                    print(f"   🎨 Phase 2 (edge case): {config_name} - {config['description']}")

        print(f"\n✅ Phase separation complete:")
        print(f"   Phase 1 (Graph Structure): {list(phase1_configs.keys())}")
        print(f"   Phase 2 (Feature Importance): {list(phase2_configs.keys())}")

        return phase1_configs, phase2_configs

    def determine_best_edge_types(self) -> Tuple[List[str], Dict]:
        """
        Determine best edge types from Phase 1 results using statistical criteria.

        Returns:
            Tuple[List[str], Dict]: (best_edge_types, analysis_details)
        """
        print("\n🔗 PHASE 1 ANALYSIS: Edge Type Selection")
        print("="*60)

        if not self.phase1_results:
            print("❌ No Phase 1 configurations found!")
            return ["playlist_track"], {"error": "No Phase 1 configurations available"}

        # Extract NDCG@10 performance for all Phase 1 configs
        config_performances = {}
        for config_name, config_data in self.phase1_results.items():
            if 'runs_summary' in config_data:
                # Handle runs_summary format
                ndcg_values = [run['metrics']['ndcg@10'] for run in config_data['runs_summary']]
            elif 'runs' in config_data:
                # Handle full runs format
                ndcg_values = [run['metrics']['ndcg@10'] for run in config_data['runs']]
            else:
                # Fallback to statistics if available
                if 'statistics' in config_data and 'ndcg@10' in config_data['statistics']:
                    mean_val = config_data['statistics']['ndcg@10']['mean']
                    std_val = config_data['statistics']['ndcg@10'].get('std', 0)
                    n = config_data['statistics']['ndcg@10'].get('n', 1)
                    # Approximate individual values (for analysis purposes)
                    ndcg_values = [mean_val] * n
                else:
                    print(f"   ⚠️ Skipping {config_name} - no usable performance data")
                    continue

            config_performances[config_name] = {
                'ndcg_values': ndcg_values,
                'mean_ndcg': np.mean(ndcg_values),
                'std_ndcg': np.std(ndcg_values, ddof=1) if len(ndcg_values) > 1 else 0,
                'config': config_data['config']
            }

        if not config_performances:
            print("❌ No usable Phase 1 performance data found!")
            return ["playlist_track"], {"error": "No usable Phase 1 performance data"}

        # Rank configurations by mean NDCG@10
        ranked_configs = sorted(
            config_performances.items(),
            key=lambda x: x[1]['mean_ndcg'],
            reverse=True
        )

        print("📊 Phase 1 Configuration Rankings (by NDCG@10):")
        for i, (config_name, perf) in enumerate(ranked_configs, 1):
            print(f"   {i}. {config_name}: {perf['mean_ndcg']:.4f} ± {perf['std_ndcg']:.4f}")

        # Get best performing configuration
        best_config_name, best_config_perf = ranked_configs[0]
        best_edge_types = best_config_perf['config']['edge_types']

        # Analysis details
        analysis_details = {
            'ranked_configs': [(name, perf['mean_ndcg']) for name, perf in ranked_configs],
            'selected_config': best_config_name,
            'performance': best_config_perf['mean_ndcg'],
            'edge_types': best_edge_types
        }

        # Test statistical significance against baseline (if not baseline itself)
        if best_config_name != 'baseline' and 'baseline' in config_performances:
            baseline_values = config_performances['baseline']['ndcg_values']
            best_values = config_performances[best_config_name]['ndcg_values']

            if len(baseline_values) > 1 and len(best_values) > 1:
                # t-test for significance
                t_stat, p_value = scipy_stats.ttest_ind(best_values, baseline_values)

                # Effect size (Cohen's d)
                pooled_std = np.sqrt(
                    ((len(best_values) - 1) * np.var(best_values, ddof=1) +
                     (len(baseline_values) - 1) * np.var(baseline_values, ddof=1)) /
                    (len(best_values) + len(baseline_values) - 2)
                )
                cohens_d = abs(np.mean(best_values) - np.mean(baseline_values)) / pooled_std

                analysis_details.update({
                    'vs_baseline': {
                        'p_value': p_value,
                        'cohens_d': cohens_d,
                        'significant': p_value < self.significance_threshold,
                        'meaningful_effect': cohens_d > self.effect_size_threshold
                    }
                })

                print(f"\n🧪 Statistical Analysis vs Baseline:")
                print(f"   • Best config: {best_config_name} (NDCG@10: {best_config_perf['mean_ndcg']:.4f})")
                print(f"   • Baseline: {config_performances['baseline']['mean_ndcg']:.4f}")
                print(f"   • p-value: {p_value:.4f} ({'significant' if p_value < self.significance_threshold else 'not significant'})")
                print(f"   • Cohen's d: {cohens_d:.4f} ({'meaningful' if cohens_d > self.effect_size_threshold else 'small effect'})")

                # Conservative decision
                if p_value < self.significance_threshold and cohens_d > self.effect_size_threshold:
                    print(f"   ✅ Selected: {best_config_name} (meets all criteria)")
                    selected_edge_types = best_edge_types
                    analysis_details['decision'] = f"Selected {best_config_name} - statistically significant improvement"
                else:
                    print(f"   ⚠️ Conservative choice: baseline (insufficient evidence for improvement)")
                    selected_edge_types = config_performances['baseline']['config']['edge_types']
                    analysis_details['decision'] = "Conservative fallback to baseline - insufficient statistical evidence"
                    analysis_details['selected_config'] = 'baseline'
            else:
                print(f"   ⚠️ Insufficient data for statistical testing")
                selected_edge_types = best_edge_types
                analysis_details['decision'] = "Insufficient data for statistical testing - selected best performer"
        else:
            if best_config_name == 'baseline':
                print(f"   ✅ Selected: baseline (best performing)")
                selected_edge_types = best_edge_types
                analysis_details['decision'] = "Baseline was best performing"
            else:
                print(f"   ⚠️ No baseline for comparison, selecting: {best_config_name}")
                selected_edge_types = best_edge_types
                analysis_details['decision'] = "No baseline comparison available"

        print(f"\n🎯 Final Edge Type Selection: {selected_edge_types}")

        return selected_edge_types, analysis_details

    def determine_best_features(self) -> Tuple[bool, List[str], Dict]:
        """
        Determine best features from Phase 2 results using improvement thresholds.

        Returns:
            Tuple[bool, List[str], Dict]: (use_features, feature_types, analysis_details)
        """
        print("\n🎨 PHASE 2 ANALYSIS: Feature Type Selection")
        print("="*60)

        if not self.phase2_results:
            print("❌ No Phase 2 configurations found!")
            return False, [], {"error": "No Phase 2 configurations available"}

        # Get baseline performance from Phase 1 results
        baseline_performance = None
        if 'baseline' in self.phase1_results:
            baseline_config = self.phase1_results['baseline']
            if 'runs_summary' in baseline_config:
                baseline_ndcg_values = [run['metrics']['ndcg@10'] for run in baseline_config['runs_summary']]
            elif 'statistics' in baseline_config:
                baseline_ndcg_values = [baseline_config['statistics']['ndcg@10']['mean']]
            baseline_performance = np.mean(baseline_ndcg_values)
            print(f"📊 Baseline Performance (Structure-only): {baseline_performance:.4f}")
        else:
            print("⚠️ No baseline found in Phase 1 results - using first Phase 2 config as reference")

        # Analyze feature configurations
        feature_analysis = {}
        for config_name, config_data in self.phase2_results.items():
            if 'runs_summary' in config_data:
                ndcg_values = [run['metrics']['ndcg@10'] for run in config_data['runs_summary']]
            elif 'statistics' in config_data:
                ndcg_values = [config_data['statistics']['ndcg@10']['mean']]
            else:
                continue

            mean_ndcg = np.mean(ndcg_values)

            # Calculate improvement over baseline
            if baseline_performance is not None:
                improvement_pct = ((mean_ndcg - baseline_performance) / baseline_performance) * 100
            else:
                improvement_pct = 0.0

            feature_analysis[config_name] = {
                'mean_ndcg': mean_ndcg,
                'improvement_pct': improvement_pct,
                'ndcg_values': ndcg_values,
                'feature_types': config_data['config']['feature_types']
            }

            print(f"   • {config_name}: {mean_ndcg:.4f} ({improvement_pct:+.1f}% vs baseline)")

        # Find best feature configuration
        if feature_analysis:
            best_feature_config = max(feature_analysis.items(), key=lambda x: x[1]['mean_ndcg'])
            best_config_name, best_config_perf = best_feature_config
            best_improvement = best_config_perf['improvement_pct']

            print(f"\n🏆 Best Feature Configuration: {best_config_name}")
            print(f"   • Performance: {best_config_perf['mean_ndcg']:.4f}")
            print(f"   • Improvement: {best_improvement:+.1f}% vs baseline")
            print(f"   • Threshold: {self.improvement_threshold}%")

            # Apply improvement threshold
            if best_improvement >= self.improvement_threshold:
                print(f"   ✅ Features justified: {best_improvement:.1f}% ≥ {self.improvement_threshold}%")
                use_features = True
                feature_types = best_config_perf['feature_types']
                decision = f"Selected {best_config_name} - {best_improvement:.1f}% improvement exceeds threshold"
            else:
                print(f"   ❌ Features not justified: {best_improvement:.1f}% < {self.improvement_threshold}%")
                use_features = False
                feature_types = []
                decision = f"Improvement {best_improvement:.1f}% < {self.improvement_threshold}% threshold"
        else:
            print("   ❌ No feature configurations found")
            use_features = False
            feature_types = []
            decision = "No feature configurations available"
            best_config_name = None
            best_improvement = 0.0

        analysis_details = {
            'baseline_performance': baseline_performance,
            'feature_analysis': feature_analysis,
            'best_config': best_config_name,
            'best_improvement': best_improvement,
            'threshold': self.improvement_threshold,
            'decision': decision,
            'use_features': use_features,
            'selected_feature_types': feature_types
        }

        print(f"\n🎯 Final Feature Selection:")
        print(f"   • Use features: {use_features}")
        print(f"   • Feature types: {feature_types}")

        return use_features, feature_types, analysis_details

    def create_optimal_configuration(self) -> Tuple[Dict, Dict]:
        """
        Create optimal configuration based on empirical analysis of combined results.

        Returns:
            Tuple[Dict, Dict]: (optimal_config, selection_analysis)
        """
        print("\n🎯 PHASE 3: EMPIRICAL CONFIGURATION SELECTION")
        print("="*80)

        # Determine best edge types from Phase 1
        best_edge_types, edge_analysis = self.determine_best_edge_types()

        # Determine best features from Phase 2
        use_features, feature_types, feature_analysis = self.determine_best_features()

        # Create optimal configuration
        optimal_config = {
            "name": "Dynamic Best Combined",
            "description": "Empirically determined best configuration based on statistical criteria from combined results",
            "edge_types": best_edge_types,
            "use_features": use_features,
            "feature_types": feature_types,
            "use_playlist_features": "playlist" in feature_types,
            "use_track_features": "track" in feature_types,
            "use_user_features": "user" in feature_types
        }

        # Comprehensive analysis summary
        selection_analysis = {
            'methodology': {
                'significance_threshold': self.significance_threshold,
                'effect_size_threshold': self.effect_size_threshold,
                'improvement_threshold': self.improvement_threshold,
                'approach': 'Conservative, evidence-based selection from combined results'
            },
            'data_source': {
                'combined_results': True,
                'phase1_configs': list(self.phase1_results.keys()),
                'phase2_configs': list(self.phase2_results.keys()),
                'separation_method': 'Automatic based on configuration characteristics'
            },
            'edge_selection': edge_analysis,
            'feature_selection': feature_analysis,
            'final_configuration': optimal_config,
            'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S")
        }

        print(f"\n🏆 OPTIMAL CONFIGURATION DETERMINED:")
        print(f"="*50)
        print(f"Name: {optimal_config['name']}")
        print(f"Description: {optimal_config['description']}")
        print(f"Edge Types: {optimal_config['edge_types']}")
        print(f"Use Features: {optimal_config['use_features']}")
        print(f"Feature Types: {optimal_config['feature_types']}")

        return optimal_config, selection_analysis


# =============================================================================
# PHASE 3: EXPERIMENT RUNNER (EMBEDDED)
# =============================================================================

class Phase3ExperimentRunner:
    """
    Runs Phase 3 experiments using the empirically determined optimal configuration.
    Uses the same methodology as Phase 1 & 2 for consistency.
    """

    def __init__(self, config: ExperimentConfig, data: Dict):
        self.config = config
        self.data = data
        self.trainer = ModelTrainer(config, data)
        self.experiment_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        print("🎯 PHASE 3 EXPERIMENT RUNNER INITIALIZED")
        print(f"   📊 Dataset: {data['entity_counts']['playlists']:,} playlists")
        print(f"   🎲 Seeds: {len(config.random_seeds)}")
        print(f"   📁 Results dir: {config.results_dir}")
        print(f"   🕐 Timestamp: {self.experiment_timestamp}")

    def run_phase3_experiment(self, optimal_config: Dict, selection_analysis: Dict) -> Dict:
        """
        Run Phase 3 experiment with the empirically determined optimal configuration.

        Args:
            optimal_config: The empirically determined optimal configuration
            selection_analysis: Analysis details from configuration selection

        Returns:
            Dict: Complete Phase 3 results
        """
        print("\n" + "="*80)
        print("🔬 PHASE 3: OPTIMAL CONFIGURATION EXPERIMENT")
        print("="*80)

        config_name = "phase3_optimal"
        config_spec = optimal_config

        print(f"\n🧪 Configuration: {config_spec['name']}")
        print(f"📝 {config_spec['description']}")
        print(f"🔗 Edge types: {config_spec['edge_types']}")
        print(f"🎨 Features: {config_spec['feature_types'] if config_spec['use_features'] else 'None'}")

        # Save Phase 3 configuration and analysis
        self._save_phase3_setup(optimal_config, selection_analysis)

        # Run experiments with multiple seeds (same as Phase 1 & 2)
        phase3_results = []

        for seed_idx, seed in enumerate(self.config.random_seeds):
            print(f"\n🎲 Seed {seed_idx+1}/{len(self.config.random_seeds)} (seed={seed}):")

            # Train model (same method as Phase 1 & 2)
            training_result = self.trainer.train_model(config_spec, seed=seed, save_model=True)

            if training_result is None:
                print(f"   ❌ Training failed for seed {seed}")
                continue

            # Evaluate model (same method as Phase 1 & 2)
            test_metrics = self.trainer.evaluate_model(
                training_result['model'],
                training_result['adj_matrix'],
                'test',
                seed=seed
            )

            # Store results
            run_result = {
                'seed': seed,
                'metrics': test_metrics,
                'training_time': training_result['training_time'],
                'final_loss': training_result['final_loss'],
                'best_val_loss': training_result['best_val_loss'],
                'model_file': training_result.get('model_file'),
                'timestamp': training_result.get('timestamp')
            }
            phase3_results.append(run_result)

            # Save individual run results
            self.config.result_saver.save_training_run(
                config_name, seed, training_result, test_metrics,
                training_result.get('timestamp')
            )

            # Print immediate results
            print(f"      📊 NDCG@10: {test_metrics.get('ndcg@10', 0):.4f}")
            print(f"      📊 AUC: {test_metrics.get('auc', 0):.4f}")
            print(f"      ⏱️ Time: {training_result['training_time']:.1f}s")

            # Memory cleanup
            del training_result
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

        # Calculate statistics
        if phase3_results:
            stats = self._calculate_statistics(phase3_results)

            complete_results = {
                'config': config_spec,
                'runs': phase3_results,
                'statistics': stats,
                'selection_analysis': selection_analysis,
                'methodology': 'Empirical configuration selection based on combined Phase 1 & 2 analysis'
            }

            # Save Phase 3 results
            self.config.result_saver.save_config_results(
                config_name, complete_results, self.experiment_timestamp
            )

            # Print summary
            self._print_phase3_summary(complete_results)

            return complete_results

        return None

    def _save_phase3_setup(self, optimal_config: Dict, selection_analysis: Dict):
        """Save Phase 3 experimental setup and configuration selection analysis"""
        setup_info = {
            'phase3_timestamp': self.experiment_timestamp,
            'optimal_configuration': optimal_config,
            'selection_analysis': selection_analysis,
            'methodology': {
                'approach': 'Empirical, performance-driven configuration selection from combined results',
                'statistical_criteria': {
                    'significance_testing': 'p < 0.05',
                    'effect_size_analysis': 'Cohen\'s d > 0.2',
                    'improvement_threshold': '>5% for feature justification'
                },
                'ranking_method': 'Mean NDCG@10 performance across multiple random seeds',
                'conservative_approach': 'Evidence-based selection with fallback to baseline'
            },
            'experiment_settings': {
                'random_seeds': self.config.random_seeds,
                'evaluation_metrics': self.config.k_values,
                'training_parameters': {
                    'epochs': self.config.epochs,
                    'learning_rate': self.config.learning_rate,
                    'embedding_dim': self.config.embedding_dim,
                    'n_layers': self.config.n_layers
                }
            }
        }

        setup_file = os.path.join(
            self.config.results_dir,
            f"phase3_setup_{self.experiment_timestamp}.json"
        )
        with open(setup_file, 'w') as f:
            json.dump(setup_info, f, indent=2)

        print(f"💾 Phase 3 setup saved: {os.path.basename(setup_file)}")

    def _calculate_statistics(self, phase3_results: List[Dict]) -> Dict:
        """Calculate statistics for Phase 3 results (same method as Phase 1 & 2)"""
        if not phase3_results:
            return {}

        statistics = {}

        # Get all metric names
        all_metrics = set()
        for run in phase3_results:
            all_metrics.update(run['metrics'].keys())

        # Calculate statistics for each metric
        for metric in all_metrics:
            values = [run['metrics'].get(metric, 0) for run in phase3_results]

            if values and len(values) > 1:
                mean_val = np.mean(values)
                std_val = np.std(values, ddof=1)
                n = len(values)

                # 95% Confidence intervals
                if n > 2:
                    t_value = scipy_stats.t.ppf(0.975, n-1)
                    margin_error = t_value * std_val / np.sqrt(n)
                    ci_lower = mean_val - margin_error
                    ci_upper = mean_val + margin_error
                else:
                    ci_lower = mean_val
                    ci_upper = mean_val

                statistics[metric] = {
                    'mean': mean_val,
                    'std': std_val,
                    'min': np.min(values),
                    'max': np.max(values),
                    'ci_lower': ci_lower,
                    'ci_upper': ci_upper,
                    'n': n
                }
            elif values:
                statistics[metric] = {
                    'mean': values[0],
                    'std': 0.0,
                    'min': values[0],
                    'max': values[0],
                    'ci_lower': values[0],
                    'ci_upper': values[0],
                    'n': 1
                }

        return statistics

    def _print_phase3_summary(self, results: Dict):
        """Print Phase 3 results summary"""
        print("\n" + "="*80)
        print("📊 PHASE 3: OPTIMAL CONFIGURATION RESULTS")
        print("="*80)

        stats = results['statistics']

        print(f"\n🏆 EMPIRICALLY DETERMINED OPTIMAL CONFIGURATION:")
        print(f"   Name: {results['config']['name']}")
        print(f"   Description: {results['config']['description']}")
        print(f"   Edge Types: {results['config']['edge_types']}")
        print(f"   Features: {results['config']['feature_types'] if results['config']['use_features'] else 'None'}")

        print(f"\n📊 PERFORMANCE METRICS:")

        # Key metrics with confidence intervals
        key_metrics = ['ndcg@10', 'ndcg@20', 'precision@10', 'recall@10', 'auc']
        for metric in key_metrics:
            if metric in stats:
                s = stats[metric]
                print(f"   • {metric.upper()}: {s['mean']:.4f} ± {s['std']:.4f} "
                      f"(95% CI: [{s['ci_lower']:.4f}, {s['ci_upper']:.4f}])")

        # Training efficiency
        avg_time = np.mean([run['training_time'] for run in results['runs']])
        print(f"\n⏱️ TRAINING EFFICIENCY:")
        print(f"   • Average training time: {avg_time:.1f}s ({avg_time/60:.1f} minutes)")
        print(f"   • Seeds completed: {len(results['runs'])}/{len(self.config.random_seeds)}")

        print(f"\n🔬 EMPIRICAL SELECTION SUMMARY:")
        edge_decision = results['selection_analysis']['edge_selection']['decision']
        feature_decision = results['selection_analysis']['feature_selection']['decision']
        print(f"   • Edge selection: {edge_decision}")
        print(f"   • Feature selection: {feature_decision}")

        print(f"\n✅ Phase 3 experiment completed successfully!")
        print(f"📁 Results saved with timestamp: {self.experiment_timestamp}")


# =============================================================================
# MAIN PHASE 3 EXECUTION FOR COMBINED RESULTS
# =============================================================================

def run_phase3_experiments_combined(combined_results_path: str,
                                   data_path: str = "../data/processed/spotify_scaled_hybrid_tiny.json") -> Dict:
    """
    Main function to run Phase 3 experiments using combined Phase 1&2 results file.

    Args:
        combined_results_path: Path to combined Phase 1&2 results (.pkl or .json)
        data_path: Path to the dataset

    Returns:
        Dict: Complete Phase 3 results
    """
    print("🎯 PHASE 3: EMPIRICAL CONFIGURATION SELECTION FROM COMBINED RESULTS")
    print("="*80)

    # 1. Initialize configuration selector for combined results
    selector = Phase3CombinedResultsSelector(combined_results_path)

    # 2. Determine optimal configuration empirically
    optimal_config, selection_analysis = selector.create_optimal_configuration()

    # 3. Setup experiment infrastructure (same as before)
    config = ExperimentConfig()
    # Update results directory for Phase 3
    config.results_dir = "../results/lightgcn_phase3_optimal_combined"
    config.result_saver = ResultSaver(config.results_dir)

    print(f"\n📁 Phase 3 results will be saved to: {config.results_dir}")

    # 4. Load and process data (same methodology)
    data_processor = SpotifyDataProcessor(config)
    data = data_processor.load_and_process_data(data_path, seed=42)

    # 5. Run Phase 3 experiment using embedded Phase3ExperimentRunner
    phase3_runner = Phase3ExperimentRunner(config, data)
    results = phase3_runner.run_phase3_experiment(optimal_config, selection_analysis)

    # 6. Save complete Phase 3 results with combined source notation
    if results:
        # Add source information
        results['data_source'] = {
            'combined_results_file': combined_results_path,
            'automatic_phase_separation': True,
            'methodology': 'Empirical selection from combined Phase 1&2 results'
        }

        config.result_saver.save_complete_results(
            {'phase3_optimal_combined': results},
            phase3_runner.experiment_timestamp
        )

        # Create summary report
        config.result_saver.create_experiment_summary(
            {'phase3_optimal_combined': results},
            phase3_runner.experiment_timestamp
        )

        print(f"\n🎉 Phase 3 experiments completed successfully!")
        print(f"📊 Empirical selection from combined results successful")
        print(f"📁 All results saved to: {config.results_dir}")

        return results
    else:
        print(f"\n❌ Phase 3 experiments failed!")
        return None


# =============================================================================
# ANALYSIS UTILITIES FOR COMBINED RESULTS
# =============================================================================

def analyze_combined_results_selection(combined_results_path: str) -> Dict:
    """
    Analyze the empirical selection process for combined results without running experiments.
    Useful for understanding what configuration would be selected and why.

    Args:
        combined_results_path: Path to combined Phase 1&2 results

    Returns:
        Dict: Selection analysis details
    """
    print("🔍 ANALYZING EMPIRICAL CONFIGURATION SELECTION")
    print("="*60)

    # Initialize selector
    selector = Phase3CombinedResultsSelector(combined_results_path)

    # Perform selection analysis
    optimal_config, selection_analysis = selector.create_optimal_configuration()

    # Create detailed analysis report
    report = {
        'optimal_configuration': optimal_config,
        'selection_analysis': selection_analysis,
        'summary': {
            'edge_decision': selection_analysis['edge_selection']['decision'],
            'feature_decision': selection_analysis['feature_selection']['decision'],
            'final_edge_types': optimal_config['edge_types'],
            'final_features': optimal_config['feature_types'],
            'complexity_level': 'Simple' if not optimal_config['use_features'] else 'Feature-Enhanced'
        }
    }

    print(f"\n📊 SELECTION ANALYSIS COMPLETE")
    print(f"   🔗 Edge Selection: {report['summary']['edge_decision']}")
    print(f"   🎨 Feature Selection: {report['summary']['feature_decision']}")
    print(f"   🎯 Final Configuration: {report['summary']['complexity_level']}")

    return report


# =============================================================================
# USAGE EXAMPLE FOR YOUR DATA
# =============================================================================

if __name__ == "__main__":
    print("🔬 Phase 3: Empirical Configuration Selection from Combined Results")
    print("\n📋 METHODOLOGY:")
    print("   • Automatic Phase 1/2 separation from combined results")
    print("   • Statistical criteria: p < 0.05, Cohen's d > 0.2, >5% improvement")
    print("   • Conservative evidence-based selection")
    print("   • Same training/evaluation methodology as Phase 1 & 2")

    # Use your actual combined results file
    combined_results_path = "../results/lightgcn_experiments_enhanced/complete_results_20250816_211920.json"  # Your file with the JSON data
    data_path = "../data/processed/spotify_scaled_hybrid_tiny.json"

    try:
        # First, let's analyze what configuration would be selected
        print("\n🔍 STEP 1: Analyzing empirical selection...")
        analysis = analyze_combined_results_selection(combined_results_path)

        print(f"\n📋 PREDICTED OPTIMAL CONFIGURATION:")
        print(f"   • Edge Types: {analysis['optimal_configuration']['edge_types']}")
        print(f"   • Features: {analysis['optimal_configuration']['feature_types']}")
        print(f"   • Rationale: {analysis['summary']['edge_decision']}")

        # Then run the full Phase 3 experiment
        print(f"\n🚀 STEP 2: Running Phase 3 experiment with optimal configuration...")

        results = run_phase3_experiments_combined(
            combined_results_path=combined_results_path,
            data_path=data_path
        )

        if results:
            print("\n🎉 Phase 3 completed successfully!")
            print("📊 Key results:")
            stats = results['statistics']
            print(f"   • NDCG@10: {stats['ndcg@10']['mean']:.4f} ± {stats['ndcg@10']['std']:.4f}")
            print(f"   • AUC: {stats['auc']['mean']:.4f} ± {stats['auc']['std']:.4f}")
            print(f"   • Configuration: {results['config']['edge_types']} + {results['config']['feature_types']}")

    except FileNotFoundError as e:
        print(f"\n❌ Error: Could not find results file")
        print(f"Please ensure the combined results file exists at: {combined_results_path}")
        print(f"Error details: {e}")

    except Exception as e:
        print(f"\n❌ Error running Phase 3 experiments: {e}")
        import traceback
        traceback.print_exc()

🔬 Phase 3: Empirical Configuration Selection from Combined Results

📋 METHODOLOGY:
   • Automatic Phase 1/2 separation from combined results
   • Statistical criteria: p < 0.05, Cohen's d > 0.2, >5% improvement
   • Conservative evidence-based selection
   • Same training/evaluation methodology as Phase 1 & 2

🔍 STEP 1: Analyzing empirical selection...
🔍 ANALYZING EMPIRICAL CONFIGURATION SELECTION

🔍 Analyzing configurations to separate phases...
   📊 Phase 1: baseline - Playlist-track edges only
   📊 Phase 1: with_artists - Add track-artist relationships
   📊 Phase 1: with_users - Add user-playlist relationships
   📊 Phase 1: full_graph - All edge types
   🎨 Phase 2: playlist_features - Only playlist features (6D) on baseline graph
   🎨 Phase 2: track_features - Only track features (4D) on baseline graph
   🎨 Phase 2: user_features - Only user features (4D) on baseline graph
   🎨 Phase 2: all_features - All feature types combined on baseline graph

✅ Phase separation complete:
   Phas