In [None]:
# -*- coding: utf-8 -*-
"""
Information Bottleneck (IB) Trajectory Verification via N-Sweep (Multi-Run Version)

This script implements a novel and theoretically sound experiment to demonstrate
the Information Bottleneck phenomenon. This version is enhanced to run the
entire experiment multiple times with different seeds to ensure the statistical
robustness of the findings.

Hypothesis (v3 - The Correct Interpretation):
The IB dynamic does not correspond to a monotonic increase in data (D), but
rather to a change in the internal resources (N) allocated to a concept.
- The Fitting Phase is analogous to INCREASING the model capacity (N).
- The Compression Phase is analogous to subsequently DECREASING the model capacity (N).

This script simulates this by training a series of models with
capacities (N) that first increase and then decrease, repeating this entire
process for multiple runs to ensure the resulting trajectory is a robust phenomenon.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import networkx as nx
from tqdm import tqdm
import os
import json
import warnings
import copy
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import minmax_scale

warnings.filterwarnings("ignore", category=FutureWarning)

# --- 1. Configuration ---
CONFIG = {
    "base_seed": 42,
    "num_runs": 5, # Number of independent runs to perform
    "fixed_epochs": 40,
    "dataset_size": 10000,
    "hidden_size_sweep": [4, 6, 8, 12, 16, 24, 32, 48, 64, 96, 128, 96, 64, 48, 32, 24, 16, 12, 8, 6, 4],
    "batch_size": 256,
    "learning_rate": 0.002,
    "analysis_sample_size": 50,
    "results_filename": "n_sweep_ib_results_multirun.json",
}

# --- 2. Model and Theory Analyzer Definition (remains the same) ---
class SimpleMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=10, num_classes=10):
        super(SimpleMLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )
    def forward(self, x):
        return self.layers(x.view(-1, 784))

class TheoryAnalyzer:
    def __init__(self, model):
        model_copy = copy.deepcopy(model); model_copy.eval()
        self.model = model_copy.to('cpu')
        self.linear_layers = [m for m in self.model.modules() if isinstance(m, nn.Linear)]
        self.graph = self._build_graph()
        self.grounding_nodes = self._get_grounding_nodes()
        self.hidden_nodes = self._get_hidden_nodes()
        self.memoized_paths = {}
    def _build_graph(self):
        G = nx.DiGraph(); layers = self.linear_layers;
        if not layers: return G
        for i in range(layers[0].in_features): G.add_node(f"0-{i}", layer=0)
        for i, l in enumerate(layers):
            for j in range(l.out_features): G.add_node(f"{i+1}-{j}", layer=i+1)
            weights = torch.abs(l.weight.data.t()); probs = torch.softmax(weights, dim=1)
            for u in range(l.in_features):
                for v in range(l.out_features):
                    p = probs[u, v].item()
                    if p > 1e-9: G.add_edge(f"{i}-{u}", f"{i+1}-{v}", cost=1.0 - np.log(p))
        return G
    def _get_grounding_nodes(self):
        num_linear_layers = len(self.linear_layers)
        return {node for node, data in self.graph.nodes(data=True) if data['layer'] == num_linear_layers}
    def _get_hidden_nodes(self):
        num_linear_layers = len(self.linear_layers)
        return [node for node, data in self.graph.nodes(data=True) if data['layer'] in range(1, num_linear_layers)]
    def find_all_paths_dfs(self, start, targets):
        memo_key = (start, tuple(sorted(list(targets))))
        if memo_key in self.memoized_paths: return self.memoized_paths[memo_key]
        paths, stack = [], [(start, [start], 0)]
        while stack:
            curr, path, cost = stack.pop()
            if curr in targets: paths.append({'path': path, 'cost': cost}); continue
            if len(path) > (len(self.linear_layers) + 2): continue
            for neighbor in self.graph.neighbors(curr):
                if neighbor not in path:
                    new_cost = cost + self.graph[curr][neighbor]['cost']
                    stack.append((neighbor, path + [neighbor], new_cost))
        self.memoized_paths[memo_key] = paths; return paths
    def calculate_metrics_for_node(self, node):
        paths = self.find_all_paths_dfs(node, self.grounding_nodes)
        if not paths: return float('inf'), float('inf')
        costs = np.array([p['cost'] for p in paths])
        importances = np.exp(-1.0 * costs); conductances = 1.0 / costs
        h_tse = 1.0 / np.sum(conductances) if np.sum(conductances) > 0 else float('inf')
        total_importance = np.sum(importances)
        probabilities = importances / total_importance if total_importance > 0 else importances
        h_sie = -np.sum(probabilities * np.log2(probabilities + 1e-9))
        return h_tse, h_sie
    def analyze_model_state(self, sample_size):
        htse_vals, hsie_vals = [], []
        if not self.hidden_nodes: return 0, 0
        sample_size = min(sample_size, len(self.hidden_nodes))
        sampled_nodes = np.random.choice(self.hidden_nodes, size=sample_size, replace=False)
        for node in sampled_nodes:
            h_tse, h_sie = self.calculate_metrics_for_node(node)
            if np.isfinite(h_tse) and np.isfinite(h_sie):
                htse_vals.append(h_tse); hsie_vals.append(h_sie)
        return np.mean(htse_vals) if htse_vals else 0, np.mean(hsie_vals) if hsie_vals else 0

# --- 3. Core Experiment and Plotting Functions ---

def run_single_sweep(seed, full_train_dataset, device):
    """Executes a single full N-Sweep for a given seed."""
    torch.manual_seed(seed)
    np.random.seed(seed)

    indices = torch.randperm(len(full_train_dataset))[:CONFIG["dataset_size"]]
    train_subset = Subset(full_train_dataset, indices)
    train_loader = DataLoader(train_subset, batch_size=CONFIG["batch_size"], shuffle=True)

    run_results = []
    pbar = tqdm(CONFIG["hidden_size_sweep"], desc=f"Running Sweep for Seed {seed}", leave=False)
    for hidden_size in pbar:
        model = SimpleMLP(hidden_size=hidden_size).to(device)
        optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"])
        criterion = nn.CrossEntropyLoss()

        model.train()
        for epoch in range(CONFIG["fixed_epochs"]):
            for data, target in train_loader:
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()

        model.eval()
        analyzer = TheoryAnalyzer(model)
        final_htse, final_hsie = analyzer.analyze_model_state(CONFIG["analysis_sample_size"])

        run_results.append({
            "hidden_size": hidden_size,
            "final_htse": final_htse,
            "final_hsie": final_hsie
        })
    return run_results

# --- 4. Main Execution Block ---
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    full_train_dataset = datasets.FashionMNIST('.', train=True, download=True, transform=transforms.ToTensor())

    all_runs_results = []
    for i in range(CONFIG["num_runs"]):
        current_seed = CONFIG["base_seed"] + i
        print(f"\n{'='*20} Starting Run {i+1}/{CONFIG['num_runs']} with Seed {current_seed} {'='*20}")
        single_run_data = run_single_sweep(current_seed, full_train_dataset, device)
        all_runs_results.append(single_run_data)

    with open(CONFIG["results_filename"], 'w') as f:
        json.dump(all_runs_results, f, indent=4)
    print(f"\nRaw experiment data for all {CONFIG['num_runs']} runs saved to: {CONFIG['results_filename']}")
    print("Experiment complete. You can now use the updated plotting script to visualize the aggregated results.")