# Mathematical Framework of Loihi
The purpose of this notebook is display the mathematical framework of the Loihi Chip and to simplify the parameter tuning of future compartments and networks

## Neural Model 
Loihi uses the leaky-integrate-and-fire model which has two internal state variables: 
### 1) $u_i(t)$: the synaptic response current
$$
u_i(t) = u_i(t - 1) \cdot (2^{12} - \delta_i^{(u)}) \cdot 2^{-12} + 2^{6+wgtExp} \sum_jw_{ij}\cdot s_j(t)
$$
- $i$ indicates the index of the post-syntapic neuron (in Loihi neuron are represented as compartments)
- $j$ indicates the index of the pre-syntapic neuron
- $\delta_i^{(u)}$ represents the current decay (Default = 4096)
$$
compartmentCurrentDecay = (1/\tau)*2^{12} 
$$
$$
\tau = e^{(-t/\tau)}
$$
- $u_i(t)$ is the compartment's state current at timestep $t$. The compartment current integrates incoming weighted spikes from the dendritic accumulators and possibly inputs from other compartments but decays exponentially otherwise
### 2) $v_t(t)$: the membrane voltage potential
$$
v_i(t) = v_i(t-1) \cdot (2^{12} - VoltageDecay) * 2^{-12} + u_i(t) + (biasMant*2^{biasExp})
$$
- Voltage decay is defined by (default = 0):
$$
VoltageDecay = (1/\tau) \cdot 2^{12}
$$
- the compartment voltage $(v_i(t))$ integrates the compartment current $(u_i(t))$, the compartment bias, and possibly inputs from other compartments
- the firing rate comes from setting the voltage threshold (default 6400): 
$$
vTh = vthMant * 2^6
$$




In [24]:
def get_vth(vthMant):
    return vthMant * (2 **6)

def get_vthMant(vTh):
    return vTh /(2 ** 6)

def get_curr_decay_tau(currentDecay):
    return (2 ** 12) / currentDecay

def get_curr_decay(tau):
    return (1 / tau) * (2 ** 12)

def act_voltage_decay(voltageDecay):
    return ((2**12) - voltageDecay) / (2**12)

def get_voltage_decay(tau):
    return (1 / tau) * (2 ** 12)





## Synaptic Connection Model
- Weight $(w)$ can take on a range of values of [-256, 256]
- The below formulas break down how synatpic weights are accumulated into the compartment current
$$
numLsbBits = 8 - (numWeightbits - IS_MIXED) \\
actWeight = (weight >> numLsbBits) << numLsbBits
$$
- `numWeightBits` specifies the number of bits and therefore the precision of a synaptic `weight`. It can take values of `0,1,2,3,4,5,6,7,8`
- Before weight is accumulated to a current value, an additional exponential scalling is performed: 
$$
2^{6 + wgtExp}
$$
- In summary the weight component that gets integrated to current is as follows (assuming 8 bit resolution)
$$
w_i = w * 2 ^{6 + wgtExp}
$$

In [25]:
def calculate_actual_weight(weight, num_weight_bits, wgt_exp, is_mixed):
    """
    Calculate the actual weight value that gets integrated into the compartment current.
    
    Parameters:
    weight (int): The synaptic weight in the range of [-256, 256].
    num_weight_bits (int): The number of bits specifying the precision of the synaptic weight. Can take values from 0 to 8.
    wgt_exp (int): The weight exponent used for additional exponential scaling.
    is_mixed (int): Indicator if the mixed precision mode is used.
    
    Returns:
    float: The actual weight value that gets integrated into the compartment current.
    """
    # Calculate the number of least significant bits
    num_lsb_bits = 8 - (num_weight_bits - is_mixed)
    
    # Calculate the actual weight by shifting
    act_weight = (weight >> num_lsb_bits) << num_lsb_bits
    
    # Calculate the weight component that gets integrated into the current
    weight_component = act_weight * (2 ** (6 + wgt_exp))
    
    return weight_component

# Example usage
weight = 10
num_weight_bits = 8
wgt_exp = 0
is_mixed = 0

actual_weight = calculate_actual_weight(weight, num_weight_bits, wgt_exp, is_mixed)
print("Actual Weight:", actual_weight)


Actual Weight: 640


In [26]:
actVoltDecay = act_voltage_decay(int(1 / 4 * 2 ** 12))
print(actVoltDecay)

voltageTh = get_vth(100)
print(voltageTh)

tau = get_curr_decay_tau(int(6400 / 2))
print(tau)

currDecay = get_curr_decay(5)
print(currDecay)

0.75
6400
1.28
819.2


In [27]:
import ipywidgets as widgets
from IPython.display import display

