In [13]:
"""
Comprehensive Framework Results Verifier

 Goal: Test all optimization results for Sidorenko violations!
"""

import torch
import itertools
import numpy as np
import pickle
import json
import csv
import os
import glob
import time
from pathlib import Path
from typing import List, Tuple, Dict, Optional, Union

class FrameworkResultsLoader:
    """Smart loader for results from the Sidorenko optimization framework."""
    
    def __init__(self, search_dirs: List[str] = None):
        """
        Initialize loader with search directories.
        
        Args:
            search_dirs: List of directories to search. If None, uses common locations.
        """
        if search_dirs is None:
            search_dirs = [
                ".",                    # Current directory
                "./models",             # Default model directory
                "./demo_models",        # Demo model directory  
                "./parallel_results",   # Parallel experiment results
                "./demo_parallel_results",  # Demo parallel results
                "./results",            # General results directory
                "./output",             # Output directory
                "./checkpoints",        # Checkpoint directory
            ]
        
        self.search_dirs = [Path(d) for d in search_dirs if os.path.exists(d)]
        print(f"🔍 Searching in {len(self.search_dirs)} directories:")
        for d in self.search_dirs:
            print(f"   📁 {d.absolute()}")
    
    def find_all_results(self) -> Dict[str, List[Path]]:
        """Find all optimization results across different formats."""
        results = {
            'pickle_matrices': [],      # .pkl files with matrices
            'model_checkpoints': [],    # .pth model files
            'csv_matrices': [],         # .csv matrix files
            'json_results': [],         # .json result files
            'individual_results': []    # Individual result files
        }
        
        for search_dir in self.search_dirs:
            if not search_dir.exists():
                continue
                
            # Find pickle files (from save_optimization_results)
            pickle_files = list(search_dir.glob("*matrices*.pkl"))
            pickle_files.extend(search_dir.glob("sidorenko_optimization_matrices_*.pkl"))
            results['pickle_matrices'].extend(pickle_files)
            
            # Find model checkpoints (.pth files)
            model_files = list(search_dir.glob("*.pth"))
            model_files.extend(search_dir.glob("*model*.pth"))
            results['model_checkpoints'].extend(model_files)
            
            # Find CSV matrices
            csv_files = list(search_dir.glob("*.csv"))
            csv_files.extend(search_dir.glob("W_optimized*.csv"))
            results['csv_matrices'].extend(csv_files)
            
            # Find JSON results
            json_files = list(search_dir.glob("*.json"))
            json_files.extend(search_dir.glob("experiment_summary.json"))
            results['json_results'].extend(json_files)
            
            # Find individual result files
            result_files = list(search_dir.glob("result_*.json"))
            results['individual_results'].extend(result_files)
        
        # Remove duplicates and sort by modification time (newest first)
        for key in results:
            results[key] = sorted(set(results[key]), 
                                key=lambda x: x.stat().st_mtime, 
                                reverse=True)
        
        # Print summary
        total_files = sum(len(files) for files in results.values())
        print(f"\n📊 Found {total_files} optimization result files:")
        for category, files in results.items():
            if files:
                print(f"   {category}: {len(files)} files")
        
        return results
    
    def load_matrices_from_pickle(self, pickle_path: Path) -> List[Dict]:
        """Load matrices from pickle files created by save_optimization_results."""
        try:
            with open(pickle_path, 'rb') as f:
                data = pickle.load(f)
            
            matrices = []
            if isinstance(data, list):
                # Format: [{'episode': int, 'matrix': ndarray, 'score': float}, ...]
                for item in data:
                    if isinstance(item, dict) and 'matrix' in item:
                        matrices.append({
                            'matrix': item['matrix'],
                            'score': item.get('score', 0),
                            'episode': item.get('episode', 0),
                            'source': f"pickle:{pickle_path.name}",
                            'source_path': str(pickle_path)
                        })
            elif isinstance(data, dict) and 'matrix' in data:
                # Single matrix format
                matrices.append({
                    'matrix': data['matrix'],
                    'score': data.get('score', 0),
                    'episode': data.get('episode', 0),
                    'source': f"pickle:{pickle_path.name}",
                    'source_path': str(pickle_path)
                })
            
            print(f"✅ Loaded {len(matrices)} matrices from {pickle_path.name}")
            return matrices
            
        except Exception as e:
            print(f"❌ Failed to load {pickle_path.name}: {e}")
            return []
    
    def load_matrices_from_checkpoint(self, checkpoint_path: Path) -> List[Dict]:
        """Load matrices from model checkpoint files."""
        try:
            # Handle PyTorch 2.6+ security restrictions
            try:
                # First try with weights_only=False for trusted checkpoints
                checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
            except Exception as e1:
                try:
                    # Try with safe globals for numpy arrays
                    import numpy as np
                    torch.serialization.add_safe_globals([np.core.multiarray.scalar, np.ndarray])
                    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
                except Exception as e2:
                    print(f"⚠️  PyTorch 2.6+ security restriction for {checkpoint_path.name}")
                    print(f"   Trying alternative loading method...")
                    # Last resort: load with explicit safety override
                    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
            
            matrices = []
            
            # Look for best_matrices in checkpoint
            if 'best_matrices' in checkpoint:
                best_matrices = checkpoint['best_matrices']
                for item in best_matrices:
                    if 'matrix' in item:
                        matrix = item['matrix']
                        # Convert to numpy if needed
                        if isinstance(matrix, torch.Tensor):
                            matrix = matrix.cpu().numpy()
                        
                        matrices.append({
                            'matrix': matrix,
                            'score': item.get('score', 0),
                            'episode': item.get('episode', 0),
                            'source': f"checkpoint:{checkpoint_path.name}",
                            'source_path': str(checkpoint_path)
                        })
            
            # Look for single matrix in checkpoint
            elif 'matrix' in checkpoint:
                matrix = checkpoint['matrix']
                if isinstance(matrix, torch.Tensor):
                    matrix = matrix.cpu().numpy()
                    
                matrices.append({
                    'matrix': matrix,
                    'score': checkpoint.get('score', 0),
                    'episode': checkpoint.get('episode', 0),
                    'source': f"checkpoint:{checkpoint_path.name}",
                    'source_path': str(checkpoint_path)
                })
            
            # Look for matrices in training_history or other locations
            elif 'training_history' in checkpoint:
                # Some models might store matrices in training history
                history = checkpoint['training_history']
                if isinstance(history, list):
                    for item in history:
                        if isinstance(item, dict) and 'matrix' in item:
                            matrix = item['matrix']
                            if isinstance(matrix, torch.Tensor):
                                matrix = matrix.cpu().numpy()
                            matrices.append({
                                'matrix': matrix,
                                'score': item.get('score', item.get('reward', 0)),
                                'episode': item.get('episode', 0),
                                'source': f"checkpoint:{checkpoint_path.name}",
                                'source_path': str(checkpoint_path)
                            })
            
            if matrices:
                print(f"✅ Loaded {len(matrices)} matrices from {checkpoint_path.name}")
            else:
                print(f"⚠️  No matrices found in {checkpoint_path.name}")
            return matrices
            
        except Exception as e:
            print(f"❌ Failed to load checkpoint {checkpoint_path.name}: {e}")
            return []
    
    def load_matrices_from_csv(self, csv_path: Path) -> List[Dict]:
        """Load matrices from CSV files."""
        try:
            # Read CSV data
            with open(csv_path, 'r') as f:
                reader = csv.reader(f)
                rows = [list(map(float, row)) for row in reader]
            
            if not rows:
                return []
            
            matrix = np.array(rows)
            
            # Verify it's a square matrix
            if matrix.shape[0] != matrix.shape[1]:
                print(f"⚠️  {csv_path.name} is not square: {matrix.shape}")
                return []
            
            matrices = [{
                'matrix': matrix,
                'score': 0,  # Unknown score from CSV
                'episode': 0,
                'source': f"csv:{csv_path.name}",
                'source_path': str(csv_path)
            }]
            
            print(f"✅ Loaded matrix {matrix.shape} from {csv_path.name}")
            return matrices
            
        except Exception as e:
            print(f"❌ Failed to load CSV {csv_path.name}: {e}")
            return []
    
    def load_matrices_from_json(self, json_path: Path) -> List[Dict]:
        """Load matrices from JSON result files."""
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            
            matrices = []
            
            # Handle different JSON structures
            if isinstance(data, dict):
                # Check for experiment results structure
                if 'best_matrix' in data:
                    matrix = np.array(data['best_matrix'])
                    matrices.append({
                        'matrix': matrix,
                        'score': data.get('best_score', 0),
                        'episode': 0,
                        'source': f"json:{json_path.name}",
                        'source_path': str(json_path)
                    })
                
                # Check for nested structure (experiment_summary.json)
                elif isinstance(data, dict):
                    for family_name, family_data in data.items():
                        if isinstance(family_data, dict):
                            for graph_name, graph_data in family_data.items():
                                if isinstance(graph_data, dict):
                                    for config_name, result in graph_data.items():
                                        if isinstance(result, dict) and 'evaluation' in result:
                                            eval_data = result['evaluation']
                                            if 'best_matrix' in eval_data:
                                                matrix = np.array(eval_data['best_matrix'])
                                                matrices.append({
                                                    'matrix': matrix,
                                                    'score': result.get('best_score', 0),
                                                    'episode': 0,
                                                    'source': f"json:{json_path.name}",
                                                    'source_path': str(json_path),
                                                    'family': family_name,
                                                    'graph': graph_name,
                                                    'config': config_name
                                                })
            
            if matrices:
                print(f"✅ Loaded {len(matrices)} matrices from {json_path.name}")
            return matrices
            
        except Exception as e:
            print(f"❌ Failed to load JSON {json_path.name}: {e}")
            return []
    
    def load_all_matrices(self) -> List[Dict]:
        """Load all matrices from all found result files."""
        print("\n🔄 Loading all optimization results...")
        
        all_files = self.find_all_results()
        all_matrices = []
        
        # Load from pickle files
        for pickle_file in all_files['pickle_matrices']:
            matrices = self.load_matrices_from_pickle(pickle_file)
            all_matrices.extend(matrices)
        
        # Load from model checkpoints
        for checkpoint_file in all_files['model_checkpoints']:
            matrices = self.load_matrices_from_checkpoint(checkpoint_file)
            all_matrices.extend(matrices)
        
        # Load from CSV files
        for csv_file in all_files['csv_matrices']:
            matrices = self.load_matrices_from_csv(csv_file)
            all_matrices.extend(matrices)
        
        # Load from JSON files
        for json_file in all_files['json_results']:
            matrices = self.load_matrices_from_json(json_file)
            all_matrices.extend(matrices)
        
        # Load from individual result files
        for result_file in all_files['individual_results']:
            matrices = self.load_matrices_from_json(result_file)
            all_matrices.extend(matrices)
        
        print(f"\n📊 Total matrices loaded: {len(all_matrices)}")
        return all_matrices

