In [2]:
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import math
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader

In [280]:
class DendriticGatedLayer(nn.Module):
    def __init__(self, n_units, n_prev_layer_units, n_network_inputs):
        super().__init__()
        self.weights = nn.Parameter(torch.rand(n_units, n_prev_layer_units), requires_grad=True)
        #self.halfplane_gates = torch.normal(torch.Tensor([[0]*n_units]*n_network_inputs)).T
        self.halfplane_gates = torch.randn_like(torch.zeros(n_units,n_prev_layer_units,n_network_inputs))
        
        self.n_units = n_units
        self.n_prev_layer_units = n_prev_layer_units
        self.n_network_inputs = n_network_inputs
        
        
    def forward(self, X, network_inputs, gate_threshold=0):
        gating = (self.halfplane_gates @ network_inputs.T)
        gate_states = torch.permute(gating, (2,0,1)) > 0
        active_weights = torch.mul(self.weights.repeat([X.shape[0],1,1]), gate_states)
        output = (active_weights @ X.T)[0].T
        
        '''
            self.gate_states is now a (n_units x n_prev_layer_units) array.
            Each row gives the gate state of the afferent synapses of a single unit in the current layer
        '''
        return output
        

class DendriticGatedNet(nn.Module):
    def __init__(self, n_inputs, n_hiddens, n_out):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        self.n_out = n_out
        
        self.layer1 = DendriticGatedLayer(
            n_units=n_hiddens,
            n_prev_layer_units=n_inputs,
            n_network_inputs=n_inputs
        )
        
        self.layer2 = DendriticGatedLayer(
            n_units=n_hiddens,
            n_prev_layer_units=n_hiddens,
            n_network_inputs=n_inputs
        )
        
        self.outlayer = nn.Linear(n_hiddens, n_out)
    
    def forward(self, X):
        layer1_activity = self.layer1.forward(X,X)
        layer2_activity = self.layer2.forward(layer1_activity, X)
        out_activity = self.outlayer(layer2_activity)
        
        return F.relu(out_activity)

In [297]:
class RandomData(Dataset):
    def __init__(self, 
                 pattern_length,
                 dataset_size
                ):
        super().__init__()

        memories = torch.Tensor(np.random.uniform(-1,1, size=(dataset_size,pattern_length)))
        
        self.data = memories
        self.x = memories
        #self.y = torch.randint(10, (dataset_size,1))
        self.y = memories
        self.n_samples = memories.shape[0]
    
    #Implement necessary helper functions
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]

In [298]:
data = RandomData(32,512)

In [299]:
loader = DataLoader(data, batch_size=32, shuffle=True)

In [300]:
net = DendriticGatedNet(32,10,32)

In [301]:
sample = next(iter(loader))

In [303]:
net(sample[0]).shape

torch.Size([32, 32])