In [1]:
import torch
import torch.nn as nn
import math

# Each gate is a class

In [2]:
class Gate():
    """
    Implements gates for transient rewiring of a network. This is inspired by the paper:
    
    Nikolić, D. (2023). Where is the mind within the brain? Transient selection of subnetworks 
    by metabotropic receptors and G protein-gated ion channels. 
    Computational Biology and Chemistry, 103, 107820.

    Note: currently works only for 1D layers of neurons
    
    x, y: determine which connection is being gated by this gate, where x and y are indexes of neuron in
        input and output layer for that weight, respectively.
        x can have the value "bias" in case the gate gates the bias parameter.
        
    trigger_layer: in which layer the trigger for this gate is located, input or output. The values
            are "down" for input and "up" for output.
            
    trigger_neuron_id: the neuron index that will provide inputs for triggering this gate
    
    g_activated: the gating value g that will be set once the gate is activated. 
    
    activation_threshold: The minimal output value of the trigger neuron necessary to activate the gates. 
    The default value is 0.1. 
    
    g_default: The value of the gating parameter in its inactivatged state. The deafult value is 1.0.
    
    duration: deterimines the number of iteratations after which the gate will return back 
        to its default state. The default number of iterations is 5.
    
    Usage:
    gate = Gate(g, neuron_id)
    
    
    We have the following methods:
    
    
    get_trigger_neuron(): returns the trigger neuron index and the layer (up or down)
    
    sniff(): check the input and decides whether to activate the gate; it does nothing if the gates is
            already active.
            
    """

    def __init__(self, x, y, trigger_layer, trigger_neuron_id, g_activated, activation_threshold = 0.1, g_default = 1.0, duration = 5):

        if x == "bias":
            self.n_x = -1
        else:
            self.n_x = x
        self.n_y = y
        if trigger_layer == "down":
            self.layer = trigger_layer #up or down
        else:
            self.layer = "up"
        self.g_activated = g_activated
        self.neuron_id = trigger_neuron_id
        self.activation_threshold = activation_threshold
        self.g_default = g_default
        self.duration = duration
        self.state = self.g_default
        self.counter = 0
        self.activated = False

    def get_trigger_neuron(self):
        return self.layer, self.neuron_id

    def sniff(self, input):
        if self.counter == 0 and input > self.activation_threshold:
                self.state = self.g_activated
                self.counter = self.duration
                self.activated = True
        self.counter_()

    def counter_(self):
        if self.counter > 0:
            self.counter -= 1
        else:
            self.state = self.g_default
            self.activated =  False
            
    def de_activate(self):
        self.activated = False
        self.counter = 0
        self.state = self.g_default
    

# Here we expand the functionality of a PyTorch layer to accommodate gates

In [3]:
class GatedLinear(nn.Module):
    """
    This overwrites the Linear function of the nn module. The new forward function also applies gates
    
    The constructor receivies a list of gates to operate within this layer
    
    forward() is extended so that it respects the gate values
    
    think(): this is a function that updates gates of that layer
    """
    def __init__(self, in_features, out_features, set_gates=None, bias=True):
        super(GatedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        if set_gates is not None:
            self.gating = True
            self.g_weight = nn.Parameter(torch.ones(out_features, in_features))
            self.g_weight.requires_grad=False
            if bias:
                self.g_bias = nn.Parameter(torch.ones(out_features))
                self.g_bias.requires_grad=False
            else:
                self.register_parameter('g_bias', None)
            
        else:
            self.gating = False
            self.register_parameter('g_weight', None)
            self.register_parameter('g_bias', None)
            
        self.reset_parameters()        
        if set_gates is not None:        
            self.gates = set_gates()      

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input, use_gates = True):            
        if self.gating is True and use_gates is True:
            linear_output = torch.matmul(input, self.weight.t() * self.g_weight.t()) 
        else:
            linear_output = torch.matmul(input, self.weight.t())
            
        if self.bias is not None:
            if self.gating is True and use_gates is True:
                linear_output += self.g_bias * self.bias
            else:
                linear_output += self.bias
            
        return linear_output

    def think(self, output_down, output_up):
        self.g_weight.fill_(1.0)
        self.g_bias.fill_(1.0)
        
        for g in self.gates:
            layer, n_id = g.get_trigger_neuron()
            if layer == "down":
                g.sniff(output_down[n_id])
            else:
                g.sniff(output_up[n_id])
            
            if g.n_x >= 0:
                self.g_weight[g.n_x, g.n_y] *= g.state
            else:
                self.g_bias[g.n_y] *= g.state
        return self.gates
    
    def de_activate_all_gates(self):
        for g in self.gates:
            g.de_activate()

