In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [88]:
INPUT_DIM = 10
NUM_LAYERS = 1
BATCH_SIZE = 5
NUM_EPOCHS = 1
LEARNING_RATE = 0.05
ITERATIONS = 50


In [113]:
class InputPairsDataset(Dataset):
    def __init__(self, num_samples, input_dim):
        # Generate pairs of indices, ensuring they match in your desired way
        # For simplicity, using identity matrix pairs here as placeholders
        self.inputs = [torch.eye(input_dim)[i].reshape(1, -1).squeeze(0) for i in range(input_dim)]

        self.input_dim = input_dim
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.inputs[idx % self.input_dim], self.inputs[idx % self.input_dim], self.inputs[(idx + 1) % self.input_dim]


class LayerLocalNetwork(nn.Module):
    def __init__(self, bottom_dim, top_dim, num_layers=1, batch_size=1):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        self.optimizers = []

        for _ in range(num_layers):
            layer = {
                'bottom_up': nn.Parameter(torch.randn(bottom_dim, top_dim)),
                'top_down': nn.Parameter(torch.randn(top_dim, top_dim)),
                'recurrent': nn.Parameter(torch.randn(top_dim, top_dim))
            }
            self.layers.append(nn.ParameterDict(layer))
            self.optimizers.append({
                'bottom_up': optim.Adam([layer['bottom_up']], lr=LEARNING_RATE),
                'top_down': optim.Adam([layer['top_down']], lr=LEARNING_RATE),
                'recurrent': optim.Adam([layer['recurrent']], lr=LEARNING_RATE),
            })

        self.activations = [torch.zeros(batch_size, top_dim) for _ in range(num_layers)]

    def forward(self, bottom_input, top_input):
        for i, layer in enumerate(self.layers):
            self.optimizers[i]['bottom_up'].zero_grad()
            self.optimizers[i]['top_down'].zero_grad()
            self.optimizers[i]['recurrent'].zero_grad()

        for i, layer in enumerate(self.layers):
            bottom_up_act = torch.mm(bottom_input.detach(), layer['bottom_up']) if i == 0 else torch.mm(
                self.activations[i-1], layer['bottom_up'])
            top_down_act = torch.mm(top_input.detach(), layer['top_down']) if i == self.num_layers - 1 else torch.mm(
                self.activations[i+1], layer['top_down'])
            recurrent_act = torch.mm(self.activations[i], layer['recurrent'])

            total_input = bottom_up_act + top_down_act + recurrent_act
            self.activations[i] = F.leaky_relu(total_input)

        loss = self.compute_energy()
        loss.backward()

        for i, layer in enumerate(self.layers):
            self.optimizers[i]['bottom_up'].step()
            self.optimizers[i]['top_down'].step()
            self.optimizers[i]['recurrent'].step()
        
        for i in range(0, len(self.activations)):
            self.activations[i] = self.activations[i].detach()

        return loss

    def compute_energy(self):
        # Push energy down proportional to activations
        running_sum = 0
        for act in self.activations:
            # pow the activations, average across neurons in layer, average across batches
            running_sum += torch.mean(torch.mean(act.pow(2), dim=1), dim=0)
        standard_loss = running_sum / len(self.activations)

        # Hebbian loss computation: encourage variance at neuron level
        hebbian_loss = 0
        for act in self.activations:
            hebbian_loss += self.generate_lpl_loss_hebbian(act)

        # TODO: predictive and decorrelative losses

        # Combine losses
        total_loss = standard_loss + hebbian_loss  # Consider weighting factors if necessary
        # total_loss = standard_loss
        return total_loss

    def generate_lpl_loss_hebbian(self, activations):
        mean_act = torch.mean(activations, dim=0)
        mean_subtracted = activations - mean_act
        sigma_squared = torch.sum(mean_subtracted ** 2, dim=0) / (activations.shape[0] - 1)
        loss = -torch.log(sigma_squared + 1e-10).sum() / sigma_squared.shape[0]
        return loss

