In [1]:
import torch
import torch.nn as nn
import math

# Each gate is a class

In [2]:

class Gate():
    """
    Implements gates for transient rewiring of a network. This is inspired by the paper:
    
    Nikolić, D. (2023). Where is the mind within the brain? Transient selection of subnetworks 
    by metabotropic receptors and G protein-gated ion channels. 
    Computational Biology and Chemistry, 103, 107820.

    Note: currently works only for 1D layers of neurons
    
    x, y: determine which connection is being gated by this gate, where x and y are indexes of neuron in
        input and output layer for that weight, respectively.
        x can have the value "bias" in case the gate gates the bias parameter.
        
    trigger_layer: in which layer the trigger for this gate is located, input or output. The values
            are "down" for input and "up" for output.
            
    trigger_neuron_id: the neuron index that will provide inputs for triggering this gate
    
    g_activated: the gating value g that will be set once the gate is activated. 
    
    activation_threshold: The minimal output value of the trigger neuron necessary to activate the gates. 
    The default value is 0.1. 
    
    g_default: The value of the gating parameter in its inactivatged state. The deafult value is 1.0.
    
    duration: deterimines the number of iteratations after which the gate will return back 
        to its default state. The default number of iterations is 5.
    
    Usage:
    gate = Gate(g, neuron_id)
    
    
    We have the following methods:
    
    
    get_trigger_neuron(): returns the trigger neuron index and the layer (up or down)
    
    sniff(): check the input and decides whether to activate the gate; it does nothing if the gates is
            already active.
            
    """

    def __init__(self, x, y, trigger_layer, trigger_neuron_id, g_activated, activation_threshold = 0.1, g_default = 1.0, duration = 5):

        if x == "bias":
            self.n_x = -1
        else:
            self.n_x = x
        self.n_y = y
        if trigger_layer == "down":
            self.layer = trigger_layer #up or down
        else:
            self.layer = "up"
        self.g_activated = g_activated
        self.neuron_id = trigger_neuron_id
        self.activation_threshold = activation_threshold
        self.g_default = g_default
        self.duration = duration
        self.state = self.g_default
        self.counter = 0
        self.activated = False

    def get_trigger_neuron(self):
        return self.layer, self.neuron_id

    def sniff(self, input):
        if self.counter == 0 and input > self.activation_threshold:
                self.state = self.g_activated
                self.counter = self.duration
                self.activated = True
        self.counter_()

    def counter_(self):
        if self.counter > 0:
            self.counter -= 1
        else:
            self.state = self.g_default
            self.activated =  False
            
    def de_activate(self):
        self.activated = False
        self.counter = 0
        self.state = self.g_default
    

# Here we expand the functionality of a PyTorch layer to accommodate gates

In [3]:

class GatedLinear(nn.Module):
    """
    This overwrites the Linear function of the nn module. The new forward function also applies gates
    
    The constructor receivies a list of gates to operate within this layer
    
    forward() is extended so that it respects the gate values
    
    think(): this is a function that updates gates of that layer
    """
    def __init__(self, in_features, out_features, set_gates=None, bias=True):
        super(GatedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        if set_gates is not None:
            self.gating = True
            self.g_weight = nn.Parameter(torch.ones(out_features, in_features))
            self.g_weight.requires_grad=False
            if bias:
                self.g_bias = nn.Parameter(torch.ones(out_features))
                self.g_bias.requires_grad=False
            else:
                self.register_parameter('g_bias', None)
            
        else:
            self.gating = False
            self.register_parameter('g_weight', None)
            self.register_parameter('g_bias', None)
            
        self.reset_parameters()        
        if set_gates is not None:        
            self.gates = set_gates()      

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input, use_gates = True):            
        if self.gating is True and use_gates is True:
            linear_output = torch.matmul(input, self.weight.t() * self.g_weight.t()) 
        else:
            linear_output = torch.matmul(input, self.weight.t())
            
        if self.bias is not None:
            if self.gating is True and use_gates is True:
                linear_output += self.g_bias * self.bias
            else:
                linear_output += self.bias
            
        return linear_output

    def think(self, output_down, output_up):
        self.g_weight.fill_(1.0)
        self.g_bias.fill_(1.0)
        
        for g in self.gates:
            layer, n_id = g.get_trigger_neuron()
            if layer == "down":
                g.sniff(output_down[n_id])
            else:
                g.sniff(output_up[n_id])
            
            if g.n_x >= 0:
                self.g_weight[g.n_x, g.n_y] *= g.state
            else:
                self.g_bias[g.n_y] *= g.state
        return self.gates
    
    def de_activate_all_gates(self):
        for g in self.gates:
            g.de_activate()
    

# Creating a neural network layer with gates in 2 different regions

In [4]:

# Define gates for region 1
def region1_gates():
    g1 = Gate(0, 1, "down", trigger_neuron_id=0, g_activated=10, activation_threshold=0.5)
    g2 = Gate(1, 0, "down", trigger_neuron_id=1, g_activated=5, activation_threshold=0.3)
    return [g1, g2]

# Define gates for region 2
def region2_gates():
    g3 = Gate(0, 1, "down", trigger_neuron_id=0, g_activated=0.1, activation_threshold=0.4)
    g4 = Gate(1, 0, "down", trigger_neuron_id=1, g_activated=2, activation_threshold=0.6)
    return [g3, g4]

