In [2]:
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import math
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader

In [232]:
class DendriticGatedLayer(nn.Module):
    def __init__(self, n_units, n_prev_layer_units, n_network_inputs):
        super().__init__()
        self.weights = nn.Parameter(torch.rand(n_units, n_prev_layer_units), requires_grad=True)
        #self.halfplane_gates = torch.normal(torch.Tensor([[0]*n_units]*n_network_inputs)).T
        self.halfplane_gates = torch.randn_like(torch.zeros(n_units,n_prev_layer_units,n_network_inputs))
        
        self.n_units = n_units
        self.n_prev_layer_units = n_prev_layer_units
        self.n_network_inputs = n_network_inputs
        
        
    def forward(self, X, network_inputs, gate_threshold=0):
        gating = (self.halfplane_gates @ network_inputs.T)
        gate_states = torch.permute(gating, (2,0,1)) > 0
        active_weights = torch.mul(self.weights.repeat([X.shape[0],1,1]), gate_states)
        output = (active_weights @ X.T)[0].T
        
        '''
            self.gate_states is now a (n_units x n_prev_layer_units) array.
            Each row gives the gate state of the afferent synapses of a single unit in the current layer
        '''
        return output
        

class DendriticGatedNet(nn.Module):
    def __init__(self, n_inputs, n_hiddens, n_out):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        self.n_out = n_out
        
        self.layer1 = DendriticGatedLayer(
            n_units=n_hiddens,
            n_prev_layer_units=n_inputs,
            n_network_inputs=n_inputs
        )
        
        self.layer2 = DendriticGatedLayer(
            n_units=n_hiddens,
            n_prev_layer_units=n_hiddens,
            n_network_inputs=n_inputs
        )
        
        self.outlayer = nn.Linear(n_hiddens, n_out)
    
    def forward(self, X):
        layer1_activity = self.layer1.forward(X,X)
        layer2_activity = self.layer2.forward(layer1_activity, X)
        out_activity = self.outlayer(layer2_activity)
        
        return F.relu(out_activity)

In [233]:
net = DendriticGatedNet(2,5,10)

In [234]:
network_inputs = torch.Tensor([[1,1]])

In [235]:
net(network_inputs)

tensor([[0.0727, 0.0000, 0.7340, 0.0000, 0.0000, 0.0000, 0.0000, 0.3953, 0.0000,
         0.5551]], grad_fn=<ReluBackward0>)

In [193]:
d = DendriticGatedLayer(5,2,2)