In [1]:
class ProductionGNNPathPredictor:
    """
    Production-ready GNN Path Predictor with comprehensive features
    """
    
    def __init__(self, config: Config = None):
        self.config = config or Config()
        self.logger = logging.getLogger(self.__class__.__name__)
        
        # Initialize components
        self.generator = GraphDatasetGenerator(self.config)
        self.visualizer = GraphVisualizationEngine(self.config)
        
        # Initialize model
        self.model = GraphPathGNN(self.config).to(self.config.device)
        self.rl_trainer = ReinforcementLearningTrainer(self.model, self.config)
        
        # Data storage
        self.dataset = None
        self.train_loader = None
        self.test_loader = None
        
        # Performance tracking
        self.training_history = defaultdict(list)
        self.model_metadata = {
            'created_at': datetime.now().isoformat(),
            'config': self.config.__dict__,
            'version': '2.0.0'
        }
        
        self.logger.info(f"Initialized ProductionGNNPathPredictor with device: {self.config.device}")
    
    def prepare_torch_data(self, sample: Dict) -> Data:
        """Convert sample to PyTorch Geometric format with optimizations"""
        adj_matrix = sample['adj_matrix']
        num_nodes = len(adj_matrix)
        
        # Efficient edge creation
        edge_indices = np.nonzero(adj_matrix)
        edge_index = torch.tensor(np.vstack(edge_indices), dtype=torch.long)
        edge_attr = torch.tensor(adj_matrix[edge_indices], dtype=torch.float)
        
        # Enhanced node features
        x = torch.zeros(num_nodes, self.config.input_dim, dtype=torch.float)
        x[:, 0] = torch.arange(num_nodes, dtype=torch.float) / num_nodes  # Normalized node ID
        x[sample['start'], 1] = 1.0  # Start node
        x[sample['end'], 2] = 1.0    # End node
        
        # Create data object
        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            start_idx=torch.tensor([sample['start']], dtype=torch.long),
            end_idx=torch.tensor([sample['end']], dtype=torch.long),
            distance=torch.tensor([sample['distance']], dtype=torch.float)
        )
        
        return data
    
    def create_dataset(self, num_samples: int = None, use_cache: bool = True) -> Tuple[List, List]:
        """Create dataset with caching support"""
        if num_samples is None:
            num_samples = self.config.num_samples
        
        cache_file = f"dataset_{num_samples}_{self.config.min_nodes}_{self.config.max_nodes}.pkl"
        
        # Try to load from cache
        if use_cache and os.path.exists(os.path.join(self.config.data_save_path, cache_file)):
            try:
                raw_dataset = self.generator.load_dataset(cache_file)
                self.logger.info(f"Loaded dataset from cache: {len(raw_dataset)} samples")
            except Exception as e:
                self.logger.warning(f"Failed to load cached dataset: {e}")
                raw_dataset = self.generator.generate_dataset(num_samples)
                self.generator.save_dataset(raw_dataset, cache_file)
        else:
            raw_dataset = self.generator.generate_dataset(num_samples)
            if use_cache:
                self.generator.save_dataset(raw_dataset, cache_file)
        
        # Convert to PyTorch format with parallel processing
        self.logger.info("Converting to PyTorch format...")
        with ThreadPoolExecutor(max_workers=self.config.num_workers) as executor:
            torch_dataset = list(tqdm(
                executor.map(self.prepare_torch_data, raw_dataset),
                total=len(raw_dataset),
                desc="Converting data"
            ))
        
        self.dataset = torch_dataset
        
        # Advanced train-test split
        train_size = int(0.8 * len(torch_dataset))
        val_size = int(0.1 * len(torch_dataset))
        test_size = len(torch_dataset) - train_size - val_size
        
        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            torch_dataset, [train_size, val_size, test_size]
        )
        
        # Create data loaders
        self.train_loader = DataLoader(
            train_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=True,
            num_workers=min(4, self.config.num_workers),
            pin_memory=True if self.config.device == 'cuda' else False
        )
        
        self.val_loader = DataLoader(
            val_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=False,
            num_workers=min(4, self.config.num_workers),
            pin_memory=True if self.config.device == 'cuda' else False
        )
        
        self.test_loader = DataLoader(
            test_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=False,
            num_workers=min(4, self.config.num_workers),
            pin_memory=True if self.config.device == 'cuda' else False
        )
        
        self.logger.info(f"Dataset split - Train: {len(train_dataset)}, "
                        f"Val: {len(val_dataset)}, Test: {len(test_dataset)}")
        
        return train_dataset, test_dataset
    
    def train_model(self, epochs: int = None, patience: int = None) -> Dict:
        """Enhanced training with validation and comprehensive monitoring"""
        if epochs is None:
            epochs = self.config.epochs
        if patience is None:
            patience = self.config.patience
        
        self.logger.info(f"Starting training for {epochs} epochs...")
        
        best_val_loss = float('inf')
        patience_counter = 0
        start_time = time.time()
        
        for epoch in range(epochs):
            epoch_start = time.time()
            
            # Training phase
            self.model.train()
            train_loss = 0
            train_reward = 0
            
            for batch in tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
                batch = batch.to(self.config.device)
                
                # Extract Dijkstra solutions
                dijkstra_solutions = batch.distance.cpu().numpy().flatten()
                
                # Training step
                loss, reward = self.rl_trainer.train_step([batch], dijkstra_solutions)
                train_loss += loss
                train_reward += reward
            
            avg_train_loss = train_loss / len(self.train_loader)
            avg_train_reward = train_reward / len(self.train_loader)
            
            # Validation phase
            val_loss, val_accuracy = self.validate_model()
            
            # Log metrics
            epoch_time = time.time() - epoch_start
            self.training_history['train_loss'].append(avg_train_loss)
            self.training_history['train_reward'].append(avg_train_reward)
            self.training_history['val_loss'].append(val_loss)
            self.training_history['val_accuracy'].append(val_accuracy)
            self.training_history['epoch_time'].append(epoch_time)
            
            self.logger.info(
                f"Epoch {epoch+1}/{epochs} - "
                f"Train Loss: {avg_train_loss:.4f}, "
                f"Train Reward: {avg_train_reward:.4f}, "
                f"Val Loss: {val_loss:.4f}, "
                f"Val Accuracy: {val_accuracy:.4f}, "
                f"Time: {epoch_time:.2f}s"
            )
            
            # Early stopping and model saving
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                self.save_model('best_model.pth')
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    self.logger.info(f"Early stopping at epoch {epoch+1}")
                    break
        
        total_time = time.time() - start_time
        self.logger.info(f"Training completed in {total_time:.2f} seconds")
        
        # Save final model and training history
        self.save_model('final_model.pth')
        self.save_training_history()
        
        return {
            'best_val_loss': best_val_loss,
            'total_epochs': epoch + 1,
            'total_time': total_time,
            'training_history': dict(self.training_history)
        }
    
    def validate_model(self) -> Tuple[float, float]:
        """Validation step"""
        self.model.eval()
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0
        
        with torch.no_grad():
            for batch in self.val_loader:
                batch = batch.to(self.config.device)
                
                outputs = self.model(
                    batch.x, batch.edge_index, batch.batch,
                    batch.start_idx, batch.end_idx
                )
                
                targets = batch.distance.unsqueeze(1)
                loss = F.mse_loss(outputs, targets)
                total_loss += loss.item()
                
                # Calculate accuracy within tolerance
                predictions = outputs.cpu().numpy().flatten()
                ground_truths = targets.cpu().numpy().flatten()
                
                for pred, true in zip(predictions, ground_truths):
                    if true > 0 and abs(pred - true) / true < 0.1:
                        correct_predictions += 1
                    total_predictions += 1
        
        avg_loss = total_loss / len(self.val_loader)
        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
        
        return avg_loss, accuracy
    
    def evaluate_model(self) -> Dict:
        """Comprehensive model evaluation"""
        self.logger.info("Evaluating model on test set...")
        
        self.model.eval()
        predictions = []
        ground_truths = []
        inference_times = []
        
        with torch.no_grad():
            for batch in tqdm(self.test_loader, desc="Evaluating"):
                batch = batch.to(self.config.device)
                
                start_time = time.time()
                outputs = self.model(
                    batch.x, batch.edge_index, batch.batch,
                    batch.start_idx, batch.end_idx
                )
                inference_time = time.time() - start_time
                inference_times.append(inference_time)
                
                predictions.extend(outputs.cpu().numpy().flatten())
                ground_truths.extend(batch.distance.cpu().numpy().flatten())
        
        # Calculate comprehensive metrics
        mse = mean_squared_error(ground_truths, predictions)
        rmse = np.sqrt(mse)
        mae = np.mean([abs(p - t) for p, t in zip(predictions, ground_truths)])
        
        # Accuracy at different tolerance levels
        accuracies = {}
        for tolerance in [0.05, 0.1, 0.15, 0.2]:
            correct = sum(
                abs(pred - true) / true < tolerance
                for pred, true in zip(predictions, ground_truths)
                if true > 0
            )
            accuracies[f'accuracy_{int(tolerance*100)}%'] = correct / len(predictions)
        
        # Performance metrics
        avg_inference_time = np.mean(inference_times)
        throughput = len(predictions) / sum(inference_times)
        
        results = {
            'mse': mse,
            'rmse': rmse,
            'mae': mae,
            'avg_inference_time': avg_inference_time,
            'throughput': throughput,
            **accuracies
        }
        
        self.logger.info("Evaluation Results:")
        for key, value in results.items():
            self.logger.info(f"{key}: {value:.4f}")
        
        # Generate visualizations
        self.plot_evaluation_results(predictions, ground_truths)
        
        return results
    
    def plot_evaluation_results(self, predictions: List[float], ground_truths: List[float]):
        """Plot comprehensive evaluation results"""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('Model Evaluation Results', fontsize=16, fontweight='bold')
        
        # Prediction vs Ground Truth
        axes[0, 0].scatter(ground_truths, predictions, alpha=0.6, s=20)
        axes[0, 0].plot([min(ground_truths), max(ground_truths)], 
                    [min(ground_truths), max(ground_truths)], 'r--', linewidth=2)
        axes[0, 0].set_xlabel('Ground Truth Distance')
        axes[0, 0].set_ylabel('Predicted Distance')
        axes[0, 0].set_title('Predictions vs Ground Truth')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Error distribution
        errors = [abs(p - t) for p, t in zip(predictions, ground_truths)]
        axes[0, 1].hist(errors, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
        axes[0, 1].set_xlabel('Absolute Error')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].set_title('Error Distribution')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Relative error distribution
        relative_errors = [abs(p - t) / t for p, t in zip(predictions, ground_truths) if t > 0]
        axes[0, 2].hist(relative_errors, bins=50, alpha=0.7, color='lightgreen', edgecolor='black')
        axes[0, 2].set_xlabel('Relative Error')
        axes[0, 2].set_ylabel('Frequency')
        axes[0, 2].set_title('Relative Error Distribution')
        axes[0, 2].grid(True, alpha=0.3)
        
        # Training history
        if self.training_history:
            axes[1, 0].plot(self.training_history['train_loss'], label='Train Loss', color='blue')
            axes[1, 0].plot(self.training_history['val_loss'], label='Val Loss', color='red')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Loss')
            axes[1, 0].set_title('Training History')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)
            
            axes[1, 1].plot(self.training_history['train_reward'], label='Train Reward', color='green')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Reward')
            axes[1, 1].set_title('Reward History')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)
            
            axes[1, 2].plot(self.training_history['val_accuracy'], label='Val Accuracy', color='purple')
            axes[1, 2].set_xlabel('Epoch')
            axes[1, 2].set_ylabel('Accuracy')
            axes[1, 2].set_title('Accuracy History')
            axes[1, 2].legend()
            axes[1, 2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def predict_path(self, adj_matrix: np.ndarray, start_node: int, end_node: int,
                    visualize: bool = True, return_path: bool = False) -> Dict:
        """
        MAIN PREDICTION FUNCTION - Predict shortest path distance and visualize
        
        Args:
            adj_matrix: Adjacency matrix of the graph
            start_node: Starting node index
            end_node: Ending node index
            visualize: Whether to plot the graph
            return_path: Whether to return the actual path
        
        Returns:
            Dictionary containing prediction results
        """
        self.logger.info(f"Predicting path from node {start_node} to node {end_node}")
        
        # Validate inputs
        if start_node < 0 or start_node >= len(adj_matrix):
            raise ValueError(f"Invalid start node: {start_node}")
        if end_node < 0 or end_node >= len(adj_matrix):
            raise ValueError(f"Invalid end node: {end_node}")
        
        # Prepare input data
        sample = {
            'adj_matrix': adj_matrix,
            'start': start_node,
            'end': end_node,
            'distance': 0,
            'path': []
        }
        
        data = self.prepare_torch_data(sample).to(self.config.device)
        
        # Model prediction
        self.model.eval()
        start_time = time.time()
        
        with torch.no_grad():
            prediction = self.model(
                data.x, data.edge_index, 
                torch.zeros(len(data.x), dtype=torch.long, device=self.config.device),
                data.start_idx, data.end_idx
            )
        
        prediction_time = time.time() - start_time
        predicted_distance = prediction.item()
        
        # Calculate true distance using Dijkstra
        true_distance, true_path = self.generator.dijkstra_algorithm(
            adj_matrix, start_node, end_node
        )
        
        # Calculate metrics
        error = abs(predicted_distance - true_distance)
        relative_error = error / true_distance if true_distance > 0 else float('inf')
        
        results = {
            'predicted_distance': predicted_distance,
            'true_distance': true_distance,
            'true_path': true_path,
            'absolute_error': error,
            'relative_error': relative_error,
            'prediction_time': prediction_time,
            'accuracy': relative_error < 0.1
        }
        
        # Print results
        print(f"\n{'='*60}")
        print(f"PATH PREDICTION RESULTS")
        print(f"{'='*60}")
        print(f"Start Node: {start_node}")
        print(f"End Node: {end_node}")
        print(f"Predicted Distance: {predicted_distance:.4f}")
        print(f"True Distance: {true_distance:.4f}")
        print(f"Absolute Error: {error:.4f}")
        print(f"Relative Error: {relative_error:.2%}")
        print(f"Prediction Time: {prediction_time:.4f}s")
        print(f"Accuracy (within 10%): {'✓' if results['accuracy'] else '✗'}")
        print(f"{'='*60}")
        
        # Visualize if requested
        if visualize:
            self.visualizer.visualize_graph_with_path(
                adj_matrix=adj_matrix,
                start_node=start_node,
                end_node=end_node,
                true_path=true_path,
                title=f"Path Prediction: {start_node} → {end_node}"
            )
        
        return results
    
    def save_model(self, filename: str):
        """Save model with metadata"""
        filepath = os.path.join(self.config.model_save_path, filename)
        
        save_dict = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.rl_trainer.optimizer.state_dict(),
            'config': self.config.__dict__,
            'metadata': self.model_metadata,
            'training_history': dict(self.training_history)
        }
        
        torch.save(save_dict, filepath)
        self.logger.info(f"Model saved to {filepath}")
    
    def load_model(self, filename: str):
        """Load model with metadata"""
        filepath = os.path.join(self.config.model_save_path, filename)
        
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"Model file not found: {filepath}")
        
        checkpoint = torch.load(filepath, map_location=self.config.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.rl_trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.training_history = defaultdict(list, checkpoint.get('training_history', {}))
        
        self.logger.info(f"Model loaded from {filepath}")
        return checkpoint.get('metadata', {})
    
    def save_training_history(self):
        """Save training history to JSON"""
        filepath = os.path.join(self.config.model_save_path, 'training_history.json')
        
        with open(filepath, 'w') as f:
            json.dump(dict(self.training_history), f, indent=2)
        
        self.logger.info(f"Training history saved to {filepath}")
    
    def get_model_info(self) -> Dict:
        """Get comprehensive model information"""
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        
        return {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'model_size_mb': total_params * 4 / (1024 * 1024),  # Assuming float32
            'device': self.config.device,
            'metadata': self.model_metadata
        }


NameError: name 'Config' is not defined