# Define functions for calculations
def get_vth(vthMant):
    return vthMant * (2 ** 6)

def get_vthMant(vTh):
    return vTh / (2 ** 6)

def get_curr_decay_tau(currentDecay):
    return (2 ** 12) / currentDecay

def get_curr_decay(tau):
    return (1 / tau) * (2 ** 12)

def act_voltage_decay(voltageDecay):
    return ((2**12) - voltageDecay) / (2**12)

def get_voltage_decay(tau):
    return (1 / tau) * (2 ** 12)

def calculate_actual_weight(weight, num_weight_bits, wgt_exp, is_mixed):
    num_lsb_bits = 8 - (num_weight_bits - is_mixed)
    act_weight = (weight >> num_lsb_bits) << num_lsb_bits
    weight_component = act_weight * (2 ** (6 + wgt_exp))
    return weight_component

# Create widgets for inputs
vthMant_input = widgets.IntText(value=100, description='vthMant:')
vTh_input = widgets.FloatText(value=6400, description='vTh:')
currentDecay_input = widgets.IntText(value=4096, description='Current Decay:')
tau_input = widgets.FloatText(value=10, description='Tau:')
voltageDecay_input = widgets.IntText(value=0, description='Voltage Decay:')
weight_input = widgets.IntText(value=10, description='Weight:')
num_weight_bits_input = widgets.IntText(value=8, description='Num Weight Bits:')
wgt_exp_input = widgets.IntText(value=0, description='Weight Exp:')
is_mixed_input = widgets.IntText(value=0, description='Is Mixed:')

# Create output areas
output_vth = widgets.Output()
output_vthMant = widgets.Output()
output_curr_decay_tau = widgets.Output()
output_curr_decay = widgets.Output()
output_voltage_decay = widgets.Output()
output_act_voltage_decay = widgets.Output()
output_weight = widgets.Output()

# Define update functions
def update_vth(change):
    with output_vth:
        output_vth.clear_output()
        print(get_vth(vthMant_input.value))

def update_vthMant(change):
    with output_vthMant:
        output_vthMant.clear_output()
        print(get_vthMant(vTh_input.value))

def update_curr_decay_tau(change):
    with output_curr_decay_tau:
        output_curr_decay_tau.clear_output()
        print(get_curr_decay_tau(currentDecay_input.value))

def update_curr_decay(change):
    with output_curr_decay:
        output_curr_decay.clear_output()
        print(get_curr_decay(tau_input.value))

def update_voltage_decay(change):
    with output_voltage_decay:
        output_voltage_decay.clear_output()
        print(get_voltage_decay(tau_input.value))

def update_act_voltage_decay(change):
    with output_act_voltage_decay:
        output_act_voltage_decay.clear_output()
        print(act_voltage_decay(voltageDecay_input.value))

def update_weight(change):
    with output_weight:
        output_weight.clear_output()
        print(calculate_actual_weight(weight_input.value, num_weight_bits_input.value, wgt_exp_input.value, is_mixed_input.value))

# Attach update functions to input widgets
vthMant_input.observe(update_vth, names='value')
vTh_input.observe(update_vthMant, names='value')
currentDecay_input.observe(update_curr_decay_tau, names='value')
tau_input.observe(update_curr_decay, names='value')
tau_input.observe(update_voltage_decay, names='value')
voltageDecay_input.observe(update_act_voltage_decay, names='value')
weight_input.observe(update_weight, names='value')
num_weight_bits_input.observe(update_weight, names='value')
wgt_exp_input.observe(update_weight, names='value')
is_mixed_input.observe(update_weight, names='value')

# Display the widgets and outputs
display(widgets.VBox([vthMant_input, output_vth]))
display(widgets.VBox([vTh_input, output_vthMant]))
display(widgets.VBox([currentDecay_input, output_curr_decay_tau]))
display(widgets.VBox([tau_input, output_curr_decay, output_voltage_decay]))
display(widgets.VBox([voltageDecay_input, output_act_voltage_decay]))
display(widgets.VBox([weight_input, num_weight_bits_input, wgt_exp_input, is_mixed_input, output_weight]))


VBox(children=(IntText(value=100, description='vthMant:'), Output()))

VBox(children=(FloatText(value=6400.0, description='vTh:'), Output()))

VBox(children=(IntText(value=4096, description='Current Decay:'), Output()))

VBox(children=(FloatText(value=10.0, description='Tau:'), Output(), Output()))

VBox(children=(IntText(value=0, description='Voltage Decay:'), Output()))

VBox(children=(IntText(value=10, description='Weight:'), IntText(value=8, description='Num Weight Bits:'), Intâ€¦