In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import time

In [2]:
# Load the pre-saved MNIST subset (100 samples per class)
mnist_subset = torch.load("mnist_subset_100_per_class.pt")
images = torch.stack([item[0] for item in mnist_subset])  # Shape: [1000, 1, 28, 28]
labels = torch.tensor([item[1] for item in mnist_subset])

In [3]:
images.shape # 1000 image each of size 

torch.Size([1000, 1, 28, 28])

In [4]:
labels.shape

torch.Size([1000])

In [5]:
# Prepare the dataset (single batch for SKA forward learning)
inputs = images  # No mini-batches, full dataset used for forward-only updates

In [6]:
inputs.shape

torch.Size([1000, 1, 28, 28])

In [None]:

# Define the SKA model with 4 layers
class SKAModel(nn.Module):
    def __init__(self, input_size=784, layer_sizes=[256, 128, 64, 10], K=50):
        super(SKAModel, self).__init__()
        self.input_size = input_size
        self.layer_sizes = layer_sizes
        self.K = K  # Number of forward steps

        # Initialize weights and biases as nn.ParameterList
        self.weights = nn.ParameterList()
        self.biases = nn.ParameterList()
        prev_size = input_size
        for size in layer_sizes:
            self.weights.append(nn.Parameter(torch.randn(prev_size, size) * 0.01))
            self.biases.append(nn.Parameter(torch.zeros(size)))
            prev_size = size

        # Tracking tensors for knowledge accumulation and entropy computation
        self.Z = [None] * len(layer_sizes)  # Knowledge tensors per layer
        self.D = [None] * len(layer_sizes)  # Decision probability tensors
        self.D_prev = [None] * len(layer_sizes)  # Previous decisions for computing shifts
        self.delta_D = [None] * len(layer_sizes)  # Decision shifts per step
        self.entropy = [None] * len(layer_sizes)  # Layer-wise entropy storage

        # Store entropy, cosine, and output distribution history for visualization
        self.entropy_history = [[] for _ in range(len(layer_sizes))]
        self.cosine_history = [[] for _ in range(len(layer_sizes))]
        self.output_history = []  # New: Store mean output distribution (10 classes) per step

    
    def forward(self, x):
        """Computes SKA forward pass, storing knowledge and decisions."""
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)  # Flatten images

        for l in range(len(self.layer_sizes)):
            # Compute knowledge tensor Z = Wx + b
            z = torch.mm(x, self.weights[l]) + self.biases[l]
            # Apply sigmoid activation to get decision probabilities
            d = torch.sigmoid(z)
            # Store values for entropy computation
            self.Z[l] = z
            self.D[l] = d
            x = d  # Output becomes input for the next layer
            

        return x
    
    def calculate_entropy(self):
        """Computes entropy reduction and cos(theta) per layer."""
        total_entropy = 0
        for l in range(len(self.layer_sizes)):
            if self.Z[l] is not None and self.D_prev[l] is not None and self.D[l] is not None:
                # Compute decision shifts
                self.delta_D[l] = self.D[l] - self.D_prev[l]
                # Entropy reduction using SKA formula
                dot_product = torch.sum(self.Z[l] * self.delta_D[l])
                layer_entropy = -1 / np.log(2) * dot_product
                self.entropy[l] = layer_entropy.item()
                self.entropy_history[l].append(layer_entropy.item())

                # Compute cos(theta) for alignment
                z_norm = torch.norm(self.Z[l])
                delta_d_norm = torch.norm(self.delta_D[l])
                if z_norm > 0 and delta_d_norm > 0:
                    cos_theta = dot_product / (z_norm * delta_d_norm)
                    self.cosine_history[l].append(cos_theta.item())
                else:
                    self.cosine_history[l].append(0.0)  # Default if norms are zero

                total_entropy += layer_entropy
        return total_entropy


    def ska_update(self, inputs, learning_rate=0.01):
        """Updates weights using entropy-based learning without backpropagation."""
        for l in range(len(self.layer_sizes)):
            if self.delta_D[l] is not None:
                # Previous layer's output
                prev_output = inputs.view(inputs.shape[0], -1) if l == 0 else self.D_prev[l-1]
                # Compute sigmoid derivative: D * (1 - D)
                d_prime = self.D[l] * (1 - self.D[l])
                # Compute entropy gradient
                gradient = -1 / np.log(2) * (self.Z[l] * d_prime + self.delta_D[l])
                # Compute weight updates via outer product
                dW = torch.matmul(prev_output.t(), gradient) / prev_output.shape[0]
                # Update weights and biases
                self.weights[l] = self.weights[l] - learning_rate * dW
                self.biases[l] = self.biases[l] - learning_rate * gradient.mean(dim=0)

    def initialize_tensors(self, batch_size):
        """Resets decision tensors at the start of each training iteration."""
        for l in range(len(self.layer_sizes)):
            self.Z[l] = None         # Reset knowledge tensors
            self.D[l] = None         # Reset current decision probabilities
            self.D_prev[l] = None    # Reset previous decision probabilities
            self.delta_D[l] = None   # Reset decision shifts
            self.entropy[l] = None   # Reset entropy storage
            self.entropy_history[l] = []  # Reset entropy history
            self.cosine_history[l] = []   # Reset cosine history
        self.output_history = []  # Reset output history


        def visualize_entropy_heatmap(self, step):
        """Dynamically scales the heatmap range and visualizes entropy reduction."""
        entropy_data = np.array(self.entropy_history)
        vmin = np.min(entropy_data)  # Dynamically set minimum entropy value
        vmax = 0.0  # Keep 0 as the upper limit for standardization
        plt.figure(figsize=(12, 8))
        sns.heatmap(entropy_data, cmap="Blues_r", vmin=vmin, vmax=vmax,  
                    xticklabels=range(1, entropy_data.shape[1] + 1),
                    yticklabels=[f"Layer {i+1}" for i in range(len(self.layer_sizes))])
        plt.title(f"Layer-wise Entropy Heatmap (Step {step})")
        plt.xlabel("Step Index K")
        plt.ylabel("Network Layers")
        plt.tight_layout()
        plt.savefig(f"entropy_heatmap_step_{step}.png")
        plt.show(block=False)  # Non-blocking
        plt.pause(2)  # Wait for 2 seconds
        plt.close()  # Close automatically

    def visualize_cosine_heatmap(self, step):
        """Visualizes cos(theta) alignment heatmap with a diverging scale."""
        cosine_data = np.array(self.cosine_history)
        plt.figure(figsize=(12, 8))
        sns.heatmap(cosine_data, cmap="coolwarm_r", vmin=-1.0, vmax=1.0,  
                    xticklabels=range(1, cosine_data.shape[1] + 1),
                    yticklabels=[f"Layer {i+1}" for i in range(len(self.layer_sizes))])
        plt.title(f"Layer-wise Cos(\u03B8) Alignment Heatmap (Step {step})")
        plt.xlabel("Step Index K")
        plt.ylabel("Network Layers")
        plt.tight_layout()
        plt.savefig(f"cosine_heatmap_step_{step}.png")
        plt.show(block=False)  # Non-blocking
        plt.pause(2)  # Wait for 2 seconds
        plt.close()  # Close automatically

    def visualize_output_distribution(self):
        """Plots the evolution of the 10-class output distribution over K steps."""
        output_data = np.array(self.output_history)  # Shape: [K, 10]
        plt.figure(figsize=(10, 6))
        plt.plot(output_data)  # Plot each class as a line
        plt.title('Output Decision Probability Evolution Across Steps (Single Pass)')
        plt.xlabel('Step Index K')
        plt.ylabel('Mean Sigmoid Output')
        plt.legend([f"Class {i}" for i in range(10)], loc='upper right', bbox_to_anchor=(1.15, 1))
        plt.grid(True)
        plt.tight_layout()
        plt.savefig("output_distribution_single_pass.png")
        plt.show(block=False)  # Non-blocking
        plt.pause(2)  # Wait for 2 seconds
        plt.close()  # Close automatically

