<a href="https://colab.research.google.com/github/Razumovskyy/NN_connectivity_vs_entropy/blob/main/NN_entropy_connectivity1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import itertools
import random

In [1]:
def generate_environment(N, h):
  """
  The entropy is essentially: log_2 h
  """
  if h >= 2**N:
    raise ValueError("number of unique patterns can't be greater than number of all patterns")

  all_patterns = list(itertools.product([0,1], repeat=N)) # repeat means how many times to make a cartesian product : [0,1] x [0,1] x ...
  return random.sample(all_patterns, h)

In [6]:
def generate_ensemble(num_environments, N, h): 

  """
  num_environments should be equal to C_{2**N}^h, which may be too execessive.
  Because of the uniform distribution one can consider less number of envs in the ensemble for practical usage.
  """
  environments = []
  for _ in range(num_environments):
      environment = generate_environment(N, h)
      environments.append(environment)
  return environments

In [7]:
def connectivity(N, alpha):
  """
  alpha=alpha(N) and if N -> inf, then alpha -> alpha_0
  """
  K = int(alpha * N)
  return K

In [9]:
"""
Here I define projection functions from the paper: n and nu. They express the amount of information that is available for neuron during learning.
nu - is the normalized function n.

vector a is \in {0,1}^K
projection functions should be calculated for all the vectors a !
"""
def n_func(env, a, K):
    count = 0
    for x in env:
        match = all(x[k] == a[k] for k in range(K))  # all() returns True if all the first K bits are of x match the corresponding bits of a
        if match:
            count += 1
    return count

def nu_func(env, a, K, h):
    return n_func(env, a, K)/h

In [None]:
def distinguish(env1, env2, K, h):
    """
    Calculation of the distinguishability measure. Values between 0 and 1. Shows how the neuron distinguishes two envs.
    0 -- complete indistingubility
    1 -- complete distinguishibility
    """

    d = 0.5 * sum(abs(nu_func(env1, a, K, h) - nu_func(env2, a, K, h)) for a in itertools.product([0, 1], repeat=K))
    return d