# Instantiate two GatedLinear layers, each with a different set of gates
gl_region1 = GatedLinear(2, 2, set_gates=region1_gates)
gl_region2 = GatedLinear(2, 2, set_gates=region2_gates)

# Manually set the weights and biases to observe the effect of gating
gl_region1.weight = nn.Parameter(torch.tensor([[0.8, 0.2],
                                               [0.3, 0.7]]))
gl_region1.bias = nn.Parameter(torch.tensor([0.1, -0.2]))

gl_region2.weight = nn.Parameter(torch.tensor([[0.5, 0.5],
                                               [0.4, 0.6]]))
gl_region2.bias = nn.Parameter(torch.tensor([0.2, -0.1]))

# Define a sequence of inputs to test gating effects

In [5]:
# Define a sequence of inputs to test alternating gating effects
inputs = [
    torch.tensor([0.2, 0.2]),  # Should only activate low-threshold gates
    torch.tensor([0.6, 0.2]),  # Will activate region1 more
    torch.tensor([0.3, 0.4]),  # Mixed activation
    torch.tensor([0.7, 0.6])   # Strong activation in region2
]
  

# Compare to the outputs without using gates

In [6]:
# Control variable for alternating between regions
toggle_region = True  # Start with region 1

print("=== Alternating Outputs for Each Region ===")
for i, inp in enumerate(inputs, 1):
    if toggle_region:
        # Use Region 1 Gates
        out = gl_region1.forward(inp, use_gates=True)
        gl_region1.think(inp, out)  # Update gates in region 1
        print(f"Input {i} (Region 1): {inp.numpy()}, Output: {out.detach().numpy()}, Gate States: {[g.state for g in gl_region1.gates]}")
    else:
        # Use Region 2 Gates
        out = gl_region2.forward(inp, use_gates=True)
        gl_region2.think(inp, out)  # Update gates in region 2
        print(f"Input {i} (Region 2): {inp.numpy()}, Output: {out.detach().numpy()}, Gate States: {[g.state for g in gl_region2.gates]}")
    
    # Toggle region for the next input
    toggle_region = not toggle_region

=== Alternating Outputs for Each Region ===
Input 1 (Region 1): [0.2 0.2], Output: [0.3 0. ], Gate States: [1.0, 1.0]
Input 2 (Region 2): [0.6 0.2], Output: [0.6        0.26000002], Gate States: [0.1, 1.0]
Input 3 (Region 1): [0.3 0.4], Output: [0.42000002 0.17      ], Gate States: [1.0, 5]
Input 4 (Region 2): [0.7 0.6], Output: [0.58 0.54], Gate States: [0.1, 1.0]


# Explanation of Alternating Outputs for Each Region

This example demonstrates how the outputs change when toggling between two gate regions. Each region has a distinct set of gates with specific activation thresholds and amplification/reduction factors, resulting in different outputs even for the same input pattern. By alternating between Region 1 and Region 2, we observe repeatable outputs for each region when it is activated.

#### Results Breakdown

- **Input 1 (Region 1)**: `[0.2, 0.2]`
  - **Gate States**: `[1.0, 1.0]`
    - The input values are below the activation thresholds for Region 1’s gates (`g1` and `g2`), so both gates remain in their default state with a gating factor of `1.0`.
  - **Output**: `[0.3, 0.0]`
    - Since no gates are activated, the output is determined purely by the initial weights and biases of Region 1’s `GatedLinear` layer.

- **Input 2 (Region 2)**: `[0.6, 0.2]`
  - **Gate States**: `[0.1, 1.0]`
    - Region 2 is now active, and `g3` activates with a gating factor of `0.1` because `0.6` exceeds its threshold of `0.4`. `g4` remains in the default state of `1.0`.
  - **Output**: `[0.6, 0.26000002]`
    - The activated gate (`g3`) reduces the impact of the connection it controls, leading to a different output pattern compared to Region 1.

- **Input 3 (Region 1)**: `[0.3, 0.4]`
  - **Gate States**: `[1.0, 5]`
    - Region 1 is reactivated, and `g2` activates with a gating factor of `5` because `0.4` exceeds its threshold of `0.3`. `g1` remains inactive.
  - **Output**: `[0.42000002, 0.17]`
    - With `g2` amplifying its connection, the output reflects this increased contribution, showing a unique pattern for Region 1 with this input.

- **Input 4 (Region 2)**: `[0.7, 0.6]`
  - **Gate States**: `[0.1, 1.0]`
    - Back to Region 2, `g3` activates with a gating factor of `0.1` (due to its low activation threshold of `0.4`), while `g4` stays inactive at `1.0`.
  - **Output**: `[0.58, 0.54]`
    - Region 2 produces a consistent output for the same gate configuration as seen with Input 2, demonstrating repeatable behavior for this region's settings.

### Summary

This output illustrates how different regions of gates, with distinct thresholds and amplification factors, result in unique outputs. By alternating between Region 1 and Region 2, we observe consistent outputs for each region when revisited, demonstrating the controlled, repeatable influence of gating configurations on network behavior.