# Define gates for each region with adjustable thresholds and gating factors

In [4]:

class AdaptiveGate(Gate):
    def adjust_threshold(self, delta):
        self.activation_threshold += delta
        self.activation_threshold = max(0, min(self.activation_threshold, 1))

def adaptive_region1_gates():
    return [AdaptiveGate(0, 1, "down", trigger_neuron_id=0, g_activated=10, activation_threshold=0.5),
            AdaptiveGate(1, 0, "down", trigger_neuron_id=1, g_activated=5, activation_threshold=0.3)]

def adaptive_region2_gates():
    return [AdaptiveGate(0, 1, "down", trigger_neuron_id=0, g_activated=0.1, activation_threshold=0.4),
            AdaptiveGate(1, 0, "down", trigger_neuron_id=1, g_activated=2, activation_threshold=0.6)]

# Create two GatedLinear layers with adaptive gates

In [5]:

gl_region1 = GatedLinear(2, 2, set_gates=adaptive_region1_gates)
gl_region2 = GatedLinear(2, 2, set_gates=adaptive_region2_gates)

# Initialize weights and biases

In [6]:

gl_region1.weight = nn.Parameter(torch.tensor([[0.8, 0.2], [0.3, 0.7]]))
gl_region1.bias = nn.Parameter(torch.tensor([0.1, -0.2]))
gl_region2.weight = nn.Parameter(torch.tensor([[0.5, 0.5], [0.4, 0.6]]))
gl_region2.bias = nn.Parameter(torch.tensor([0.2, -0.1]))

# Define input patterns and target outputs for learning

In [7]:

inputs = [torch.tensor([0.2, 0.2]), torch.tensor([0.6, 0.2]), torch.tensor([0.3, 0.4]), torch.tensor([0.7, 0.6])]
targets = [torch.tensor([0.3, 0.0]), torch.tensor([0.6, 0.3]), torch.tensor([0.4, 0.2]), torch.tensor([0.8, 0.5])]

# Simple learning loop

In [8]:

learning_rate = 0.01
toggle_region = True

for epoch in range(35):  # Iterate multiple times to simulate learning
    print(f"Epoch {epoch + 1}")
    for i, inp in enumerate(inputs):
        # Toggle between regions
        current_layer = gl_region1 if toggle_region else gl_region2
        out = current_layer.forward(inp, use_gates=True)
        
        # Get the target output for comparison
        target = targets[i]
        
        # Calculate error as a simple difference
        error = target - out
        print(f"Input {i + 1}, Target: {target.numpy()}, Output: {out.detach().numpy()}, Error: {error.detach().numpy()}")
        
        # Adjust gates based on error: Reduce threshold if error is high
        for gate in current_layer.gates:
            if error.abs().sum() > 0.1:  # If error is significant, adjust threshold
                adjustment = -learning_rate if toggle_region else learning_rate
                gate.adjust_threshold(adjustment)
        
        # Simulate gradual weight update as part of learning (similar to backpropagation)
        with torch.no_grad():
            current_layer.weight += learning_rate * error.unsqueeze(1) * inp.unsqueeze(0)
            current_layer.bias += learning_rate * error

        # Print gate states and thresholds after adjustment
        print(f"Gate Thresholds: {[gate.activation_threshold for gate in current_layer.gates]}")
        
        # Toggle region for the next input
        toggle_region = not toggle_region

    print("\n")

