In [None]:
# -*- coding: utf-8 -*-
"""
This script implements the final, focused version of the Cognitive Thermodynamics
analysis, centering the narrative on the core, robust findings: the macroscopic
phase transition observed in ICE and the evolution of the underlying semantic
potential energy landscape.

Core Design Philosophy (Final Phase Transition Focus Version):
1.  Focus on Robust Observables: Removes the unstable estimation of the internal
    temperature T and the associated Smeta efficiency (α) analysis.
2.  Highlight Phase Transition: Re-introduces the generalized logistic (sigmoid)
    fit for the System Temperature (ICE) curve, providing the primary quantitative
    evidence for a phase transition.
3.  Objective Potential Energy: Retains the ultimate, theoretically-grounded
    definition of energy (E_pot) as the unweighted norm of the semantic state vector.
4.  Memory-Enhanced Analyzer: Uses the most stable analyzer design to ensure smooth
    and continuous metric calculation.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
import warnings
from scipy.stats import entropy
from scipy.optimize import curve_fit
from scipy.spatial.distance import pdist, squareform
import networkx as nx
import os
import json

warnings.filterwarnings("ignore", category=UserWarning)

# --- 1. Experiment Configuration ---
CONFIG = {
    "train_subset_size": 500,
    "epochs": 100,
    "batch_size": 64,
    "learning_rate": 0.01,
    "model_h1_size": 256,
    "model_h2_size": 128,
    "analysis_interval": 5,
    "analysis_sample_size": 100,
}

# --- 2. Neural Network Model ---
class MLP(nn.Module):
    def __init__(self, h1_size=256, h2_size=128):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(28 * 28, h1_size), nn.ReLU(),
            nn.Linear(h1_size, h2_size), nn.ReLU(),
            nn.Linear(h2_size, 10)
        )
    def forward(self, x): return self.layers(x.view(-1, 28 * 28))
    def get_hidden_activations(self, x):
        activations = {}
        x = x.view(-1, 28 * 28)
        current_layer_input = x
        for i, layer in enumerate(self.layers):
            current_layer_input = layer(current_layer_input)
            if isinstance(layer, nn.ReLU):
                activations[f'hidden_{i//2}'] = current_layer_input
        return activations

# --- 3. Cognitive Thermodynamics Analyzer (with Memory) ---
class CognitiveThermodynamicsAnalyzer:
    def __init__(self, initial_model, device):
        self.device = device
        self.memoized_paths = {} # Persistent cache
        self.model = initial_model
        self.update_graph(initial_model) # Initial graph build

    def update_graph(self, model):
        """Dynamically updates the internal graph based on the new model state."""
        self.model = model
        self.analyzer_model = copy.deepcopy(model).to('cpu')
        self.linear_layers = [l for l in self.analyzer_model.layers if isinstance(l, nn.Linear)]
        
        G = nx.DiGraph()
        for i in range(self.linear_layers[0].in_features): G.add_node(f"0-{i}", layer=0)
        for i, l in enumerate(self.linear_layers):
            for j in range(l.out_features): G.add_node(f"{i+1}-{j}", layer=i+1)
            weights = torch.abs(l.weight.data.t())
            probs = torch.softmax(weights, dim=1)
            for u in range(l.in_features):
                for v in range(l.out_features):
                    p = probs[u, v].item()
                    if p > 1e-9: G.add_edge(f"{i}-{u}", f"{i+1}-{v}", cost=1.0 - np.log(p))
        
        self.graph = G
        self.grounding_nodes = {f"{len(self.linear_layers)}-{i}" for i in range(10)}
        self.hidden_nodes = [n for n, d in self.graph.nodes(data=True) if 0 < d['layer'] < len(self.linear_layers)]

    def _find_all_paths_dfs(self, start, targets):
        memo_key = (start, tuple(sorted(list(targets))))
        if memo_key in self.memoized_paths: return self.memoized_paths[memo_key]
        paths, stack = [], [(start, [start], 0)]
        while stack:
            curr, path, cost = stack.pop()
            if curr in targets: paths.append({'path': path, 'cost': cost}); continue
            if len(path) > 8: continue
            for neighbor in self.graph.neighbors(curr):
                if neighbor not in path:
                    stack.append((neighbor, path + [neighbor], cost + self.graph[curr][neighbor]['cost']))
        self.memoized_paths[memo_key] = paths
        return paths

    def _calculate_metrics_for_node(self, node):
        paths = self._find_all_paths_dfs(node, self.grounding_nodes)
        if not paths: return float('inf'), float('inf')
        costs = np.array([p['cost'] for p in paths])
        conductances = 1.0 / costs
        h_tse = 1.0 / np.sum(conductances) if np.sum(conductances) > 0 else float('inf')
        importances = np.exp(-1.0 * costs)
        groundingness = np.sum(importances)
        probabilities = importances / groundingness if groundingness > 0 else []
        h_sie = -np.sum(probabilities * np.log2(probabilities + 1e-9)) if probabilities.size > 0 else float('inf')
        return h_tse, h_sie

    def get_all_metrics_and_distributions(self):
        htse_vals, hsie_vals, potential_energy_vals = [], [], []
        if not self.hidden_nodes: return {}
        
        sample_size = min(CONFIG["analysis_sample_size"], len(self.hidden_nodes))
        sampled_nodes = np.random.choice(self.hidden_nodes, size=sample_size, replace=False)
        
        for node in sampled_nodes:
            h_tse, h_sie = self._calculate_metrics_for_node(node)
            if np.isfinite(h_tse) and np.isfinite(h_sie):
                htse_vals.append(h_tse)
                hsie_vals.append(h_sie)
                potential_energy = np.sqrt(h_tse**2 + h_sie**2)
                potential_energy_vals.append(potential_energy)
        
        metrics = {
            "avg_htse": np.mean(htse_vals) if htse_vals else 0,
            "avg_hsie": np.mean(hsie_vals) if hsie_vals else 0,
            "potential_energy_dist": potential_energy_vals
        }
        return metrics

    def analyze_macroscopic(self, test_loader):
        self.model.eval()
        prototypes = []
        with torch.no_grad():
            for i in range(10):
                class_images = [img for img, label in test_loader.dataset if label == i][:100]
                if not class_images: continue
                class_tensor = torch.stack(class_images).to(self.device)
                hidden_acts = self.model.get_hidden_activations(class_tensor)
                last_hidden_key = f'hidden_{(len(self.linear_layers)-2)//2}'
                if last_hidden_key in hidden_acts:
                    class_rep = hidden_acts[last_hidden_key].mean(dim=0)
                    prototypes.append(class_rep.cpu().numpy())
        if len(prototypes) < 2: return 0
        sim_matrix = 1 - squareform(pdist(np.array(prototypes), 'cosine'))
        np.fill_diagonal(sim_matrix, 0)
        prob_dist = sim_matrix / sim_matrix.sum() if sim_matrix.sum() > 0 else sim_matrix
        return entropy(prob_dist.flatten())

# --- 4. Training and Evaluation Functions ---
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            pred = model(data).argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return 100. * correct / len(dataloader.dataset)

# --- 5. Main Experiment Workflow ---
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    mnist_train_full = datasets.MNIST('.', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST('.', train=False, download=True, transform=transform)

    train_subset = Subset(mnist_train_full, range(CONFIG['train_subset_size']))
    train_loader = DataLoader(train_subset, batch_size=CONFIG['batch_size'], shuffle=True)
    test_loader = DataLoader(mnist_test, batch_size=CONFIG['batch_size'])

    print(f"Experiment Setup: Training on {len(train_subset)} samples for {CONFIG['epochs']} epochs.")
    
    model = MLP(CONFIG['model_h1_size'], CONFIG['model_h2_size']).to(device)
    optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
    criterion = nn.CrossEntropyLoss()

    analyzer = CognitiveThermodynamicsAnalyzer(model, device)
    
    history = {
        "epoch": [], "loss": [], "train_acc": [], "test_acc": [], 
        "htse": [], "hsie": [], "ice": [],
        "initial_potential_energy_dist": [], "final_potential_energy_dist": [],
    }
    
    print("Analyzing initial (t=0) random state...")
    initial_metrics = analyzer.get_all_metrics_and_distributions()
    history["initial_potential_energy_dist"] = initial_metrics.get("potential_energy_dist", [])
    
    pbar = tqdm(range(1, CONFIG['epochs'] + 1), desc="Training Progress")
    for epoch in pbar:
        model.train()
        epoch_losses = []
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss.item())
        avg_epoch_loss = np.mean(epoch_losses)

        if epoch % CONFIG['analysis_interval'] == 0 or epoch == CONFIG['epochs']:
            analyzer.update_graph(model)
            
            train_acc = evaluate(model, train_loader, device)
            test_acc = evaluate(model, test_loader, device)
            
            current_metrics = analyzer.get_all_metrics_and_distributions()
            ice_current = analyzer.analyze_macroscopic(test_loader)

            history["epoch"].append(epoch)
            history["loss"].append(avg_epoch_loss)
            history["train_acc"].append(train_acc)
            history["test_acc"].append(test_acc)
            history["htse"].append(current_metrics.get("avg_htse", 0))
            history["hsie"].append(current_metrics.get("avg_hsie", 0))
            history["ice"].append(ice_current)
            
            pbar.set_postfix({"Loss": f"{avg_epoch_loss:.2f}", "ICE": f"{ice_current:.2f}"})

    print("Analyzing final (t=100) ordered state...")
    final_metrics = analyzer.get_all_metrics_and_distributions()
    history["final_potential_energy_dist"] = final_metrics.get("potential_energy_dist", [])
    
    results_dir = "results_phase_transition"
    os.makedirs(results_dir, exist_ok=True)
    results_path = os.path.join(results_dir, "results.json")
    with open(results_path, 'w') as f: json.dump(history, f, indent=4)
    print(f"\nExperiment data successfully saved to: {results_path}")

    # --- 7. Results Visualization ---
    plt.style.use('seaborn-v0_8-whitegrid')
    # Changed to 2x2 grid and adjusted figsize
    fig, axs = plt.subplots(2, 2, figsize=(20, 12)) 
    fig.suptitle('Cognitive Thermodynamics: Analysis of Phase Transition', fontsize=22, y=0.98)
    
    epochs = np.array(history['epoch'])
    
    # Plot 1 (Position [0, 0]): Accuracy & Loss Dynamics
    ax1 = axs[0, 0]
    ax1_twin = ax1.twinx()
    ax1.plot(epochs, history['train_acc'], 'b-o', label='Training Accuracy')
    ax1.plot(epochs, history['test_acc'], 'r-s', label='Test Accuracy')
    ax1_twin.plot(epochs, history['loss'], 'g--p', label='Training Loss')
    ax1.set_title('1. Accuracy & Loss Dynamics', fontsize=16)
    ax1.set_ylabel('Accuracy (%)', fontsize=12)
    ax1_twin.set_ylabel('Loss', fontsize=12, color='g')
    ax1.legend(loc='center left'); ax1_twin.legend(loc='center right')

    # Plot 2 (Position [0, 1]): Semantic Temperature (ICE) Dynamics with Logistic Fit
    ax2 = axs[0, 1]
    ice_data = np.array(history['ice'])
    ax2.plot(epochs, ice_data, 'm-^', label='Semantic Temperature (ICE)')
    
    def generalized_logistic_function(t, L, A, k, t0):
        return L + A / (1 + np.exp(-k * (t - t0)))

    min_ice_idx = np.argmin(ice_data)
    if min_ice_idx < len(ice_data) - 3:
        epochs_to_fit = epochs[min_ice_idx:]
        ice_to_fit = ice_data[min_ice_idx:]
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.6)
        try:
            L_guess = np.min(ice_to_fit)
            A_guess = np.max(ice_to_fit) - L_guess
            half_way_val = L_guess + A_guess / 2
            try:
                t0_index = np.where(ice_to_fit >= half_way_val)[0][0]
                t0_guess = epochs_to_fit[t0_index]
            except IndexError:
                t0_guess = np.median(epochs_to_fit)
            k_guess = 1.0
            initial_guesses = [L_guess, A_guess, k_guess, t0_guess]
            bounds = ([np.min(ice_to_fit)*0.9, A_guess*0.5, 1e-5, epochs_to_fit[0]], 
                      [np.min(ice_to_fit)*1.1, A_guess*1.5, 5.0, epochs_to_fit[-1]])
            
            popt, _ = curve_fit(generalized_logistic_function, epochs_to_fit, ice_to_fit, p0=initial_guesses, bounds=bounds, maxfev=10000)
            r_squared = 1 - (np.sum((ice_to_fit - generalized_logistic_function(epochs_to_fit, *popt))**2) / np.sum((ice_to_fit - np.mean(ice_to_fit))**2))
            
            epochs_fine = np.linspace(epochs_to_fit[0], epochs_to_fit[-1], 200)
            ax2.plot(epochs_fine, generalized_logistic_function(epochs_fine, *popt), 'g--', lw=2.5, label=f'Logistic Fit (R²={r_squared:.3f})')
            result_text = (f"Phase Transition Analysis:\n"
                           f"R² = {r_squared:.4f}\nMax Temp: {popt[0]+popt[1]:.3f}\n"
                           f"Speed (k): {popt[2]:.3f}\nCritical Point (t0): {popt[3]:.2f} Epochs")
            ax2.text(0.05, 0.95, result_text, transform=ax2.transAxes, fontsize=10, va='top', bbox=props)
        except (RuntimeError, ValueError) as e:
            ax2.text(0.05, 0.95, f"Logistic fit failed", transform=ax2.transAxes, fontsize=10, va='top', bbox=props)
    
    ax2.set_title('2. Semantic Temperature (ICE) Dynamics', fontsize=16)
    ax2.set_ylabel('Semantic Temperature (Entropy)', fontsize=12)
    ax2.legend()

    # Plot 3 (Position [1, 0]): Microscopic Entropy Dynamics
    ax3 = axs[1, 0]
    ax3.plot(epochs, history['htse'], 'c-p', label="H'_TSE (Cognitive Cost)")
    ax3.plot(epochs, history['hsie'], 'y-h', label="H'_SIE (Structural Robustness)")
    ax3.set_title('3. Microscopic Entropy Dynamics', fontsize=16)
    ax3.set_ylabel('Entropy Value', fontsize=12)
    ax3.legend(loc='center right')
    
    # Plot 4 (Position [1, 1]): Semantic Potential Energy Distribution
    ax4 = axs[1, 1]
    energy_initial = history['initial_potential_energy_dist']
    energy_final = history['final_potential_energy_dist']
    
    if energy_initial and energy_final:
        combined_energy = np.hstack((energy_initial, energy_final))
        combined_energy = combined_energy[np.isfinite(combined_energy)]
        if combined_energy.size > 0:
            bins = np.histogram(combined_energy, bins=25)[1]
            ax4.hist(energy_initial, bins=bins, alpha=0.7, label='Initial State (High E_pot)', color='blue', density=True)
            ax4.hist(energy_final, bins=bins, alpha=0.7, label='Final State (Low E_pot)', color='red', density=True)
            ax4.set_yscale('log')
            ax4.legend()
        else:
            ax4.text(0.5, 0.5, "No finite energy data to plot.", ha='center', va='center')
    else:
        if not energy_initial: print("Warning: Initial potential energy distribution is empty.")
        if not energy_final: print("Warning: Final potential energy distribution is empty.")
        ax4.text(0.5, 0.5, "Energy data not available.", ha='center', va='center')
        
    ax4.set_title('4. Objective Potential Energy Distribution', fontsize=16)
    ax4.set_ylabel('Probability Density', fontsize=12)

    # General labeling and grid for the 2x2 layout
    for i, j in [(0,0), (0,1), (1,0), (1,1)]:
        axs[i,j].set_xlabel('Epochs', fontsize=12)
        axs[i,j].grid(True, which="both", ls="--")
    
    axs[1,1].set_xlabel('E_pot (Unweighted Norm of State Vector)', fontsize=12) # Specific label for energy plot

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

