In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import time

In [2]:
# Load the pre-saved MNIST subset (100 samples per class)
mnist_subset = torch.load("mnist_subset_100_per_class.pt")
images = torch.stack([item[0] for item in mnist_subset])  # Shape: [1000, 1, 28, 28]
labels = torch.tensor([item[1] for item in mnist_subset])

In [None]:
images.shape # 1000 image each of size 

torch.Size([1000, 1, 28, 28])

In [4]:
labels.shape

torch.Size([1000])

In [5]:
# Prepare the dataset (single batch for SKA forward learning)
inputs = images  # No mini-batches, full dataset used for forward-only updates

In [64]:
inputs.shape

torch.Size([1000, 1, 28, 28])

In [77]:

# Define the SKA model with 4 layers
class SKAModel(nn.Module):
    def __init__(self, input_size=784, layer_sizes=[256, 128, 64, 10], K=50):
        super(SKAModel, self).__init__()
        self.input_size = input_size
        self.layer_sizes = layer_sizes
        self.K = K  # Number of forward steps

        # Initialize weights and biases as nn.ParameterList
        self.weights = nn.ParameterList()
        self.biases = nn.ParameterList()
        prev_size = input_size
        for size in layer_sizes:
            self.weights.append(nn.Parameter(torch.randn(prev_size, size) * 0.01))
            self.biases.append(nn.Parameter(torch.zeros(size)))
            prev_size = size

        # Tracking tensors for knowledge accumulation and entropy computation
        self.Z = [None] * len(layer_sizes)  # Knowledge tensors per layer
        self.D = [None] * len(layer_sizes)  # Decision probability tensors
        self.D_prev = [None] * len(layer_sizes)  # Previous decisions for computing shifts
        self.delta_D = [None] * len(layer_sizes)  # Decision shifts per step
        self.entropy = [None] * len(layer_sizes)  # Layer-wise entropy storage

        # Store entropy, cosine, and output distribution history for visualization
        self.entropy_history = [[] for _ in range(len(layer_sizes))]
        self.cosine_history = [[] for _ in range(len(layer_sizes))]
        self.output_history = []  # New: Store mean output distribution (10 classes) per step

    
    def forward(self, x):
        """Computes SKA forward pass, storing knowledge and decisions."""
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)  # Flatten images

        for l in range(len(self.layer_sizes)):
            # Compute knowledge tensor Z = Wx + b
            print(x.shape)
            print(self.weights[l].shape)
            break
            z = torch.mm(x, self.weights[l]) + self.biases[l]
            # Apply sigmoid activation to get decision probabilities
            d = torch.sigmoid(z)
            # Store values for entropy computation
            self.Z[l] = z
            self.D[l] = d
            x = d  # Output becomes input for the next layer
            

        # return x


In [78]:

def forward(self, x):
    """Computes SKA forward pass, storing knowledge and decisions."""
    batch_size = x.shape[0]
    x = x.view(batch_size, -1)  # Flatten images

    for l in range(len(self.layer_sizes)):
        # Compute knowledge tensor Z = Wx + b
        z = torch.mm(x, self.weights[l]) + self.biases[l]
        # Apply sigmoid activation to get decision probabilities
        d = torch.sigmoid(z)
        # Store values for entropy computation
        self.Z[l] = z
        self.D[l] = d
        x = d  # Output becomes input for the next layer

    return x

### training 

In [79]:
# Training parameters
model = SKAModel()
learning_rate = 0.01

# SKA training over multiple forward steps
total_entropy = 0
step_count = 0
start_time = time.time()

# Initialize tensors for first step
# model.initialize_tensors(inputs.size(0))

In [80]:
inputs.shape[0] # we get the number of inputsinputs.shape[0]

1000

In [81]:
inputs.view(inputs.shape[0], -1).shape

torch.Size([1000, 784])

In [82]:
model.forward(inputs)

torch.Size([1000, 784])
torch.Size([784, 256])