class SidorenkoVerifier:
    """Optimized Sidorenko conjecture verifier for Möbius ladder."""
    
    def __init__(self, device):
        self.device = device
        
        # Möbius ladder K_{5,5} \ C_{10} structure
        self.left_neighbors = {
            0: (0, 1, 4),  # L0 → R0, R1, R4
            1: (0, 1, 2),  # L1 → R0, R1, R2  
            2: (1, 2, 3),  # L2 → R1, R2, R3
            3: (2, 3, 4),  # L3 → R2, R3, R4
            4: (3, 4, 0),  # L4 → R3, R4, R0
        }
    
    def precompute_left_tables(self, M):
        """Precompute S_i tables for efficient homomorphism counting."""
        n = M.shape[0]
        S_tables = {}
        
        for left_idx, (a, b, c) in self.left_neighbors.items():
            S_table = torch.zeros((n, n, n), device=self.device, dtype=M.dtype)
            
            for x_a in range(n):
                for x_b in range(n):
                    for x_c in range(n):
                        contribution = torch.sum(M[:, x_a] * M[:, x_b] * M[:, x_c])
                        S_table[x_a, x_b, x_c] = contribution
            
            S_tables[left_idx] = S_table
            
        return S_tables
    
    def compute_homomorphism_exact(self, M):
        """Compute exact homomorphism count using optimized algorithm."""
        n = M.shape[0]
        M = M.to(self.device)
        
        # Precompute left-side contribution tables
        S_tables = self.precompute_left_tables(M)
        
        # Generate all 6^5 right-side assignments
        all_right_assignments = list(itertools.product(range(n), repeat=5))
        
        # Convert to tensor for vectorized processing
        assignments = torch.tensor(all_right_assignments, device=self.device, dtype=torch.long)
        r0, r1, r2, r3, r4 = assignments.T
        
        # Vectorized lookup of S_i values
        S0 = S_tables[0][r0, r1, r4]  # L0 neighbors: (R0, R1, R4)
        S1 = S_tables[1][r0, r1, r2]  # L1 neighbors: (R0, R1, R2)
        S2 = S_tables[2][r1, r2, r3]  # L2 neighbors: (R1, R2, R3)
        S3 = S_tables[3][r2, r3, r4]  # L3 neighbors: (R2, R3, R4)
        S4 = S_tables[4][r3, r4, r0]  # L4 neighbors: (R3, R4, R0)
        
        # Product of all S_i for each right assignment
        products = S0 * S1 * S2 * S3 * S4
        
        # Total homomorphism count
        hom_count = torch.sum(products).item()
        
        return hom_count
    
    def verify_matrix(self, matrix_data: Dict, verbose: bool = False) -> Dict:
        """Verify a single matrix for Sidorenko violation."""
        matrix = matrix_data['matrix']
        
        # Convert to tensor
        if isinstance(matrix, np.ndarray):
            M = torch.tensor(matrix, dtype=torch.float32).to(self.device)
        elif isinstance(matrix, torch.Tensor):
            M = matrix.float().to(self.device)
        else:
            raise ValueError(f"Unsupported matrix type: {type(matrix)}")
        
        # Ensure square matrix
        if M.shape[0] != M.shape[1]:
            raise ValueError(f"Matrix must be square, got {M.shape}")
        
        n = M.shape[0]
        if n != 6:
            if verbose:
                print(f"⚠️  Matrix size {n}×{n} (expected 6×6)")
        
        start_time = time.time()
        
        # Matrix properties
        matrix_sum = torch.sum(M).item()
        matrix_mean = torch.mean(M).item()
        
        # Compute homomorphism count
        hom_count = self.compute_homomorphism_exact(M)
        
        # Normalized density
        t_value = hom_count / (n ** 10)
        
        # Sidorenko threshold
        p = matrix_mean
        threshold = p ** 15
        
        # Check violation
        gap = t_value - threshold
        violation = gap < 0
        
        computation_time = time.time() - start_time
        
        result = {
            'source': matrix_data.get('source', 'unknown'),
            'source_path': matrix_data.get('source_path', ''),
            'original_score': matrix_data.get('score', 0),
            'episode': matrix_data.get('episode', 0),
            'matrix_shape': tuple(M.shape),
            'matrix_sum': matrix_sum,
            'matrix_mean': matrix_mean,
            'homomorphism_count': hom_count,
            'normalized_density': t_value,
            'edge_density': p,
            'threshold': threshold,
            'gap': gap,
            'violation': violation,
            'computation_time': computation_time
        }
        
        # Add family/graph/config info if available
        for key in ['family', 'graph', 'config']:
            if key in matrix_data:
                result[key] = matrix_data[key]
        
        return result

