In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.ndimage import label
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import json
from collections import defaultdict
from matplotlib.backends.backend_pdf import PdfPages

MAX_STEPS = 10

def generate_percolation_lattice(size, p):
    return np.random.choice([0, 1], (size, size), p=[1-p, p]).astype(np.uint8)

def check_percolation(lattice):
    labeled, _ = label(lattice)
    top = set(labeled[0]) - {0}
    bottom = set(labeled[-1]) - {0}
    left = set(labeled[:,0]) - {0}
    right = set(labeled[:,-1]) - {0}
    return float(bool(top & bottom) or bool(left & right))

def first_coarse_graining(binary_lattice, dim):
    """Average non-overlapping dim×dim blocks."""
    t = torch.tensor(binary_lattice, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    patches = F.unfold(t, kernel_size=dim, stride=dim)             # [1, dim*dim, num_patches]
    patches = patches.permute(0, 2, 1)                             # [1, num_patches, dim*dim]
    coarse_vals = patches.mean(dim=2)                             # [1, num_patches]
    H, W = binary_lattice.shape
    new_h, new_w = H // dim, W // dim
    return coarse_vals.view(1, 1, new_h, new_w).squeeze(0)        # [1, new_h, new_w]

class PercolationModel(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.rule = nn.Sequential(
            nn.Linear(dim * dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, max_steps=MAX_STEPS):
        b, c, H, W = x.shape
        for _ in range(max_steps):
            if H < self.dim or W < self.dim:
                break
            patches = F.unfold(x, kernel_size=self.dim, stride=self.dim)  # [b, dim*dim, np]
            patches = patches.permute(0, 2, 1).contiguous()               # [b, np, dim*dim]
            out = self.rule(patches.view(-1, self.dim*self.dim))          # [b*np, 1]
            new_h, new_w = H // self.dim, W // self.dim
            x = out.view(b, 1, new_h, new_w)
            _, _, H, W = x.shape
        return x.squeeze(1).view(b, -1)  # returns shape [b, new_h*new_w] or [b] if fully reduced

def prepare_dataset(N, sizes):
    data = []
    for _ in tqdm(range(int(N/2)), desc="Generating data"):
        p = np.random.uniform(0.1, 0.9)
        size = np.random.choice(sizes)
        L = generate_percolation_lattice(size, p)
        data.append((L, check_percolation(L)))
    for _ in tqdm(range(int(N/2)), desc="Generating fractal data"):
        p = np.random.uniform(0.55,0.65)
        size = np.random.choice(sizes)
        L = generate_percolation_lattice(size,p)
        data.append((L, check_percolation(L)))
    return data

def train_epoch(model, device, data, batch_size, opt, crit, dim):
    model.train()
    total_loss = 0.0
    # Process in batches, then split into size groups
    for i in tqdm(range(0, len(data), batch_size), desc="Training"):
        batch = data[i:i+batch_size]
        processed = []
        # Apply first_coarse_graining and get sizes
        for x, y in batch:
            cg_lattice = first_coarse_graining(x, dim)  # [1, H', W']
            h, w = cg_lattice.shape[-2], cg_lattice.shape[-1]
            processed.append((cg_lattice, y, (h, w)))
        
        # Group by size
        groups = {}
        for cg, y, size in processed:
            if size not in groups:
                groups[size] = []
            groups[size].append((cg, y))
        
        # Process each group
        group_loss = 0.0
        for size_key, group in groups.items():
            lattices = [item[0] for item in group]
            labels = [item[1] for item in group]
            inputs = torch.stack(lattices).to(device)  # [B, 1, H, W]
            targets = torch.tensor(labels, dtype=torch.float32, device=device)
            
            opt.zero_grad()
            outputs = model(inputs)  # [B, 1]
            loss = crit(outputs.view(-1), targets)
            loss.backward()
            opt.step()
            group_loss += loss.item() * len(group)
        
        total_loss += group_loss
    
    return total_loss / len(data)

def test_systems(model, dim, power, device='cpu',
                 num_tests=50, system_size='standard',
                 p_range=(0,1), verbose=True):
    """
    For each test:
      1) generate a raw DIM^size_power × DIM^size_power lattice
      2) compute true percolation label on that raw lattice
      3) manually coarse-grain once (patch size = dim)
      4) feed the result into model (which will do further recursive steps)
    """
    model.to(device).eval()

    # Build a dimension‐dependent mapping from "dim^k" strings to exponents
    mapping = {
        f'{dim}^2': power - 1,
        f'{dim}^3': power,
        f'{dim}^4': power + 1,
        f'{dim}^5': power + 2,
        f'{dim}^6': power + 3,
        f'{dim}^7': power + 4
    }
    size_power = mapping[system_size]
    L = dim ** size_power

    results = []
    for _ in tqdm(range(num_tests), desc=f"Testing {L}×{L}"):
        # 1) Raw lattice + label
        p   = np.random.uniform(*p_range)
        raw = generate_percolation_lattice(L, p)
        lbl = check_percolation(raw)

        # 2) Manual first coarse-graining
        coarse = first_coarse_graining(raw, dim)   # tensor shape [1, L/dim, L/dim]

        # 3) Prepare input for the model
        inp = coarse.unsqueeze(0).to(device)       # [1, 1, L/dim, L/dim]

        # 4) Get network prediction
        with torch.no_grad():
            # Let the model do its remaining recursion as usual
            # (the `max_steps` is large enough that it will recurse until <dim)
            out = model(inp).view(-1).item()

        results.append((raw, lbl, out))

    # Compute accuracy at 0.5 threshold
    acc = sum((pred > 0.5) == lbl for _, lbl, pred in results) / num_tests
    pos = [pred for _, lbl, pred in results if lbl == 1]
    neg = [pred for _, lbl, pred in results if lbl == 0]
    
    metrics = {
        'accuracy': acc,
        'avg_pred_perc': np.mean(pos) if pos else 0,
        'avg_pred_non_perc': np.mean(neg) if neg else 0
    }
    
    if verbose:
        print(f"\nAfter manual first coarse-grain -> NN cascade on {L}×{L}:")
        print(f" Accuracy        : {acc:.2%}")
        print(f" Avg pred | Perc     : {metrics['avg_pred_perc']:.3f}")
        print(f" Avg pred | Non-Perc : {metrics['avg_pred_non_perc']:.3f}")

    return metrics

def visualize_rule(model, dim, device='cpu', num_samples=1000):
    p_values = np.linspace(0, 1, 100)
    model.eval()
    mean_outputs = []
    with torch.no_grad():
        for p in p_values:
            inputs = torch.full((num_samples, dim*dim), p, 
                               dtype=torch.float32, device=device)
            outputs = model.rule(inputs).cpu().numpy()
            mean_outputs.append(outputs.mean())
    
    # Calculate exact intersections with y=x bisector, ignoring near 0 and 1
    mean_outputs = np.array(mean_outputs)
    diff = mean_outputs - p_values
    crossings = []
    for i in range(len(p_values) - 1):
        # Only consider points between 0.1 and 0.9
        if p_values[i] < 0.1 or p_values[i] > 0.9:
            continue
        if diff[i] * diff[i+1] <= 0:  # Sign change indicates crossing
            # Linear interpolation for precise crossing point
            x1, x2 = p_values[i], p_values[i+1]
            y1, y2 = diff[i], diff[i+1]
            if y1 == y2:
                continue  # Avoid division by zero
            cross = x1 - y1 * (x2 - x1) / (y2 - y1)
            crossings.append(cross)
    
    # Find critical point (closest approach to bisector in the valid range)
    valid_mask = (p_values >= 0.1) & (p_values <= 0.9)
    valid_p = p_values[valid_mask]
    valid_out = mean_outputs[valid_mask]
    abs_diff = np.abs(valid_out - valid_p)
    if len(valid_p) > 0:
        p_c_model = valid_p[np.argmin(abs_diff)]
    else:
        p_c_model = p_values[np.argmin(np.abs(mean_outputs - p_values))]
    
    return p_values, mean_outputs, p_c_model, crossings

def run_experiment(dim, power, num_runs=50, device='cpu'):
    """Run multiple experiments for a given dimension"""
    all_rule_curves = []
    all_test_results = defaultdict(list)
    all_pc_values = []
    
    for run_idx in range(num_runs):
        print(f"\n{'='*40}")
        print(f"Run {run_idx+1}/{num_runs} for DIM={dim}")
        print(f"{'='*40}")
        
        # Configuration
        sizes = [dim**2, dim**3, dim**4]
        
        # Generate mixed-size training data
        train_data = prepare_dataset(10_000, sizes)
        
        # Initialize model
        model = PercolationModel(dim).to(device)
        opt = optim.Adam(model.parameters(), lr=1e-3)
        crit = nn.BCELoss()
        
        # Training loop
        for epoch in range(1, 7):
            loss = train_epoch(model, device, train_data, 10, opt, crit, dim)
            print(f"Epoch {epoch} — Loss: {loss:.4f}")
        
        # Test configurations (using dimension‐dependent labels)
        test_configs = [
            {'system_size': f'{dim}^2', 'num_tests': 100, 'p_range': (0.55, 0.65)},
            {'system_size': f'{dim}^3', 'num_tests': 100, 'p_range': (0.55, 0.65)},
            {'system_size': f'{dim}^4', 'num_tests': 100, 'p_range': (0.55, 0.65)},
            {'system_size': f'{dim}^5', 'num_tests': 100, 'p_range': (0.55, 0.65)},
        ]
        
        # Run tests
        for config in test_configs:
            key = f"{config['system_size']}_{config['p_range'][0]}-{config['p_range'][1]}"
            metrics = test_systems(
                model, dim, power, device,
                num_tests=config['num_tests'],
                system_size=config['system_size'],
                p_range=config['p_range'],
                verbose=False
            )
            all_test_results[key].append(metrics)
        
        # Visualize rule and save curve
        p_vals, outputs, p_c_model, crossings = visualize_rule(model, dim, device)  # Now unpacking 4 values
        all_rule_curves.append((p_vals, outputs, p_c_model, crossings))  # Storing all 4 values
        all_pc_values.append(p_c_model)
        
        # Clean up to save memory
        del model, opt, crit, train_data
        torch.cuda.empty_cache()
    
    return all_rule_curves, all_test_results, all_pc_values

def save_consolidated_results(results_dict):
    with PdfPages("consolidated_results.pdf") as pdf:
        for dim, results in results_dict.items():
            all_rule_curves, all_test_results, all_pc_values = results
            plt.figure(figsize=(10, 6))
            
            # Plot all runs for this dimension
            for i, (p_vals, outputs, p_c_model, crossings) in enumerate(all_rule_curves):
                # Plot rule curve
                plt.plot(p_vals, outputs, alpha=0.5, color='blue')
                
                # Plot all intersections with bisector (filtered to ignore near 0 and 1)
                valid_crossings = [cross for cross in crossings if 0.1 <= cross <= 0.9]
                for cross in valid_crossings:
                    plt.scatter(cross, cross, color='black', s=30, zorder=3)
            
            # Plot dotted bisector line y = x
            plt.plot([0, 1], [0, 1], 'k--', label=r'$f(p) = p$', linewidth=1.5)
            
            # Add legend entry for average critical point
            avg_pc = np.mean(all_pc_values)
            plt.plot([], [], ' ', label=f'Avg Model $p_c$ = {avg_pc:.3f}')
            
            # Set axis labels with font size
            plt.xlabel('Density (p)', fontsize=14)
            plt.ylabel(r'$f_{\theta}(p\mathbf{1})$', fontsize=14)
            
            plt.title(f"Rule Projection - Mixed size - DIM={dim}", fontsize=18)
            plt.legend(fontsize=12)
            plt.grid(True, alpha=0.3)
            plt.xlim(0, 1)
            plt.ylim(0, 1)
            pdf.savefig(bbox_inches='tight')
            plt.close()
    
    # Create single UTF-8 text file with all results
    with open("consolidated_results.txt", "w", encoding="utf-8") as f:
        for dim, results in results_dict.items():
            all_rule_curves, all_test_results, all_pc_values = results
            # Write p_c values
            f.write(f"{'='*40}\n")
            f.write(f"Critical Point Analysis (DIM={dim})\n")
            f.write(f"{'='*40}\n\n")
            f.write("p_c values from each run:\n")
            for i, pc in enumerate(all_pc_values):
                f.write(f"Run {i+1}: {pc:.6f}\n")
            
            f.write(f"\nAverage p_c: {np.mean(all_pc_values):.6f}\n")
            f.write(f"Standard deviation: {np.std(all_pc_values):.6f}\n\n")
            
            # Write test results
            f.write(f"{'='*40}\n")
            f.write(f"Test Performance Metrics (DIM={dim})\n")
            f.write(f"{'='*40}\n\n")
            
            for config, results_list in all_test_results.items():
                # Parse configuration details
                parts = config.split('_')
                system_size = parts[0]
                p_range = parts[1]
                
                # Extract metrics
                accuracies = [r['accuracy'] for r in results_list]
                avg_perc = [r['avg_pred_perc'] for r in results_list]
                avg_non_perc = [r['avg_pred_non_perc'] for r in results_list]
                
                # Write header
                f.write(f"System: {system_size}, p-range: {p_range}\n")
                f.write("-"*50 + "\n")
                
                # Write detailed run data
                f.write("Run | Accuracy | Avg Perc | Avg Non-Perc\n")
                f.write("----|----------|----------|------------\n")
                for i in range(len(results_list)):
                    f.write(f"{i+1:3d} | {accuracies[i]:.4f} | {avg_perc[i]:.4f} | {avg_non_perc[i]:.4f}\n")
                
                # Write summary statistics
                f.write("\nSummary Statistics:\n")
                f.write(f"Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}\n")
                f.write(f"Avg Perc: {np.mean(avg_perc):.4f} ± {np.std(avg_perc):.4f}\n")
                f.write(f"Avg Non-Perc: {np.mean(avg_non_perc):.4f} ± {np.std(avg_non_perc):.4f}\n")
                f.write("="*50 + "\n\n")
            f.write("\n\n")  # Space between dimensions

def main():
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    POWER = 3
    
    # Collect all results in a dictionary
    results_dict = {}
    
    print("\n\n" + "="*50)
    print("STARTING EXPERIMENTS FOR DIM=3")
    print("="*50)
    all_rule_curves_3, all_test_results_3, all_pc_values_3 = run_experiment(
        dim=3, power=POWER, num_runs=50, device=DEVICE
    )
    results_dict[3] = (all_rule_curves_3, all_test_results_3, all_pc_values_3)
    
    print("\n\n" + "="*50)
    print("STARTING EXPERIMENTS FOR DIM=4")
    print("="*50)
    all_rule_curves_4, all_test_results_4, all_pc_values_4 = run_experiment(
        dim=4, power=POWER, num_runs=50, device=DEVICE
    )
    results_dict[4] = (all_rule_curves_4, all_test_results_4, all_pc_values_4)
    
    print("\n\n" + "="*50)
    print("STARTING EXPERIMENTS FOR DIM=5")
    print("="*50)
    all_rule_curves_5, all_test_results_5, all_pc_values_5 = run_experiment(
        dim=5, power=POWER, num_runs=50, device=DEVICE
    )
    results_dict[5] = (all_rule_curves_5, all_test_results_5, all_pc_values_5)
    
    # Save all results in consolidated files
    save_consolidated_results(results_dict)

if __name__ == "__main__":
    main()