Epoch 1
Input 1, Target: [0.3 0. ], Output: [0.3 0. ], Error: [0. 0.]
Gate Thresholds: [0.5, 0.3]
Input 2, Target: [0.6 0.3], Output: [0.6        0.26000002], Error: [0.         0.03999999]
Gate Thresholds: [0.4, 0.6]
Input 3, Target: [0.4 0.2], Output: [0.42000002 0.17      ], Error: [-0.02000001  0.03      ]
Gate Thresholds: [0.5, 0.3]
Input 4, Target: [0.8 0.5], Output: [0.84999996 0.540616  ], Error: [-0.04999995 -0.04061598]
Gate Thresholds: [0.4, 0.6]


Epoch 2
Input 1, Target: [0.3 0. ], Output: [0.299772 0.000342], Error: [ 0.00022802 -0.000342  ]
Gate Thresholds: [0.5, 0.3]
Input 2, Target: [0.6 0.3], Output: [0.59923005 0.2599345 ], Error: [0.00076997 0.0400655 ]
Gate Thresholds: [0.4, 0.6]
Input 3, Target: [0.4 0.2], Output: [0.41975263 0.17037112], Error: [-0.01975262  0.02962889]
Gate Thresholds: [0.5, 0.3]
Input 4, Target: [0.8 0.5], Output: [0.8490869  0.54048157], Error: [-0.04908687 -0.04048157]
Gate Thresholds: [0.4, 0.6]


Epoch 3
Input 1, Target: [0.3 0. ], Output: 

# New input data to test how the network reacts to unseen patterns

In [9]:
new_inputs = [
    torch.tensor([0.5, 0.1]),  # Unseen pattern similar to training set but with minor differences
    torch.tensor([0.9, 0.8]),  # High values to test strong gating
    torch.tensor([0.2, 0.7]),  # Mixed values to activate different gates
    torch.tensor([0.1, 0.1])   # Low values to see if gates stay inactive
]

# Define the expected reactions for both new and training inputs
expected_reactions = [
    "similar to low activation, gates likely inactive",      # New input 1
    "strong response, high activation expected in gates",    # New input 2
    "moderate response, partial gate activation",            # New input 3
    "minimal response, likely no gate activation",           # New input 4
    "training data, should fall within trained range",       # Training input 5
    "training data, should fall within trained range"        # Training input 6
]

# Control variable for alternating regions during testing
toggle_region = True

# Define expected output ranges for each input, including two training inputs
expected_ranges = [
    {"Output 1": (0.2, 0.4), "Output 2": (0.0, 0.1)},    # Input 1 (New)
    {"Output 1": (0.8, 1.2), "Output 2": (0.6, 1.0)},    # Input 2 (New)
    {"Output 1": (0.4, 0.7), "Output 2": (0.2, 0.4)},    # Input 3 (New)
    {"Output 1": (0.1, 0.3), "Output 2": (-0.1, 0.1)},   # Input 4 (New)
    {"Output 1": (0.6, 0.8), "Output 2": (0.2, 0.4)},    # Input 5 (Training)
    {"Output 1": (0.4, 0.6), "Output 2": (0.3, 0.5)}     # Input 6 (Training)
]

# Define both new inputs and training inputs to test
test_inputs = [
    torch.tensor([0.5, 0.1]),  # New input
    torch.tensor([0.9, 0.8]),  # New input
    torch.tensor([0.2, 0.7]),  # New input
    torch.tensor([0.1, 0.1]),  # New input
    torch.tensor([0.6, 0.2]),  # Training input
    torch.tensor([0.3, 0.4])   # Training input
]

# Run

In [10]:


print("=== Testing Network with New and Training Data ===")
for i, new_input in enumerate(test_inputs):
    # Select the region based on toggle
    current_layer = gl_region1 if toggle_region else gl_region2
    output = current_layer.forward(new_input, use_gates=True)
    current_layer.think(new_input, output)  # Update gate states based on input
    
    # Define expected ranges for the current input and calculate deviations
    expected_range_output1 = expected_ranges[i]["Output 1"]
    expected_range_output2 = expected_ranges[i]["Output 2"]
    
    deviation_output1 = abs(output[0].item() - ((expected_range_output1[0] + expected_range_output1[1]) / 2))
    deviation_output2 = abs(output[1].item() - ((expected_range_output2[0] + expected_range_output2[1]) / 2))

    # Check if within expected ranges
    in_range_output1 = expected_range_output1[0] <= output[0].item() <= expected_range_output1[1]
    in_range_output2 = expected_range_output2[0] <= output[1].item() <= expected_range_output2[1]
    
    # Detailed explanation of the expected reaction
    if i == 0:
        expected_explanation = (
            "Input has moderate values, likely triggering low activation or no gating effects. "
            "We expect an output that is influenced minimally by the gates, similar to when all gates are inactive."
        )
    elif i == 1:
        expected_explanation = (
            "High values in the input are likely to activate multiple gates, as thresholds may be exceeded. "
            "This should result in a significant amplification of the output, especially for gates with strong gating factors."
        )
    elif i == 2:
        expected_explanation = (
            "This mixed input pattern may lead to partial activation, where only some gates are triggered. "
            "The output should reflect a moderate change in values, as only certain connections are gated."
        )
    elif i == 3:
        expected_explanation = (
            "Low input values should fall below most gate thresholds, leading to little or no activation. "
            "We expect the output to be close to the default (ungated) response, with minimal impact from gating."
        )
    elif i == 4:
        expected_explanation = (
            "This input was part of the training set, so we expect a response that falls within the trained range."
        )
    elif i == 5:
        expected_explanation = (
            "This input was part of the training set, so we expect a response that falls within the trained range."
        )

    # Display observed results along with detailed expectations
    print(f"New Input {i + 1}: {new_input.numpy()}")
    print(f"Observed Output: {output.detach().numpy()}")
    print(f"Gate States: {[gate.state for gate in current_layer.gates]}")
    print(f"Expected Range for Output 1: {expected_range_output1}, Observed Output 1: {output[0].item()}, In Range: {in_range_output1}")
    print(f"Expected Range for Output 2: {expected_range_output2}, Observed Output 2: {output[1].item()}, In Range: {in_range_output2}")
    print(f"Deviation for Output 1: {deviation_output1:.3f}, Deviation for Output 2: {deviation_output2:.3f}")
    print(f"Expected Reaction: {expected_reactions[i]}")
    print(f"Explanation: {expected_explanation}")
    print(f"Gate Thresholds: {[gate.activation_threshold for gate in current_layer.gates]}")
    print("-" * 50)
    
    # Toggle region for next input
    toggle_region = not toggle_region

=== Testing Network with New and Training Data ===
New Input 1: [0.5 0.1]
Observed Output: [0.51424766 0.02862865]
Gate States: [1.0, 1.0]
Expected Range for Output 1: (0.2, 0.4), Observed Output 1: 0.5142476558685303, In Range: False
Expected Range for Output 2: (0.0, 0.1), Observed Output 2: 0.028628647327423096, In Range: True
Deviation for Output 1: 0.214, Deviation for Output 2: 0.021
Expected Reaction: similar to low activation, gates likely inactive
Explanation: Input has moderate values, likely triggering low activation or no gating effects. We expect an output that is influenced minimally by the gates, similar to when all gates are inactive.
Gate Thresholds: [0.5, 0.3]
--------------------------------------------------
New Input 2: [0.9 0.8]
Observed Output: [1.0270493  0.73551595]
Gate States: [0.1, 2]
Expected Range for Output 1: (0.8, 1.2), Observed Output 1: 1.0270493030548096, In Range: True
Expected Range for Output 2: (0.6, 1.0), Observed Output 2: 0.7355159521102905, I