In [25]:

def forward(self, x):
    """Computes SKA forward pass, storing knowledge and decisions."""
    batch_size = x.shape[0]
    x = x.view(batch_size, -1)  # Flatten images

    for l in range(len(self.layer_sizes)):
        # Compute knowledge tensor Z = Wx + b
        z = torch.mm(x, self.weights[l]) + self.biases[l]
        # Apply sigmoid activation to get decision probabilities
        d = torch.sigmoid(z)
        # Store values for entropy computation
        self.Z[l] = z
        self.D[l] = d
        x = d  # Output becomes input for the next layer

    return x

### training 

In [26]:
# Training parameters
model = SKAModel()
learning_rate = 0.01

# SKA training over multiple forward steps
total_entropy = 0
step_count = 0
start_time = time.time()

# Initialize tensors for first step
# model.initialize_tensors(inputs.size(0))

In [27]:
inputs.shape[0] # we get the number of inputsinputs.shape[0]

1000

In [28]:
inputs.view(inputs.shape[0], -1).shape

torch.Size([1000, 784])

In [29]:
model.forward(inputs)

tensor([[0.4875, 0.5045, 0.5058,  ..., 0.4770, 0.4973, 0.5064],
        [0.4875, 0.5045, 0.5058,  ..., 0.4770, 0.4973, 0.5064],
        [0.4875, 0.5045, 0.5058,  ..., 0.4770, 0.4973, 0.5064],
        ...,
        [0.4875, 0.5045, 0.5058,  ..., 0.4770, 0.4973, 0.5064],
        [0.4875, 0.5045, 0.5058,  ..., 0.4770, 0.4973, 0.5064],
        [0.4875, 0.5045, 0.5058,  ..., 0.4770, 0.4973, 0.5064]],
       grad_fn=<SigmoidBackward0>)

