In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage, stats
from sklearn.feature_selection import mutual_info_regression
from sklearn.neighbors import NearestNeighbors
import scipy.spatial.distance as dist
from scipy.signal import correlate2d
import warnings
warnings.filterwarnings('ignore')

class SpatioTemporalInfoAnalyzer:
    """
    Comprehensive toolkit for quantifying information content in evolving 2D fields
    """
    
    def __init__(self, field_data, dx=1.0, dt=1.0):
        """
        field_data: 3D array (time, y, x) representing the evolving 2D field
        dx, dt: spatial and temporal resolution
        """
        self.data = np.array(field_data)
        self.T, self.ny, self.nx = self.data.shape
        self.dx = dx
        self.dt = dt
        
    def differential_entropy_kde(self, field_slice, bandwidth=None):
        """
        Estimate differential entropy using kernel density estimation
        """
        field_flat = field_slice.flatten()
        # Remove any NaN or inf values
        field_flat = field_flat[np.isfinite(field_flat)]
        
        if len(field_flat) == 0:
            return 0.0
            
        # Estimate PDF using KDE
        try:
            kde = stats.gaussian_kde(field_flat, bw_method=bandwidth)
            # Evaluate KDE on a grid
            x_grid = np.linspace(field_flat.min(), field_flat.max(), 100)
            pdf_vals = kde(x_grid)
            pdf_vals = pdf_vals[pdf_vals > 1e-10]  # Avoid log(0)
            
            # Numerical integration for entropy
            dx_grid = x_grid[1] - x_grid[0] if len(x_grid) > 1 else 1.0
            entropy = -np.sum(pdf_vals * np.log(pdf_vals)) * dx_grid
            return entropy
        except:
            return 0.0
    
    def spatial_mutual_information(self, t1, t2, lag_x=1, lag_y=1, bins=50):
        """
        Calculate mutual information between spatially separated regions
        """
        field1 = self.data[t1]
        field2 = self.data[t2]
        
        # Create spatially lagged versions
        region1 = field1[:-lag_y, :-lag_x].flatten()
        region2 = field2[lag_y:, lag_x:].flatten()
        
        # Remove invalid values
        valid_mask = np.isfinite(region1) & np.isfinite(region2)
        region1 = region1[valid_mask]
        region2 = region2[valid_mask]
        
        if len(region1) < 10:
            return 0.0
        
        # Calculate mutual information using binning
        try:
            hist_2d, x_edges, y_edges = np.histogram2d(region1, region2, bins=bins)
            hist_2d = hist_2d + 1e-10  # Avoid log(0)
            
            # Normalize to get probabilities
            p_xy = hist_2d / np.sum(hist_2d)
            p_x = np.sum(p_xy, axis=1)
            p_y = np.sum(p_xy, axis=0)
            
            # Calculate MI
            mi = 0.0
            for i in range(len(p_x)):
                for j in range(len(p_y)):
                    if p_xy[i,j] > 0 and p_x[i] > 0 and p_y[j] > 0:
                        mi += p_xy[i,j] * np.log(p_xy[i,j] / (p_x[i] * p_y[j]))
            
            return mi
        except:
            return 0.0
    
    def transfer_entropy_estimator(self, source_ts, target_ts, lag=1, bins=10):
        """
        Estimate transfer entropy between two time series
        TE(X->Y) = I(Y_t+1; X_t | Y_t)
        """
        if len(source_ts) != len(target_ts) or len(source_ts) < lag + 2:
            return 0.0
        
        # Prepare variables for TE calculation
        y_future = target_ts[lag+1:]
        x_past = source_ts[lag:-1]
        y_past = target_ts[lag:-1]
        
        # Remove invalid values
        valid_mask = np.isfinite(y_future) & np.isfinite(x_past) & np.isfinite(y_past)
        y_future = y_future[valid_mask]
        x_past = x_past[valid_mask]
        y_past = y_past[valid_mask]
        
        if len(y_future) < 10:
            return 0.0
        
        try:
            # Calculate conditional mutual information
            # TE = I(Y_t+1; X_t | Y_t) = H(Y_t+1 | Y_t) - H(Y_t+1 | X_t, Y_t)
            
            # Bin the data
            y_fut_binned = np.digitize(y_future, np.linspace(y_future.min(), y_future.max(), bins))
            x_past_binned = np.digitize(x_past, np.linspace(x_past.min(), x_past.max(), bins))
            y_past_binned = np.digitize(y_past, np.linspace(y_past.min(), y_past.max(), bins))
            
            # Calculate joint and marginal entropies
            def joint_entropy_2d(x, y):
                hist, _, _ = np.histogram2d(x, y, bins=bins)
                hist = hist + 1e-10
                p = hist / np.sum(hist)
                return -np.sum(p * np.log(p))
            
            def joint_entropy_3d(x, y, z):
                # Simplified 3D entropy calculation
                combined = x * bins**2 + y * bins + z
                unique, counts = np.unique(combined, return_counts=True)
                p = counts / np.sum(counts)
                p = p + 1e-10
                return -np.sum(p * np.log(p))
            
            h_y_fut_y_past = joint_entropy_2d(y_future, y_past)
            h_y_past = -np.sum((np.bincount(y_past_binned) / len(y_past_binned)) * 
                              np.log(np.bincount(y_past_binned) / len(y_past_binned) + 1e-10))
            
            h_y_fut_x_past_y_past = joint_entropy_3d(y_fut_binned, x_past_binned, y_past_binned)
            h_x_past_y_past = joint_entropy_2d(x_past, y_past)
            
            # TE = I(Y_fut; X_past | Y_past)
            te = h_y_fut_y_past + h_x_past_y_past - h_y_fut_x_past_y_past - h_y_past
            
            return max(0, te)  # TE should be non-negative
        except:
            return 0.0
    
    def entropy_rate(self, window_size=5):
        """
        Calculate entropy rate - new information per time step
        """
        rates = []
        for t in range(window_size, self.T):
            current_entropy = self.differential_entropy_kde(self.data[t])
            
            # Conditional entropy given past window
            past_window = self.data[t-window_size:t]
            past_info = np.mean([self.differential_entropy_kde(frame) for frame in past_window])
            
            # Rate is new entropy beyond what's predictable from past
            rate = current_entropy - 0.8 * past_info  # 0.8 is a damping factor
            rates.append(max(0, rate))
        
        return np.array(rates)
    
    def information_flow_field(self, lag=1):
        """
        Calculate information flow between neighboring spatial regions
        Returns a 2D field showing local information transfer rates
        """
        flow_field = np.zeros((self.ny-2, self.nx-2))
        
        for i in range(1, self.ny-1):
            for j in range(1, self.nx-1):
                # Extract time series for center pixel and neighbors
                center_ts = self.data[:, i, j]
                
                # Calculate average transfer entropy from neighbors
                neighbors = [
                    self.data[:, i-1, j],    # up
                    self.data[:, i+1, j],    # down
                    self.data[:, i, j-1],    # left
                    self.data[:, i, j+1],    # right
                ]
                
                te_values = []
                for neighbor_ts in neighbors:
                    te = self.transfer_entropy_estimator(neighbor_ts, center_ts, lag)
                    te_values.append(te)
                
                flow_field[i-1, j-1] = np.mean(te_values)
        
        return flow_field
    
    def wavefront_info_velocity(self):
        """
        Estimate the velocity of information propagation
        """
        velocities = []
        
        for t in range(1, self.T-1):
            # Calculate spatial gradients of information
            info_field = np.abs(self.data[t+1] - self.data[t-1]) / (2 * self.dt)
            
            # Estimate wavefront using edge detection
            grad_x = np.gradient(info_field, self.dx, axis=1)
            grad_y = np.gradient(info_field, self.dx, axis=0)
            
            # Information velocity magnitude
            velocity_mag = np.sqrt(grad_x**2 + grad_y**2)
            velocities.append(np.mean(velocity_mag[np.isfinite(velocity_mag)]))
        
        return np.array(velocities)
    
    def analyze_complete(self):
        """
        Perform comprehensive spatio-temporal information analysis
        """
        print("Performing comprehensive spatio-temporal information analysis...")
        
        results = {}
        
        # 1. Temporal entropy evolution
        print("  - Calculating temporal entropy evolution...")
        temporal_entropies = []
        for t in range(self.T):
            entropy = self.differential_entropy_kde(self.data[t])
            temporal_entropies.append(entropy)
        results['temporal_entropy'] = np.array(temporal_entropies)
        
        # 2. Entropy rate
        print("  - Calculating entropy rate...")
        results['entropy_rate'] = self.entropy_rate()
        
        # 3. Spatial mutual information evolution
        print("  - Calculating spatial correlations...")
        spatial_mi = []
        for t in range(self.T-1):
            mi = self.spatial_mutual_information(t, t+1, lag_x=1, lag_y=1)
            spatial_mi.append(mi)
        results['spatial_mi'] = np.array(spatial_mi)
        
        # 4. Information flow field
        print("  - Calculating information flow field...")
        results['flow_field'] = self.information_flow_field()
        
        # 5. Information velocity
        print("  - Estimating information propagation velocity...")
        results['info_velocity'] = self.wavefront_info_velocity()
        
        return results
    
    def plot_analysis(self, results):
        """
        Visualize the information analysis results
        """
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # Original field evolution (show first, middle, last frames)
        times_to_show = [0, self.T//2, self.T-1]
        for idx, t in enumerate(times_to_show):
            ax = axes[0, idx]
            im = ax.imshow(self.data[t], cmap='viridis', aspect='auto')
            ax.set_title(f'Field at t={t}')
            plt.colorbar(im, ax=ax)
        
        # Temporal entropy evolution
        axes[1, 0].plot(results['temporal_entropy'])
        axes[1, 0].set_title('Temporal Entropy Evolution')
        axes[1, 0].set_xlabel('Time')
        axes[1, 0].set_ylabel('Differential Entropy')
        
        # Entropy rate
        axes[1, 1].plot(results['entropy_rate'])
        axes[1, 1].set_title('Information Generation Rate')
        axes[1, 1].set_xlabel('Time')
        axes[1, 1].set_ylabel('Entropy Rate')
        
        # Information flow field
        im = axes[1, 2].imshow(results['flow_field'], cmap='plasma', aspect='auto')
        axes[1, 2].set_title('Information Flow Field')
        plt.colorbar(im, ax=axes[1, 2])
        
        plt.tight_layout()
        plt.show()
        
        # Summary statistics
        print("\n=== SPATIO-TEMPORAL INFORMATION ANALYSIS SUMMARY ===")
        print(f"Total entropy change: {results['temporal_entropy'][-1] - results['temporal_entropy'][0]:.3f}")
        print(f"Average entropy rate: {np.mean(results['entropy_rate']):.3f}")
        print(f"Peak information generation: {np.max(results['entropy_rate']):.3f}")
        print(f"Average spatial MI: {np.mean(results['spatial_mi']):.3f}")
        print(f"Mean information flow: {np.mean(results['flow_field']):.3f}")
        print(f"Average information velocity: {np.mean(results['info_velocity']):.3f}")

# Example usage with simulated data
def generate_example_field():
    """Generate a sample evolving 2D field (wave equation solution)"""
    nx, ny, nt = 50, 50, 100
    dx, dy, dt = 0.1, 0.1, 0.01
    
    x = np.linspace(0, 5, nx)
    y = np.linspace(0, 5, ny)
    t = np.linspace(0, 1, nt)
    
    X, Y = np.meshgrid(x, y)
    field_data = np.zeros((nt, ny, nx))
    
    # Simulate a wave propagating with some nonlinearity
    for i, time in enumerate(t):
        # Multiple wave sources with different frequencies
        wave1 = np.sin(2*np.pi*(X - 2*time)) * np.exp(-((X-2.5)**2 + (Y-2.5)**2)/2)
        wave2 = 0.5*np.sin(3*np.pi*(Y - 1.5*time)) * np.exp(-((X-1)**2 + (Y-4)**2)/1)
        
        # Add some nonlinear interaction
        field_data[i] = wave1 + wave2 + 0.1*wave1*wave2
        
        # Add small amount of noise
        field_data[i] += 0.05 * np.random.randn(ny, nx)
    
    return field_data

if __name__ == "__main__":
    # Generate example data
    print("Generating example 2D evolving field...")
    field_data = generate_example_field()
    
    # Create analyzer
    analyzer = SpatioTemporalInfoAnalyzer(field_data, dx=0.1, dt=0.01)
    
    # Perform analysis
    results = analyzer.analyze_complete()
    
    # Visualize results
    analyzer.plot_analysis(results)