*Please run this script before experiment,and you will be surprised about the result.*

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons
from tqdm import trange
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import copy
import warnings
from scipy.optimize import curve_fit
from scipy.stats import pearsonr

# Suppress FutureWarning from seaborn/numpy
warnings.filterwarnings("ignore", category=FutureWarning)

# --- 1. Configuration Parameters ---
N_SAMPLES = 2000
N_EPOCHS = 5000
BATCH_SIZE = 128
LEARNING_RATE = 0.05
WEIGHT_DECAY = 0 # You can adjust this L2 regularization strength
N_BINS = 30  # Number of bins for MI calculation
ANALYSIS_SAMPLE_SIZE = 30 # For H'tse/H'sie calculation
LOG_INTERVAL = 10 # How often to log MI and Entropy metrics
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
print(f"Weight Decay (L2 Regularization) set to: {WEIGHT_DECAY}")

# --- 2. Define the MLP Model ---
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
        super(SimpleMLP, self).__init__()
        # Wrap layers in a Sequential container to make them easily iterable for the TheoryAnalyzer
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.Tanh(),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.Tanh(),
            nn.Linear(hidden_dim2, output_dim)
        )

    def forward(self, x, return_hidden=False):
        h1, h2 = None, None
        # Manually pass through layers to capture hidden activations
        x = self.layers[0](x)
        x = self.layers[1](x)
        h1 = x
        x = self.layers[2](x)
        x = self.layers[3](x)
        h2 = x
        out = self.layers[4](x)

        if return_hidden:
            return out, h1, h2
        return out

# --- 3. Prepare Data ---
X, y = make_moons(n_samples=N_SAMPLES, noise=0.1, random_state=42)
X = torch.FloatTensor(X).to(DEVICE)
y = torch.LongTensor(y).to(DEVICE)

train_dataset = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)


# --- 4. Information Theory & Unified Theory Analysis Tools ---

# 4.1 Mutual Information Calculation Functions
def calculate_mutual_information(x, y, bins):
    """Calculates the mutual information I(X;Y) between two variables."""
    joint_hist, _, _ = np.histogram2d(x, y, bins=bins)
    joint_prob = joint_hist / np.sum(joint_hist)
    p_x = np.sum(joint_prob, axis=1)
    p_y = np.sum(joint_prob, axis=0)
    mi = 0.0
    for i in range(bins):
        for j in range(bins):
            if joint_prob[i, j] > 1e-12 and p_x[i] > 1e-12 and p_y[j] > 1e-12:
                mi += joint_prob[i, j] * np.log2(joint_prob[i, j] / (p_x[i] * p_y[j]))
    return mi

def get_mi_for_layer(model, data_loader, layer_idx, n_bins):
    """Gets the mutual information for a specified hidden layer."""
    model.eval()
    all_inputs = data_loader.dataset.tensors[0].cpu().numpy()
    all_labels = data_loader.dataset.tensors[1].cpu().numpy()

    with torch.no_grad():
        _, h1, h2 = model(data_loader.dataset.tensors[0], return_hidden=True)

    hidden_activations = h1 if layer_idx == 1 else h2
    hidden_activations = hidden_activations.cpu().numpy()

    digitized_input = np.digitize(all_inputs[:, 0], bins=np.linspace(np.min(all_inputs[:, 0]), np.max(all_inputs[:, 0]), n_bins))
    digitized_hidden = np.digitize(hidden_activations[:, 0], bins=np.linspace(np.min(hidden_activations[:, 0]), np.max(hidden_activations[:, 0]), n_bins))
    digitized_labels = all_labels

    mi_xt = calculate_mutual_information(digitized_input, digitized_hidden, n_bins)
    mi_ty = calculate_mutual_information(digitized_hidden, digitized_labels, n_bins)
    return mi_xt, mi_ty

