In [18]:
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import math
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader

In [69]:
class DendriticGatedLayer(nn.Module):
    def __init__(self, n_units, n_prev_layer_units, n_network_inputs):
        super().__init__()
        self.weights = nn.Parameter(torch.rand(n_units, n_prev_layer_units), requires_grad=True)
        #self.halfplane_gates = torch.normal(torch.Tensor([[0]*n_units]*n_network_inputs)).T
        self.halfplane_gates = torch.randn_like(torch.zeros(n_units,n_prev_layer_units,n_network_inputs))
        
        self.n_units = n_units
        self.n_prev_layer_units = n_prev_layer_units
        self.n_network_inputs = n_network_inputs
        
        
    def forward(self, X, network_inputs, gate_threshold=0):
        self.gate_states = torch.matmul(self.halfplane_gates, network_inputs) > gate_threshold
        '''
            self.gate_states is now a (n_units x n_prev_layer_units) array.
            Each row gives the gate state of the afferent synapses of a single unit in the current layer
        '''
        
        active_weights = torch.mul(self.gate_states, self.weights)
        
        '''
            TODO
        '''
        
class DendriticGatedNet(nn.Module):
    def __init__(self, n_inputs, n_hiddens, n_out):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        self.n_out = n_out
        
        self.layer1 = DendriticGatedLayer(
            n_units=n_hiddens,
            n_prev_layer_units=n_inputs,
            n_network_inputs=n_inputs
        )
        
        self.layer2 = DendriticGatedLayer(
            n_units=n_hiddens,
            n_prev_layer_units=n_hiddens,
            n_network_inputs=n_inputs
        )
        
        self.outlayer = nn.Linear(n_hiddens, n_out)
    
    def forward(self, X):
        '''
            TODO
        '''
        None
    

In [70]:
d = DendriticGatedLayer(10,5,2)

In [71]:
d.halfplane_gates.shape

torch.Size([10, 5, 2])

In [72]:
nints = torch.Tensor([1, 2])

In [73]:
d.weights

Parameter containing:
tensor([[0.3554, 0.1049, 0.2705, 0.2355, 0.0898],
        [0.9788, 0.0422, 0.1106, 0.5441, 0.8272],
        [0.6372, 0.7918, 0.4002, 0.0792, 0.2660],
        [0.0500, 0.4185, 0.7117, 0.8096, 0.8407],
        [0.7752, 0.5017, 0.8510, 0.9329, 0.8001],
        [0.9708, 0.6695, 0.5735, 0.0756, 0.6888],
        [0.7754, 0.7473, 0.7093, 0.3587, 0.6916],
        [0.9657, 0.5634, 0.3656, 0.2018, 0.1697],
        [0.0021, 0.8924, 0.3840, 0.9474, 0.6527],
        [0.0797, 0.9404, 0.4061, 0.9397, 0.5388]], requires_grad=True)

In [74]:
d.gate_states = torch.matmul(d.halfplane_gates, nints) > 0

In [68]:
torch.mul(d.gate_states, d.weights)

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3371, 0.3376, 0.0000, 0.7723, 0.5539],
        [0.9596, 0.4867, 0.0000, 0.0000, 0.4136],
        [0.8265, 0.0000, 0.0000, 0.8945, 0.1709],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6275, 0.0000, 0.0000, 0.6731, 0.0000],
        [0.9331, 0.5181, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.1989, 0.3595, 0.0000, 0.1478],
        [0.0000, 0.6895, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4793, 0.0000, 0.9342, 0.3264]], grad_fn=<MulBackward0>)

In [39]:
d.halfplane_gates

tensor([[[-0.8190, -0.7637],
         [-0.4799, -0.0825],
         [ 0.0427,  0.6324],
         [-1.2150, -0.7130],
         [ 1.2596, -0.0982]],

        [[-0.4832,  0.4515],
         [-0.5670, -1.2686],
         [-0.5599, -1.5672],
         [-1.4474,  0.8022],
         [-0.0559, -1.0652]],

        [[ 0.6129, -1.0537],
         [ 0.6510, -0.3724],
         [ 0.2602, -1.9522],
         [-0.3411, -0.5974],
         [-1.6851,  1.1642]],

        [[-0.6132, -1.3896],
         [-1.3078,  0.2334],
         [-0.1705, -0.3765],
         [ 0.2383,  0.0578],
         [ 1.2325,  0.3557]],

        [[-0.3953,  0.3149],
         [ 0.3031,  0.9879],
         [-0.2411,  0.1035],
         [ 1.1008, -1.3745],
         [ 0.0364,  1.2497]],

        [[-1.9182,  0.6105],
         [-1.6942, -0.5275],
         [-1.2167, -0.4694],
         [-0.1722, -1.7400],
         [-0.4286, -0.5878]],

        [[-0.8983,  0.0086],
         [-0.7074,  0.9069],
         [-0.1669, -0.1947],
         [ 1.0059,  0.9120],
  