In [18]:
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import math
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader

In [47]:
class DendriticGatedLayer(nn.Module):
    def __init__(self, n_units, n_prev_layer_units, n_network_inputs):
        super().__init__()
        self.weights = nn.Parameter(torch.zeros(n_units, n_prev_layer_units), requires_grad=True)
        #self.halfplane_gates = torch.normal(torch.Tensor([[0]*n_units]*n_network_inputs)).T
        self.halfplane_gates = torch.randn_like(torch.zeros(n_units,n_prev_layer_units,n_network_inputs))
        
        self.n_units = n_units
        self.n_prev_layer_units = n_prev_layer_units
        self.n_network_inputs = n_network_inputs
        
        
    def forward(self, X, network_inputs, gate_threshold=0):
        self.gate_states = torch.matmul(self.halfplane_gates, network_inputs)
        self.gate_states = self.gate_states > gate_threshold
        '''
            self.gate_states is now a (n_units x n_prev_layer_units) array.
            Each row gives the gate state of the afferent synapses of a single unit in the current layer
        '''
        
        
        
class DendriticGatedNet(nn.Module):
    def __init__(self, n_inputs, n_hiddens, n_out):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        self.n_out = n_out
        
        self.layer1 = DendriticGatedLayer(
            n_units=n_hiddens,
            n_prev_layer_units=n_inputs,
            n_network_inputs=n_inputs
        )
        
        self.layer2 = DendriticGatedLayer(
            n_units=n_hiddens,
            n_prev_layer_units=n_hiddens,
            n_network_inputs=n_inputs
        )
        
        self.outlayer = nn.Linear(n_hiddens, n_out)
    
    def forward(self, X):
        '''
            TODO
        '''
        None
    

In [23]:
d = DendriticGatedLayer(10,5,2)

In [26]:
d.halfplane_gates.shape

torch.Size([10, 5, 2])

In [37]:
nints = torch.Tensor([1, 2])

In [46]:
torch.matmul(d.halfplane_gates, nints) 

tensor([[-2.3464, -0.6449,  1.3075, -2.6411,  1.0632],
        [ 0.4197, -3.1043, -3.6943,  0.1570, -2.1862],
        [-1.4945, -0.0937, -3.6442, -1.5359,  0.6434],
        [-3.3923, -0.8410, -0.9234,  0.3539,  1.9439],
        [ 0.2346,  2.2789, -0.0342, -1.6482,  2.5359],
        [-0.6972, -2.7491, -2.1556, -3.6523, -1.6041],
        [-0.8810,  1.1065, -0.5562,  2.8299,  0.1198],
        [-1.5890,  2.3093,  1.5424,  3.0762,  2.5471],
        [ 1.4515,  0.7046,  0.8563, -0.6275, -2.5453],
        [ 2.7130,  1.5057, -2.4481, -1.1823, -0.2223]])

In [39]:
d.halfplane_gates

tensor([[[-0.8190, -0.7637],
         [-0.4799, -0.0825],
         [ 0.0427,  0.6324],
         [-1.2150, -0.7130],
         [ 1.2596, -0.0982]],

        [[-0.4832,  0.4515],
         [-0.5670, -1.2686],
         [-0.5599, -1.5672],
         [-1.4474,  0.8022],
         [-0.0559, -1.0652]],

        [[ 0.6129, -1.0537],
         [ 0.6510, -0.3724],
         [ 0.2602, -1.9522],
         [-0.3411, -0.5974],
         [-1.6851,  1.1642]],

        [[-0.6132, -1.3896],
         [-1.3078,  0.2334],
         [-0.1705, -0.3765],
         [ 0.2383,  0.0578],
         [ 1.2325,  0.3557]],

        [[-0.3953,  0.3149],
         [ 0.3031,  0.9879],
         [-0.2411,  0.1035],
         [ 1.1008, -1.3745],
         [ 0.0364,  1.2497]],

        [[-1.9182,  0.6105],
         [-1.6942, -0.5275],
         [-1.2167, -0.4694],
         [-0.1722, -1.7400],
         [-0.4286, -0.5878]],

        [[-0.8983,  0.0086],
         [-0.7074,  0.9069],
         [-0.1669, -0.1947],
         [ 1.0059,  0.9120],
  