# 4.2 Unified Theory Analyzer (H'_tse and H'_sie) from your script
class TheoryAnalyzer:
    def __init__(self, model):
        model_copy = copy.deepcopy(model)
        model_copy.eval()
        self.model = model_copy.to('cpu')
        self.graph = self._build_graph()
        self.hidden_nodes = self._get_hidden_nodes()
        self.memoized_paths = {}

    def _build_graph(self):
        G = nx.DiGraph()
        node_counter = 0; layer_map = {}
        linear_layers = [l for l in self.model.layers if isinstance(l, nn.Linear)]
        if not linear_layers: return G
        in_features = linear_layers[0].in_features
        layer_map[0] = list(range(node_counter, node_counter + in_features))
        for i in range(in_features):
            G.add_node(node_counter, layer=0); node_counter += 1
        graph_layer_idx = 1
        for l in linear_layers:
            layer_map[graph_layer_idx] = list(range(node_counter, node_counter + l.out_features))
            for i in range(l.out_features):
                G.add_node(node_counter, layer=graph_layer_idx); node_counter += 1
            weights = torch.abs(l.weight.data.t()); probs = torch.softmax(weights, dim=1)
            for u_local_idx, u_global_idx in enumerate(layer_map[graph_layer_idx - 1]):
                for v_local_idx, v_global_idx in enumerate(layer_map[graph_layer_idx]):
                    prob = probs[u_local_idx, v_local_idx].item()
                    if prob > 1e-9: G.add_edge(u_global_idx, v_global_idx, cost=1.0 - np.log(prob + 1e-9))
            graph_layer_idx += 1
        self.grounding_nodes = set(layer_map.get(graph_layer_idx - 1, [])); return G

    def _get_hidden_nodes(self):
        max_layer_idx = max((data['layer'] for _, data in self.graph.nodes(data=True)), default=0)
        return [node for node, data in self.graph.nodes(data=True) if data['layer'] not in [0, max_layer_idx]]

    def find_all_paths_dfs(self, start, targets):
        memo_key = (start, tuple(sorted(list(targets))))
        if memo_key in self.memoized_paths: return self.memoized_paths[memo_key]
        paths, stack = [], [(start, [start], 0)]
        while stack:
            curr, path, cost = stack.pop()
            if curr in targets: paths.append({'path': path, 'cost': cost}); continue
            if len(path) > 10: continue
            for neighbor in self.graph.neighbors(curr):
                edge_cost = self.graph.get_edge_data(curr, neighbor, {}).get('cost', float('inf'))
                if neighbor not in path: stack.append((neighbor, path + [neighbor], cost + edge_cost))
        self.memoized_paths[memo_key] = paths; return paths

    def calculate_metrics_for_node(self, node):
        paths = self.find_all_paths_dfs(node, self.grounding_nodes)
        if not paths: return float('inf'), 0.0
        costs = np.array([p['cost'] for p in paths])
        conductances = 1.0 / costs
        htse = 1.0 / np.sum(conductances) if np.sum(conductances) > 0 else float('inf')
        importances = np.exp(-1.0 * costs)
        probabilities = importances / np.sum(importances) if np.sum(importances) > 0 else np.zeros_like(importances)
        hsie = -np.sum(probabilities * np.log2(probabilities + 1e-9))
        return htse, hsie

    def analyze_model_structure(self, analysis_sample_size):
        htse_vals, hsie_vals = [], []
        if not self.hidden_nodes: return 0, 0
        sample_size = min(analysis_sample_size, len(self.hidden_nodes))
        sampled_nodes = np.random.choice(self.hidden_nodes, size=sample_size, replace=False)
        for node in sampled_nodes:
            htse, hsie = self.calculate_metrics_for_node(node)
            if np.isfinite(htse) and np.isfinite(hsie): htse_vals.append(htse); hsie_vals.append(hsie)
        return np.mean(htse_vals) if htse_vals else 0, np.mean(hsie_vals) if hsie_vals else 0


# --- 5. Train Model and Log Metrics ---
model = SimpleMLP(input_dim=2, hidden_dim1=10, hidden_dim2=7, output_dim=2).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

history = {
    'loss': [], 'mi_xt_h1': [], 'mi_ty_h1': [],
    'htse': [], 'hsie': []
}

