In [6]:
import os
import logging
import torch
import pandas as pd
import numpy as np

def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[logging.StreamHandler()]
    )

class SimpleDataset:
    """Simplified dataset class to load the processed data.pt file"""
    def __init__(self, root='./processed_experimental/processed'):
        data_path = os.path.join(root, 'data.pt')
        self.data, self.slices = torch.load(data_path)
        
    def __len__(self):
        return self.slices['x'].size(0) - 1
    
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return [self[i] for i in range(*idx.indices(len(self)))]
        
        # Create a class to hold the data similar to Data object
        class DataObject:
            pass
        
        data = DataObject()
        
        for key in self.slices.keys():
            item, slices = self.data[key], self.slices[key]
            s, e = slices[idx].item(), slices[idx + 1].item()
            data.__setattr__(key, item[s:e])
        
        return data

def verify_mutual_information_bound(dataset):
    """
    Verify that mutual information is always lower than or equal to von Neumann entropy
    and analyze any violations.
    """
    logging.info("Verifying Mutual Information bounds...")
    
    violations = []
    violation_sizes = []
    max_violation = 0
    total_samples = len(dataset)
    
    # Collect statistics per system size
    size_stats = {}
    
    for idx in range(len(dataset)):
        data = dataset[idx]
        mi = data.mutual_info.item()
        vn = data.y.item()  # von Neumann entropy
        system_size = data.system_size.item()
        
        # Check if MI > VN (allowing for small numerical errors)
        if mi > vn + 1e-6:  # tolerance of 1e-6
            violations.append({
                'idx': idx,
                'mi': mi,
                'vn': vn,
                'difference': mi - vn,
                'system_size': system_size,
                'nA': data.nA.item(),
                'nB': data.nB.item()
            })
            violation_sizes.append(system_size)
            max_violation = max(max_violation, mi - vn)
            
        # Update size statistics
        if system_size not in size_stats:
            size_stats[system_size] = {
                'count': 0,
                'violations': 0,
                'max_ratio': 0,  # MI/VN ratio
                'mean_ratio': 0,
                'ratios': []
            }
        
        stats = size_stats[system_size]
        stats['count'] += 1
        if mi > vn + 1e-6:
            stats['violations'] += 1
        
        # Calculate MI/VN ratio
        ratio = mi / (vn + 1e-10)  # avoid division by zero
        stats['ratios'].append(ratio)
        stats['max_ratio'] = max(stats['max_ratio'], ratio)

        # Print progress every 10k samples
        if (idx + 1) % 10000 == 0:
            logging.info(f"Processed {idx + 1}/{total_samples} samples...")
    
    # Compute final statistics per size
    for size, stats in size_stats.items():
        stats['violation_rate'] = (stats['violations'] / stats['count']) * 100
        stats['mean_ratio'] = np.mean(stats['ratios'])
    
    # Print summary
    logging.info(f"\nMutual Information Bound Analysis:")
    logging.info(f"Total samples analyzed: {total_samples}")
    logging.info(f"Number of violations: {len(violations)}")
    logging.info(f"Overall violation rate: {(len(violations)/total_samples)*100:.4f}%")
    if violations:
        logging.info(f"Maximum violation: {max_violation:.6f}")
        logging.info(f"System sizes with violations: {sorted(set(violation_sizes))}")
        
        logging.info("\nTop 10 worst violations:")
        sorted_violations = sorted(violations, key=lambda x: x['difference'], reverse=True)
        for v in sorted_violations[:10]:
            logging.info(
                f"  Size {v['system_size']:2} (nA={v['nA']:.0f}, nB={v['nB']:.0f}): "
                f"MI={v['mi']:.6f}, VN={v['vn']:.6f}, Diff={v['difference']:.6f}"
            )
    
    logging.info("\nStatistics per system size:")
    for size in sorted(size_stats.keys()):
        stats = size_stats[size]
        logging.info(f"Size {size:2}:")
        logging.info(f"  Samples: {stats['count']}")
        logging.info(f"  Violation rate: {stats['violation_rate']:.4f}%")
        logging.info(f"  Mean MI/VN ratio: {stats['mean_ratio']:.4f}")
        logging.info(f"  Max MI/VN ratio: {stats['max_ratio']:.4f}")
    
    return violations, size_stats