In [30]:
output = model.forward(inputs)

In [31]:
output.shape

torch.Size([1000, 10])

In [32]:
output.mean(dim=0)

tensor([0.4875, 0.5045, 0.5058, 0.5021, 0.4920, 0.4965, 0.4966, 0.4770, 0.4973,
        0.5064], grad_fn=<MeanBackward1>)

In [33]:
output.mean(dim=0).shape

torch.Size([10])

In [34]:
output[0]

tensor([0.4875, 0.5045, 0.5058, 0.5021, 0.4920, 0.4965, 0.4966, 0.4770, 0.4973,
        0.5064], grad_fn=<SelectBackward0>)

In [35]:
model.calculate_entropy()

0

In [36]:
model.ska_update(inputs, learning_rate)

In [37]:
batch_entropy = model.calculate_entropy()

In [38]:
print(f'{batch_entropy:.4f}')

0.0000


In [39]:
model.ska_update(inputs, learning_rate)

In [40]:
model.D_prev = [d.clone().detach() if d is not None else None for d in model.D]

In [41]:
# Training parameters
model = SKAModel()
learning_rate = 0.01

# SKA training over multiple forward steps
total_entropy = 0
step_count = 0
start_time = time.time()

# Initialize tensors for first step
model.initialize_tensors(inputs.size(0))

# Process K forward steps (without backpropagation)
for k in range(model.K):
    outputs = model.forward(inputs)
    # Store mean output distribution for the final layer
    model.output_history.append(outputs.mean(dim=0).detach().cpu().numpy())  # [10] vector
    if k > 0:  # Compute entropy after first step
        batch_entropy = model.calculate_entropy()
        model.ska_update(inputs, learning_rate)
        total_entropy += batch_entropy
        step_count += 1
        print(f'Step: {k}, Total Steps: {step_count}, Entropy: {batch_entropy:.4f}')
        # model.visualize_entropy_heatmap(step_count)
        # model.visualize_cosine_heatmap(step_count)  # Add cosine heatmap
    # Update previous decision tensors
    model.D_prev = [d.clone().detach() if d is not None else None for d in model.D]


Step: 1, Total Steps: 1, Entropy: 0.0000
Step: 2, Total Steps: 2, Entropy: -142.8350
Step: 3, Total Steps: 3, Entropy: -239.0253
Step: 4, Total Steps: 4, Entropy: -365.7324
Step: 5, Total Steps: 5, Entropy: -554.8585
Step: 6, Total Steps: 6, Entropy: -838.3662
Step: 7, Total Steps: 7, Entropy: -1254.4277
Step: 8, Total Steps: 8, Entropy: -1844.3822
Step: 9, Total Steps: 9, Entropy: -2643.5168
Step: 10, Total Steps: 10, Entropy: -3666.1729
Step: 11, Total Steps: 11, Entropy: -4886.8608
Step: 12, Total Steps: 12, Entropy: -6225.5479
Step: 13, Total Steps: 13, Entropy: -7555.4561
Step: 14, Total Steps: 14, Entropy: -8744.6416
Step: 15, Total Steps: 15, Entropy: -9709.4658
Step: 16, Total Steps: 16, Entropy: -10440.2334
Step: 17, Total Steps: 17, Entropy: -10984.6787
Step: 18, Total Steps: 18, Entropy: -11410.0859
Step: 19, Total Steps: 19, Entropy: -11773.5566
Step: 20, Total Steps: 20, Entropy: -12112.8535
Step: 21, Total Steps: 21, Entropy: -12449.5439
Step: 22, Total Steps: 22, Entropy