pbar = trange(N_EPOCHS, desc="Training")
for epoch in pbar:
    model.train()
    epoch_loss = 0
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)
    history['loss'].append(avg_loss)

    if epoch % LOG_INTERVAL == 0:
        # Calculate IB metrics
        mi_xt1, mi_ty1 = get_mi_for_layer(model, train_loader, layer_idx=1, n_bins=N_BINS)
        history['mi_xt_h1'].append(mi_xt1); history['mi_ty_h1'].append(mi_ty1)

        # Calculate Unified Theory metrics
        analyzer = TheoryAnalyzer(model)
        htse, hsie = analyzer.analyze_model_structure(ANALYSIS_SAMPLE_SIZE)
        history['htse'].append(htse); history['hsie'].append(hsie)

        pbar.set_postfix_str(f"Loss:{avg_loss:.3f}, I(X;T):{mi_xt1:.2f}, I(T;Y):{mi_ty1:.2f}, H'tse:{htse:.3f}, H'sie:{hsie:.3f}")
    else:
        for key in history:
            if key != 'loss': history[key].append(float('nan'))

# --- 6. Curve Fitting and Statistical Analysis ---

# Define fitting functions based on your logarithmic hypothesis
def log_growth_func(x, a, b):
    """Logarithmic growth function: y = a * log(x + 1) + b"""
    return a * np.log(x + 1) + b

def log_decay_func(x, a, b):
    """Logarithmic decay function: y = -a * log(x + 1) + b"""
    return -a * np.log(x + 1) + b

# Clean up NaN values for plotting and fitting
log_epochs_np = np.array([i for i, val in enumerate(history['mi_xt_h1']) if not np.isnan(val)])
mi_xt_h1 = np.array([val for val in history['mi_xt_h1'] if not np.isnan(val)])
mi_ty_h1 = np.array([val for val in history['mi_ty_h1'] if not np.isnan(val)])
htse_hist = np.array([val for val in history['htse'] if not np.isnan(val)])
hsie_hist = np.array([val for val in history['hsie'] if not np.isnan(val)])

# Fit H'_tse (Cognitive Cost)
try:
    params_htse, _ = curve_fit(log_growth_func, log_epochs_np, htse_hist, p0=[0.001, 0.54], maxfev=5000)
    y_fit_htse = log_growth_func(log_epochs_np, *params_htse)
    residuals_htse = htse_hist - y_fit_htse
    ss_res_htse = np.sum(residuals_htse**2)
    ss_tot_htse = np.sum((htse_hist - np.mean(htse_hist))**2)
    r2_htse = 1 - (ss_res_htse / ss_tot_htse)
    r_htse, p_htse = pearsonr(htse_hist, y_fit_htse)

    # NEW: Print the results to the console
    print("\n--- H'_tse (Cognitive Cost) Curve Fit ---")
    print(f"Fit function: y = a * ln(x + 1) + b")
    print(f"  - Parameter a: {params_htse[0]:.6f}")
    print(f"  - Parameter b: {params_htse[1]:.6f}")
    print(f"R-squared: {r2_htse:.4f}")
    print(f"p-value: {p_htse:.2e}")

except RuntimeError:
    print("Could not fit H'_tse curve.")
    y_fit_htse, r2_htse, p_htse, params_htse = None, 0, 1, [0,0]

# Fit H'_sie (Robustness)
try:
    params_hsie, _ = curve_fit(log_decay_func, log_epochs_np, hsie_hist, p0=[0.01, 2.7], maxfev=5000)
    y_fit_hsie = log_decay_func(log_epochs_np, *params_hsie)
    residuals_hsie = hsie_hist - y_fit_hsie
    ss_res_hsie = np.sum(residuals_hsie**2)
    ss_tot_hsie = np.sum((hsie_hist - np.mean(hsie_hist))**2)
    r2_hsie = 1 - (ss_res_hsie / ss_tot_hsie)
    r_hsie, p_hsie = pearsonr(hsie_hist, y_fit_hsie)

    # NEW: Print the results to the console
    print("\n--- H'_sie (Robustness) Curve Fit ---")
    print(f"Fit function: y = -a * ln(x + 1) + b")
    print(f"  - Parameter a: {params_hsie[0]:.6f}")
    print(f"  - Parameter b: {params_hsie[1]:.6f}")
    print(f"R-squared: {r2_hsie:.4f}")
    print(f"p-value: {p_hsie:.2e}\n")


except RuntimeError:
    print("Could not fit H'_sie curve.")
    y_fit_hsie, r2_hsie, p_hsie, params_hsie = None, 0, 1, [0,0]