dataset = InputPairsDataset(num_samples=100, input_dim=INPUT_DIM)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Example usage:
model = LayerLocalNetwork(bottom_dim=INPUT_DIM, top_dim=INPUT_DIM, num_layers=NUM_LAYERS, batch_size=BATCH_SIZE)

In [4]:
class ActivationDecoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)


In [5]:
# for epoch in range(NUM_EPOCHS):
#     print("Epoch:", epoch)
#     for bottom_input, top_input, next_input in dataloader:
#         for i in range(ITERATIONS):
#             energy = model(bottom_input, top_input)
#             # layer_activations = torch.stack([layer_activations.clone() for layer_activations in model.activations], dim=1).reshape(-1, INPUT_DIM)
#             print("Energy:", f"{energy.item(): .2f}")
#         print("-----")
#     print()

# print("======TRYING NEGATIVE SAMPLES======")


In [114]:
from collections import deque


for epoch in range(NUM_EPOCHS):
    print("Epoch:", epoch)
    for bottom_input, top_input, _ in dataloader:
        running_sum = 0
        layer_activations_queue = deque(maxlen=10)
        for i in range(ITERATIONS):
            loss = model(bottom_input, top_input)
            running_sum += loss.item()
            # print("Loss:", f"{loss.item(): .2f}")

            # layer_activations = torch.stack([layer_activations.clone() for layer_activations in model.activations], dim=1).reshape(-1, INPUT_DIM)
            # layer_activations_queue.append(layer_activations)

            # input_to_decoder = torch.stack(list(layer_activations_queue), dim=1)
            # print("layer activations shape: ", layer_activations.shape)
            # print("shape: ", input_to_decoder.shape)

        print("Average Loss:", f"{running_sum / ITERATIONS: .3f}")
        # print("----")
    print()

print("======TRYING NEGATIVE SAMPLES======")


bottom_input = torch.eye(10)[0].reshape(1, -1)  # One-hot vector for bottom input
top_input = torch.eye(10)[1].reshape(1, -1)    # One-hot vector for top input

running_sum = 0
for i in range(75):
    loss = model(bottom_input, top_input)
    running_sum += loss.item()
    # print("Loss:", f"{loss.item(): .2f}")
print("Average Loss:", f"{running_sum / 75: .3f}")


print("======TRYING POSITIVE SAMPLE AGAIN======")

for epoch in range(NUM_EPOCHS):
    print("Epoch:", epoch)
    for bottom_input, top_input, _ in dataloader:
        running_sum = 0
        for i in range(ITERATIONS):
            loss = model(bottom_input, top_input)
            running_sum += loss.item()
            # print("Loss:", f"{loss.item(): .2f}")
        print("Average Loss:", f"{running_sum / ITERATIONS: .3f}")
        # print("----")
    print()


Epoch: 0
Average Loss:  10497.498
Average Loss:  84.995
Average Loss:  2.853
Average Loss:  2.407
Average Loss:  2.647
Average Loss:  2.337
Average Loss:  2.771
Average Loss:  2.855
Average Loss:  2.562
Average Loss:  1.706
Average Loss:  2.602
Average Loss:  2.432
Average Loss:  1.860
Average Loss:  2.386
Average Loss:  2.354
Average Loss:  2.220
Average Loss:  2.328
Average Loss:  2.185
Average Loss:  1.965
Average Loss:  1.647

Average Loss:  6.191
Epoch: 0
Average Loss:  3.038
Average Loss:  1.964
Average Loss:  1.805
Average Loss:  1.802
Average Loss:  1.947
Average Loss:  2.461
Average Loss:  2.146
Average Loss:  1.330
Average Loss:  3.428
Average Loss:  3.186
Average Loss:  2.064
Average Loss:  2.277
Average Loss:  1.639
Average Loss:  2.716
Average Loss:  1.819
Average Loss:  1.945
Average Loss:  3.039
Average Loss:  2.388
Average Loss:  6.356
Average Loss:  2.144