def main():
    setup_logging()
    
    # Load dataset
    try:
        dataset = SimpleDataset()
        logging.info(f"Loaded dataset with {len(dataset)} samples")
    except Exception as e:
        logging.error(f"Error loading dataset: {e}")
        return
    
    # Verify MI bounds
    violations, size_stats = verify_mutual_information_bound(dataset)
    
    # Save violations if any found
    if violations:
        violation_data = pd.DataFrame(violations)
        violation_data.to_csv('mi_violations.csv', index=False)
        logging.warning(f"Found violations! Details saved to mi_violations.csv")
    else:
        logging.info("No violations found - MI is a proper lower bound for VN entropy")

if __name__ == "__main__":
    main()

  self.data, self.slices = torch.load(data_path)
2025-01-19 03:13:39,732 [INFO] Loaded dataset with 5000 samples
2025-01-19 03:13:39,732 [INFO] Verifying Mutual Information bounds...
2025-01-19 03:13:40,257 [INFO] 
Mutual Information Bound Analysis:
2025-01-19 03:13:40,258 [INFO] Total samples analyzed: 5000
2025-01-19 03:13:40,258 [INFO] Number of violations: 518
2025-01-19 03:13:40,260 [INFO] Overall violation rate: 10.3600%
2025-01-19 03:13:40,260 [INFO] Maximum violation: 0.530045
2025-01-19 03:13:40,261 [INFO] System sizes with violations: [2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0]
2025-01-19 03:13:40,261 [INFO] 
Top 10 worst violations:
2025-01-19 03:13:40,261 [INFO]   Size 12.0 (nA=7, nB=5): MI=0.973731, VN=0.443687, Diff=0.530045
2025-01-19 03:13:40,262 [INFO]   Size 2.0 (nA=1, nB=1): MI=0.698826, VN=0.655488, Diff=0.043338
2025-01-19 03:13:40,263 [INFO]   Size 2.0 (nA=1, nB=1): MI=0.698107, VN=0.660415, Diff=0.037692
2025-01-19 03:13:40,264 [INFO]   Size 2.0 (nA=1, nB=1): MI=0.043

In [10]:
import os
import logging
import torch
import pandas as pd
import numpy as np

def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[logging.StreamHandler()]
    )

class SimpleDataset:
    """Simplified dataset class to load the processed data.pt file"""
    def __init__(self, root='./processed_experimental/processed'):
        data_path = os.path.join(root, 'data.pt')
        self.data, self.slices = torch.load(data_path)
        
    def __len__(self):
        return self.slices['x'].size(0) - 1
    
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return [self[i] for i in range(*idx.indices(len(self)))]
        
        # Create a class to hold the data similar to Data object
        class DataObject:
            pass
        
        data = DataObject()
        
        for key in self.slices.keys():
            item, slices = self.data[key], self.slices[key]
            s, e = slices[idx].item(), slices[idx + 1].item()
            data.__setattr__(key, item[s:e])
        
        return data