def setup_device():
    """Setup optimal device for computation."""
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("🚀 Using Apple M4 Max GPU with Metal Performance Shaders (MPS)")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
        print("🚀 Using CUDA GPU")
    else:
        device = torch.device("cpu")
        print("💻 Using CPU")
    return device

def main():
    """Main verification function for all framework results."""
    print("🌟 COMPREHENSIVE FRAMEWORK RESULTS VERIFICATION")
    print("="*60)
    
    # Setup
    device = setup_device()
    verifier = SidorenkoVerifier(device)
    
    # Load all results
    loader = FrameworkResultsLoader()
    all_matrices = loader.load_all_matrices()
    
    if not all_matrices:
        print("❌ No matrices found! Make sure you've run the optimization framework.")
        print("\n💡 Expected file locations:")
        print("   - *.pkl files (from save_optimization_results)")
        print("   - *.pth files (model checkpoints)")  
        print("   - W_optimized.csv (from AMCS)")
        print("   - experiment_summary.json (from parallel experiments)")
        return {}
    
    print(f"\n🧮 Verifying {len(all_matrices)} matrices...")
    print("="*50)
    
    # Verify all matrices
    all_results = []
    violations_found = []
    near_violations = []
    
    for i, matrix_data in enumerate(all_matrices):
        try:
            result = verifier.verify_matrix(matrix_data, verbose=False)
            all_results.append(result)
            
            # Categorize results
            if result['violation']:
                violations_found.append(result)
                print(f"🚨 VIOLATION #{len(violations_found)}: {result['source']} "
                      f"(gap: {result['gap']:.2e})")
            elif abs(result['gap']) < 1e-5:
                near_violations.append(result)
                print(f"🎯 Near-violation: {result['source']} "
                      f"(gap: {result['gap']:.2e})")
            else:
                print(f"✅ Matrix {i+1:3d}: {result['source']} "
                      f"(gap: {result['gap']:+.2e})")
                
        except Exception as e:
            print(f"❌ Failed to verify matrix {i+1}: {e}")
    
    # Summary analysis
    print(f"\n{'='*60}")
    print("📊 COMPREHENSIVE VERIFICATION SUMMARY")
    print("="*60)
    
    print(f"Total matrices tested: {len(all_results)}")
    print(f"Violations found: {len(violations_found)}")
    print(f"Near-violations (<1e-5): {len(near_violations)}")
    print(f"Regular matrices: {len(all_results) - len(violations_found) - len(near_violations)}")
    
    if violations_found:
        print(f"\n🎉🎉🎉 BREAKTHROUGH! 🎉🎉🎉")
        print(f"Found {len(violations_found)} Sidorenko violations!")
        print(f"\nViolation Details:")
        for i, result in enumerate(violations_found, 1):
            print(f"  {i}. Source: {result['source']}")
            print(f"     Gap: {result['gap']:.6e}")
            print(f"     Density: {result['normalized_density']:.10f}")
            print(f"     Threshold: {result['threshold']:.10f}")
        
        print(f"\n📝 PUBLICATION READY!")
        print(f"These are counterexamples to Sidorenko's conjecture!")
        
    elif near_violations:
        print(f"\n🎯 VERY PROMISING RESULTS!")
        print(f"Found {len(near_violations)} near-violations!")
        
        # Find the closest one
        closest = min(near_violations, key=lambda x: abs(x['gap']))
        print(f"\nClosest to violation:")
        print(f"  Source: {closest['source']}")
        print(f"  Gap: {closest['gap']:.2e}")
        print(f"  Distance to violation: {abs(closest['gap']):.2e}")
        
        print(f"\n💡 Recommendations:")
        print(f"  - Continue optimization from best results")
        print(f"  - Try different random seeds")
        print(f"  - Increase training episodes")
        print(f"  - Use higher precision for matrices near boundary")
        
    else:
        print(f"\n✅ All matrices satisfy Sidorenko's conjecture")
        print(f"Continue optimization to search for violations!")
    
    # Performance stats
    total_time = sum(r['computation_time'] for r in all_results)
    avg_time = total_time / len(all_results) if all_results else 0
    
    print(f"\n⚡ Performance Statistics:")
    print(f"  Total verification time: {total_time:.2f}s")
    print(f"  Average per matrix: {avg_time:.4f}s")
    print(f"  Throughput: {len(all_results)/total_time:.1f} matrices/sec")
    
    # Save detailed results
    results_summary = {
        'total_matrices': len(all_results),
        'violations_found': len(violations_found),
        'near_violations': len(near_violations),
        'violation_details': violations_found,
        'near_violation_details': near_violations,
        'all_results': all_results
    }
    
    # Save to JSON
    with open('sidorenko_verification_results.json', 'w') as f:
        json.dump(results_summary, f, indent=2, default=str)
    
    print(f"\n💾 Detailed results saved to: sidorenko_verification_results.json")
    
    return results_summary

