In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.ndimage import label
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

MAX_STEPS = 10

def generate_percolation_lattice(size, p):
    return np.random.choice([0, 1], (size, size), p=[1-p, p]).astype(np.uint8)

def check_percolation(lattice):
    labeled, _ = label(lattice)
    top = set(labeled[0]) - {0}
    bottom = set(labeled[-1]) - {0}
    left = set(labeled[:,0]) - {0}
    right = set(labeled[:,-1]) - {0}
    return float(bool(top & bottom) or bool(left & right))

def first_coarse_graining(binary_lattice, dim):
    """Average non-overlapping dim×dim blocks."""
    t = torch.tensor(binary_lattice, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    patches = F.unfold(t, kernel_size=dim, stride=dim)             # [1, dim*dim, num_patches]
    patches = patches.permute(0, 2, 1)                             # [1, num_patches, dim*dim]
    coarse_vals = patches.mean(dim=2)                             # [1, num_patches]
    H, W = binary_lattice.shape
    new_h, new_w = H // dim, W // dim
    return coarse_vals.view(1, 1, new_h, new_w).squeeze(0)        # [1, new_h, new_w]

class PercolationModel(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.rule = nn.Sequential(
            nn.Linear(dim * dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, max_steps=MAX_STEPS):
        b, c, H, W = x.shape
        for _ in range(max_steps):
            if H < self.dim or W < self.dim:
                break
            patches = F.unfold(x, kernel_size=self.dim, stride=self.dim)  # [b, dim*dim, np]
            patches = patches.permute(0, 2, 1).contiguous()               # [b, np, dim*dim]
            out = self.rule(patches.view(-1, self.dim*self.dim))          # [b*np, 1]
            new_h, new_w = H // self.dim, W // self.dim
            x = out.view(b, 1, new_h, new_w)
            _, _, H, W = x.shape
        return x.squeeze(1).view(b, -1)  # returns shape [b, new_h*new_w] or [b] if fully reduced

def prepare_dataset(N, sizes):
    data = []
    for _ in tqdm(range(N), desc="Generating data"):
        p = np.random.uniform(0.1, 0.9)
        size = np.random.choice(sizes)
        L = generate_percolation_lattice(size, p)
        data.append((L, check_percolation(L)))
    return data

def train_epoch(model, device, data, batch_size, opt, crit):
    model.train()
    total_loss = 0.0
    # Process in batches, then split into size groups
    for i in tqdm(range(0, len(data), batch_size), desc="Training"):
        batch = data[i:i+batch_size]
        processed = []
        # Apply first_coarse_graining and get sizes
        for x, y in batch:
            cg_lattice = first_coarse_graining(x, DIM)  # [1, H', W']
            h, w = cg_lattice.shape[-2], cg_lattice.shape[-1]
            processed.append((cg_lattice, y, (h, w)))
        
        # Group by size
        groups = {}
        for cg, y, size in processed:
            if size not in groups:
                groups[size] = []
            groups[size].append((cg, y))
        
        # Process each group
        group_loss = 0.0
        for size_key, group in groups.items():
            lattices = [item[0] for item in group]
            labels = [item[1] for item in group]
            inputs = torch.stack(lattices).to(device)  # [B, 1, H, W]
            targets = torch.tensor(labels, dtype=torch.float32, device=device)
            
            opt.zero_grad()
            outputs = model(inputs)  # [B, 1]
            loss = crit(outputs.view(-1), targets)
            loss.backward()
            opt.step()
            group_loss += loss.item() * len(group)
        
        total_loss += group_loss
    
    return total_loss / len(data)

def test_systems(model, dim, power, device='cpu',
                              num_tests=50, system_size='standard',
                              p_range=(0,1), verbose=True):
    """
    For each test:
      1) generate a raw DIM^size_power × DIM^size_power lattice
      2) compute true percolation label on that raw lattice
      3) manually coarse-grain once (patch size = dim)
      4) feed the result into model (which will do further recursive steps)
    """
    model.to(device).eval()

    # Determine the exponent for lattice_size
    size_power = {'3^2': power-1, '3^3': power, '3^4': power+1, '3^5': power+2, '3^6': power+3, '3^7': power+4}[system_size]
    L = dim ** size_power

    results = []
    for _ in tqdm(range(num_tests), desc=f"Testing {L}×{L}"):
        # 1) Raw lattice + label
        p   = np.random.uniform(*p_range)
        raw = generate_percolation_lattice(L, p)
        lbl = check_percolation(raw)

        # 2) Manual first coarse-graining
        coarse = first_coarse_graining(raw, dim)   # tensor shape [1, L/dim, L/dim]

        # 3) Prepare input for the model
        inp = coarse.unsqueeze(0).to(device)       # [1, 1, L/dim, L/dim]

        # 4) Get network prediction
        with torch.no_grad():
            # Let the model do its remaining recursion as usual
            # (the `max_steps` is large enough that it will recurse until <dim)
            out = model(inp).view(-1).item()

        results.append((raw, lbl, out))

    # Compute accuracy at 0.5 threshold
    acc = sum((pred > 0.5) == lbl for _, lbl, pred in results) / num_tests

    if verbose:
        pos = [pred for _, lbl, pred in results if lbl==1]
        neg = [pred for _, lbl, pred in results if lbl==0]
        print(f"\nAfter manual first coarse-grain -> NN cascade on {L}×{L}:")
        print(f" Accuracy        : {acc:.2%}")
        print(f" Avg pred | Perc     : {np.mean(pos):.3f}")
        print(f" Avg pred | Non-Perc : {np.mean(neg):.3f}")

    return results



# ----------------- Run -----------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Add this configuration block (critical for model initialization)
DIM = 3      # Patch size for coarse-graining (fixed for the model architecture)
POWER = 3    # Base exponent for lattice sizes (3^POWER = 27x27 as "standard")
SIZES = [DIM**2, DIM**3, DIM**4]  # Mixed training sizes: 9x9, 27x27, 81x81

# Generate mixed-size training data
train_data = prepare_dataset(20_000, SIZES)

# Initialize model with DIM=3 (matches patch size used in first_coarse_graining)
model = PercolationModel(DIM).to(DEVICE)
opt = optim.Adam(model.parameters(), lr=1e-3)
crit = nn.BCELoss()

# Training loop (unchanged)
for epoch in range(1, 20):
    loss = train_epoch(model, DEVICE, train_data, 10, opt, crit)
    print(f"Epoch {epoch} — Loss: {loss:.4f}")