In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.ndimage import label
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

MAX_STEPS = 10

def generate_percolation_lattice(size, p):
    return np.random.choice([0, 1], (size, size), p=[1-p, p]).astype(np.uint8)

def check_percolation(lattice):
    labeled, _ = label(lattice)
    top = set(labeled[0]) - {0}
    bottom = set(labeled[-1]) - {0}
    left = set(labeled[:,0]) - {0}
    right = set(labeled[:,-1]) - {0}
    return float(bool(top & bottom) or bool(left & right))

def first_coarse_graining(binary_lattice, dim):
    """Average non-overlapping dim×dim blocks."""
    t = torch.tensor(binary_lattice, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    patches = F.unfold(t, kernel_size=dim, stride=dim)             # [1, dim*dim, num_patches]
    patches = patches.permute(0, 2, 1)                             # [1, num_patches, dim*dim]
    coarse_vals = patches.mean(dim=2)                             # [1, num_patches]
    H, W = binary_lattice.shape
    new_h, new_w = H // dim, W // dim
    return coarse_vals.view(1, 1, new_h, new_w).squeeze(0)        # [1, new_h, new_w]

class PercolationModel(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.rule = nn.Sequential(
            nn.Linear(dim * dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, max_steps=MAX_STEPS):
        b, c, H, W = x.shape
        for _ in range(max_steps):
            if H < self.dim or W < self.dim:
                break
            patches = F.unfold(x, kernel_size=self.dim, stride=self.dim)  # [b, dim*dim, np]
            patches = patches.permute(0, 2, 1).contiguous()               # [b, np, dim*dim]
            out = self.rule(patches.view(-1, self.dim*self.dim))          # [b*np, 1]
            new_h, new_w = H // self.dim, W // self.dim
            x = out.view(b, 1, new_h, new_w)
            _, _, H, W = x.shape
        return x.squeeze(1).view(b, -1)  # returns shape [b, new_h*new_w] or [b] if fully reduced

def prepare_dataset(N, sizes):
    data = []
    for _ in tqdm(range(int(N/2)), desc="Generating data"):
        p = np.random.uniform(0.1, 0.9)
        size = np.random.choice(sizes)
        L = generate_percolation_lattice(size, p)
        data.append((L, check_percolation(L)))
    for _ in tqdm(range(int(N/2)), desc="Generating fractal data"):
        p = np.random.uniform(0.55,0.65)
        size = np.random.choice(sizes)
        L = generate_percolation_lattice(size,p)
        data.append((L, check_percolation(L)))
    return data

def train_epoch(model, device, data, batch_size, opt, crit):
    model.train()
    total_loss = 0.0
    # Process in batches, then split into size groups
    for i in tqdm(range(0, len(data), batch_size), desc="Training"):
        batch = data[i:i+batch_size]
        processed = []
        # Apply first_coarse_graining and get sizes
        for x, y in batch:
            cg_lattice = first_coarse_graining(x, DIM)  # [1, H', W']
            h, w = cg_lattice.shape[-2], cg_lattice.shape[-1]
            processed.append((cg_lattice, y, (h, w)))
        
        # Group by size
        groups = {}
        for cg, y, size in processed:
            if size not in groups:
                groups[size] = []
            groups[size].append((cg, y))
        
        # Process each group
        group_loss = 0.0
        for size_key, group in groups.items():
            lattices = [item[0] for item in group]
            labels = [item[1] for item in group]
            inputs = torch.stack(lattices).to(device)  # [B, 1, H, W]
            targets = torch.tensor(labels, dtype=torch.float32, device=device)
            
            opt.zero_grad()
            outputs = model(inputs)  # [B, 1]
            loss = crit(outputs.view(-1), targets)
            loss.backward()
            opt.step()
            group_loss += loss.item() * len(group)
        
        total_loss += group_loss
    
    return total_loss / len(data)

def test_systems(model, dim, power, device='cpu',
                              num_tests=50, system_size='standard',
                              p_range=(0,1), verbose=True):
    """
    For each test:
      1) generate a raw DIM^size_power × DIM^size_power lattice
      2) compute true percolation label on that raw lattice
      3) manually coarse-grain once (patch size = dim)
      4) feed the result into model (which will do further recursive steps)
    """
    model.to(device).eval()

    # Determine the exponent for lattice_size
    size_power_dict = {
        f'{dim}^2': power-1,
        f'{dim}^3': power,
        f'{dim}^4': power+1,
        f'{dim}^5': power+2,
        f'{dim}^6': power+3,
        f'{dim}^7': power+4
    }
    size_power = size_power_dict[system_size]
    L = dim ** size_power

    results = []
    for _ in tqdm(range(num_tests), desc=f"Testing {L}×{L}", disable=not verbose):
        # 1) Raw lattice + label
        p   = np.random.uniform(*p_range)
        raw = generate_percolation_lattice(L, p)
        lbl = check_percolation(raw)

        # 2) Manual first coarse-graining
        coarse = first_coarse_graining(raw, dim)   # tensor shape [1, L/dim, L/dim]

        # 3) Prepare input for the model
        inp = coarse.unsqueeze(0).to(device)       # [1, 1, L/dim, L/dim]

        # 4) Get network prediction
        with torch.no_grad():
            # Let the model do its remaining recursion as usual
            # (the `max_steps` is large enough that it will recurse until <dim)
            out = model(inp).view(-1).item()

        results.append((raw, lbl, out))

    # Compute accuracy at 0.5 threshold
    acc = sum((pred > 0.5) == lbl for _, lbl, pred in results) / num_tests

    if verbose:
        pos = [pred for _, lbl, pred in results if lbl==1]
        neg = [pred for _, lbl, pred in results if lbl==0]
        print(f"\nAfter manual first coarse-grain -> NN cascade on {L}×{L}:")
        print(f" Accuracy        : {acc:.2%}")
        print(f" Avg pred | Perc     : {np.mean(pos):.3f}")
        print(f" Avg pred | Non-Perc : {np.mean(neg):.3f}")

    return acc  # Return accuracy instead of full results

# ----------------- Main Function -----------------
def run_experiment(base, epochs=10, train_samples=2000):
    print(f"\n{'='*50}")
    print(f"Running experiment for base {base}")
    print(f"{'='*50}")
    
    global DIM  # We'll modify the global DIM for this experiment
    DIM = base
    POWER = 3
    SIZES = [base**2, base**3, base**4]  # Training sizes: base^2, base^3, base^4
    
    # Test sizes - limit to smaller exponents for larger bases
    if base <= 4:
        TEST_SIZES = [f'{base}^3', f'{base}^4', f'{base}^5', f'{base}^6']  # base^3 to base^6
    else:
        TEST_SIZES = [f'{base}^3', f'{base}^4', f'{base}^5']  # For base 5, skip base^6
    
    # Generate training data
    train_data = prepare_dataset(train_samples, SIZES)
    
    # Initialize model
    model = PercolationModel(DIM).to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    crit = nn.BCELoss()
    
    # Track metrics
    train_losses = []
    gen_errors = {size: [] for size in TEST_SIZES}
    
    # Training loop
    for epoch in range(1, epochs + 1):
        loss = train_epoch(model, DEVICE, train_data, 10, opt, crit)
        train_losses.append(loss)
        
        print(f"\nEpoch {epoch} - Loss: {loss:.4f}")
        for size in TEST_SIZES:
            # Reduce tests for larger lattices
            exp = int(size.split('^')[1])
            num_tests = 100 if base**exp <= 1000 else 50  # Fewer tests for larger systems
            
            acc = test_systems(model, DIM, POWER, DEVICE, num_tests=num_tests, 
                              system_size=size, p_range=(0.1, 0.9), verbose=False)
            error = 1 - acc
            gen_errors[size].append(error)
            print(f"Generalization error ({size}): {error:.4f}")
    
    # Plot results
    epochs_range = range(1, epochs + 1)
    plt.figure(figsize=(10, 6))
    
    # Plot training loss
    plt.plot(epochs_range, train_losses, 'b-o', label='Training Loss', linewidth=2)
    
    # Plot generalization errors
    colors = ['g', 'r', 'c', 'm', 'y']
    for i, size in enumerate(TEST_SIZES):
        plt.plot(epochs_range, gen_errors[size], f'{colors[i]}-s', 
                 label=f'Gen Error ({size})', alpha=0.7)
    
    plt.xlabel('Epochs', fontsize=16)
    plt.ylabel('Loss / Error', fontsize=16)
    plt.title(f'Training Loss and Generalization Error (Base {base})', fontsize=20)
    plt.legend(fontsize=13)
    plt.xticks(epochs_range)
    plt.tight_layout()
    plt.savefig(f'loss_gen_base{base}.png', dpi=300)
    plt.show()
    
    return train_losses, gen_errors

# ----------------- Run Experiments -----------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Run for base 3 (original)
run_experiment(base=3, epochs=10, train_samples=1500)

# Run for base 4
run_experiment(base=4, epochs=10, train_samples=1500)

# Run for base 5 (with reduced parameters)
run_experiment(base=5, epochs=12, train_samples=1500)