def verify_mutual_information_bound(dataset):
    """
    Verify that mutual information is always lower than or equal to von Neumann entropy
    and analyze any violations.
    """
    logging.info("Verifying Mutual Information bounds...")
    
    violations = []
    violation_sizes = []
    max_violation = 0
    total_samples = len(dataset)
    
    # Collect statistics per system size
    size_stats = {}
    
    for idx in range(len(dataset)):
        data = dataset[idx]
        mi = data.mutual_info.item()
        vn = data.y.item()  # von Neumann entropy
        system_size = data.system_size.item()
        
        # Check if MI > VN (allowing for small numerical errors)
        if mi > vn + 1e-6:  # tolerance of 1e-6
            violations.append({
                'idx': idx,
                'mi': mi,
                'vn': vn,
                'difference': mi - vn,
                'system_size': system_size,
                'nA': data.nA.item(),
                'nB': data.nB.item()
            })
            violation_sizes.append(system_size)
            max_violation = max(max_violation, mi - vn)
            
        # Update size statistics
        if system_size not in size_stats:
            size_stats[system_size] = {
                'count': 0,
                'violations': 0,
                'max_ratio': 0,  # MI/VN ratio
                'mean_ratio': 0,
                'ratios': []
            }
        
        stats = size_stats[system_size]
        stats['count'] += 1
        if mi > vn + 1e-6:
            stats['violations'] += 1
        
        # Calculate MI/VN ratio
        ratio = mi / (vn + 1e-10)  # avoid division by zero
        stats['ratios'].append(ratio)
        stats['max_ratio'] = max(stats['max_ratio'], ratio)

        # Print progress every 10k samples
        if (idx + 1) % 10000 == 0:
            logging.info(f"Processed {idx + 1}/{total_samples} samples...")
    
    # Compute final statistics per size
    for size, stats in size_stats.items():
        stats['violation_rate'] = (stats['violations'] / stats['count']) * 100
        stats['mean_ratio'] = np.mean(stats['ratios'])
    
    # Print summary
    logging.info(f"\nMutual Information Bound Analysis:")
    logging.info(f"Total samples analyzed: {total_samples}")
    logging.info(f"Number of violations: {len(violations)}")
    logging.info(f"Overall violation rate: {(len(violations)/total_samples)*100:.4f}%")
    if violations:
        logging.info(f"Maximum violation: {max_violation:.6f}")
        logging.info(f"System sizes with violations: {sorted(set(violation_sizes))}")
        
        logging.info("\nTop 10 worst violations:")
        sorted_violations = sorted(violations, key=lambda x: x['difference'], reverse=True)
        for v in sorted_violations[:10]:
            logging.info(
                f"  Size {v['system_size']:2} (nA={v['nA']:.0f}, nB={v['nB']:.0f}): "
                f"MI={v['mi']:.6f}, VN={v['vn']:.6f}, Diff={v['difference']:.6f}"
            )
    
    logging.info("\nStatistics per system size:")
    for size in sorted(size_stats.keys()):
        stats = size_stats[size]
        logging.info(f"Size {size:2}:")
        logging.info(f"  Samples: {stats['count']}")
        logging.info(f"  Violation rate: {stats['violation_rate']:.4f}%")
        logging.info(f"  Mean MI/VN ratio: {stats['mean_ratio']:.4f}")
        logging.info(f"  Max MI/VN ratio: {stats['max_ratio']:.4f}")
    
    return violations, size_stats

def main():
    setup_logging()
    
    # Load dataset
    try:
        dataset = SimpleDataset()
        logging.info(f"Loaded dataset with {len(dataset)} samples")
    except Exception as e:
        logging.error(f"Error loading dataset: {e}")
        return
    
    # Verify MI bounds
    violations, size_stats = verify_mutual_information_bound(dataset)
    
    # Save violations if any found
    if violations:
        violation_data = pd.DataFrame(violations)
        violation_data.to_csv('mi_violations.csv', index=False)
        logging.warning(f"Found violations! Details saved to mi_violations.csv")
    else:
        logging.info("No violations found - MI is a proper lower bound for VN entropy")

if __name__ == "__main__":
    main()

  self.data, self.slices = torch.load(data_path)
2025-01-19 03:22:31,655 [INFO] Loaded dataset with 5000 samples
2025-01-19 03:22:31,656 [INFO] Verifying Mutual Information bounds...
2025-01-19 03:22:32,164 [INFO] 
Mutual Information Bound Analysis:
2025-01-19 03:22:32,165 [INFO] Total samples analyzed: 5000
2025-01-19 03:22:32,165 [INFO] Number of violations: 518
2025-01-19 03:22:32,166 [INFO] Overall violation rate: 10.3600%
2025-01-19 03:22:32,166 [INFO] Maximum violation: 0.530045
2025-01-19 03:22:32,167 [INFO] System sizes with violations: [2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0]
2025-01-19 03:22:32,167 [INFO] 
Top 10 worst violations:
2025-01-19 03:22:32,167 [INFO]   Size 12.0 (nA=7, nB=5): MI=0.973731, VN=0.443687, Diff=0.530045
2025-01-19 03:22:32,167 [INFO]   Size 2.0 (nA=1, nB=1): MI=0.698826, VN=0.655488, Diff=0.043338
2025-01-19 03:22:32,167 [INFO]   Size 2.0 (nA=1, nB=1): MI=0.698107, VN=0.660415, Diff=0.037692
2025-01-19 03:22:32,169 [INFO]   Size 2.0 (nA=1, nB=1): MI=0.043