# --- 7. Visualize Results ---
sns.set_style("whitegrid")
fig, axes = plt.subplots(2, 3, figsize=(24, 12))
fig.suptitle('Unified Analysis of Learning Dynamics (with Curve Fitting)', fontsize=20)

# Column 1: Loss and IB Metrics vs Time
axes[0, 0].plot(history['loss'], color='red')
axes[0, 0].set_title('Training Loss vs. Epochs')
axes[0, 0].set_xlabel('Epoch'); axes[0, 0].set_ylabel('Cross-Entropy Loss')

axes[1, 0].plot(log_epochs_np, mi_xt_h1, label='I(X; T)', marker='.')
axes[1, 0].plot(log_epochs_np, mi_ty_h1, label='I(T; Y)', marker='.')
axes[1, 0].set_title('IB Metrics vs. Epochs')
axes[1, 0].set_xlabel('Epoch'); axes[1, 0].set_ylabel('Mutual Information (bits)'); axes[1, 0].legend()

# Column 2: Information Plane and Unified Theory Cost vs Time
points1 = axes[0, 1].scatter(mi_xt_h1, mi_ty_h1, c=log_epochs_np, cmap='viridis', s=15, alpha=0.8)
cbar1 = fig.colorbar(points1, ax=axes[0, 1]); cbar1.set_label('Epoch')
axes[0, 1].set_title('The Information Plane (IB Theory)')
axes[0, 1].set_xlabel('I(X; T) - Compression'); axes[0, 1].set_ylabel('I(T; Y) - Fitting')

axes[1, 1].scatter(log_epochs_np, htse_hist, label="H'_tse (Actual Data)", marker='.', color='green', alpha=0.5)
if y_fit_htse is not None:
    axes[1, 1].plot(log_epochs_np, y_fit_htse, label="Logarithmic Fit", color='black', linestyle='--')
    # NEW: Add the fitted equation to the plot text
    fit_eq_htse = f"$y = {params_htse[0]:.4f} \\cdot \\ln(x+1) + {params_htse[1]:.4f}$"
    axes[1, 1].text(0.95, 0.05, f'{fit_eq_htse}\n$R^2 = {r2_htse:.4f}$\n$p = {p_htse:.2e}$',
                    transform=axes[1, 1].transAxes, fontsize=12,
                    verticalalignment='bottom', horizontalalignment='right',
                    bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))
axes[1, 1].set_title("H'_tse (Cognitive Cost) vs. Epochs")
axes[1, 1].set_xlabel('Epoch'); axes[1, 1].set_ylabel("H'_tse"); axes[1, 1].legend()

# Column 3: Unified Theory Space and Robustness vs Time
points2 = axes[0, 2].scatter(htse_hist, hsie_hist, c=log_epochs_np, cmap='magma', s=15, alpha=0.8)
cbar2 = fig.colorbar(points2, ax=axes[0, 2]); cbar2.set_label('Epoch')
axes[0, 2].set_title('Weighted Semantic State Space (Unified Theory)')
axes[0, 2].set_xlabel("H'_tse - Cognitive Cost"); axes[0, 2].set_ylabel("H'_sie - Robustness")

axes[1, 2].scatter(log_epochs_np, hsie_hist, label="H'_sie (Actual Data)", marker='.', color='purple', alpha=0.5)
if y_fit_hsie is not None:
    axes[1, 2].plot(log_epochs_np, y_fit_hsie, label="Logarithmic Fit", color='black', linestyle='--')
    # NEW: Add the fitted equation to the plot text
    fit_eq_hsie = f"$y = -{params_hsie[0]:.4f} \\cdot \\ln(x+1) + {params_hsie[1]:.4f}$"
    axes[1, 2].text(0.95, 0.95, f'{fit_eq_hsie}\n$R^2 = {r2_hsie:.4f}$\n$p = {p_hsie:.2e}$',
                    transform=axes[1, 2].transAxes, fontsize=12,
                    verticalalignment='top', horizontalalignment='right',
                    bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))
axes[1, 2].set_title("H'_sie (Robustness) vs. Epochs")
axes[1, 2].set_xlabel('Epoch'); axes[1, 2].set_ylabel("H'_sie"); axes[1, 2].legend()

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()