if __name__ == "__main__":
    results = main()
    
    print(f"\n🏁 Verification complete!")
    if results.get('violations_found', 0) > 0:
        print(f"🎉 MATHEMATICAL BREAKTHROUGH ACHIEVED!")
    else:
        print(f"🔍 Keep optimizing - you're making great progress!")

🌟 COMPREHENSIVE FRAMEWORK RESULTS VERIFICATION
🚀 Using Apple M4 Max GPU with Metal Performance Shaders (MPS)
🔍 Searching in 1 directories:
   📁 /Users/aburyan/Desktop/VS/rl-env/Raymond_ACMS

🔄 Loading all optimization results...

📊 Found 3 optimization result files:
   csv_matrices: 1 files
   json_results: 2 files
✅ Loaded matrix (6, 6) from EXACT_VERIFIED_VIOLATION_1_1751047683.csv

📊 Total matrices loaded: 1

🧮 Verifying 1 matrices...
🚨 VIOLATION #1: csv:EXACT_VERIFIED_VIOLATION_1_1751047683.csv (gap: -4.04e-09)

📊 COMPREHENSIVE VERIFICATION SUMMARY
Total matrices tested: 1
Violations found: 1
Near-violations (<1e-5): 0
Regular matrices: 0

🎉🎉🎉 BREAKTHROUGH! 🎉🎉🎉
Found 1 Sidorenko violations!

Violation Details:
  1. Source: csv:EXACT_VERIFIED_VIOLATION_1_1751047683.csv
     Gap: -4.039570e-09
     Density: 1.0000035722
     Threshold: 1.0000035763

📝 PUBLICATION READY!
These are counterexamples to Sidorenko's conjecture!

⚡ Performance Statistics:
  Total verification time: 0.08s
  