In [None]:
# Import Packages
import time
import numpy as np

from matplotlib import pyplot as plt
import matplotlib.cm as cm
import matplotlib.gridspec as gridspec
import seaborn as sns

from tqdm import tqdm
import scipy
from scipy.special import factorial, factorial2, erfinv, erf

import jax
import jax.numpy as jnp
from jax import random, jit, vmap
from functools import partial
from math import comb

import matplotlib as mpl
from IPython.display import clear_output, display


import seaborn as sns
sns.set_style("whitegrid")
sns.set_context("poster")

key = jax.random.PRNGKey(2023)

# Helper Functions:

In [7]:
def generate_patterns(key, num_neurons, num_patterns):
    patterns = random.bernoulli(key, p=0.5, shape=(num_patterns, num_neurons)) 
    return jnp.where(patterns, 1, -1).astype('float32') # Change True to 1 and False to -1

def generate_correlated_patterns(key, eps, num_neurons, num_patterns):
    patterns = random.bernoulli(key, p=0.5 + 0.5*eps, shape=(num_patterns, num_neurons)) 
    return jnp.where(patterns, 1, -1).astype('float32') # Change True to 1 and False to -1

def generate_sequences(key, num_neurons, num_patterns, num_sequences=100):
    seq_list = []
    for s in range(num_sequences):
        key, _ = random.split(key)
        seq_list += [generate_patterns(key, num_patterns, num_neurons)]
    sequences = jnp.stack(seq_list).reshape(num_sequences, num_patterns, num_neurons)
    return sequences

# Theoretical Predictions:

In [8]:
### Polynomial DenseNet / MixedNet
def theory_PDN_trans(N, d):
    return N**d / (2 * factorial2(2*d-1) * jnp.log(N))

def theory_PDN_seq(N, d):
    return N**d / (2 * (d+1) * factorial2(2*d-1) * jnp.log(N))

def gamma(dS, dA, lambd):
    if dS < dA:
        return factorial2(2*dS-1)
    elif dS == dA:
        return (lambd**2 + 1)*factorial2(2*dS-1) + 2*lambd*factorial2(dS-1)**2  * (dS % 2 == 0)
    else:
        return lambd**2 * factorial2(2*dA-1)

def theory_PMN_trans(N, dA, dS, lambd):
    return (lambd-1)**2 / (2 * gamma(dS, dA, lambd)) * N**(np.min([dS, dA]))/ jnp.log(N)

def theory_PMN_seq(N, dA, dS, lambd):
    return (lambd-1)**2 / (2 * gamma(dS, dA, lambd) * np.min([dS,dA])) * N**(np.min([dS, dA]))/ jnp.log(N)

# Exponential DenseNet / MixedNet
beta = jnp.exp(2) / jnp.cosh(2)

def theory_EDN_trans(N):
    return beta**(N-1) / (2 * jnp.log(N))

def theory_EDN_seq(N):
    return beta**(N) / (2 * jnp.log(beta) * N)

def theory_EMN_trans(N, lambd):
    return (lambd-1)**2 / (lambd**2 + 1) * theory_EDN_trans(N)

def theory_EMN_seq(N, lambd):
    return (lambd-1)**2 / (lambd**2 + 1) * theory_EDN_seq(N)

# Model Definitions:

In [4]:
# Polynomial Networks
@jit
def poly_asym_update(patterns, state, dA): # Update rule for Polynomial DenseNet excluding self-coupling
    P, N = patterns.shape
    return jnp.einsum('ij,ij->j', jnp.roll(patterns,-1, axis=0), jnp.power( ((patterns@state).reshape(P,1) - (patterns*state)) / (N-1), dA))

batched_poly_asym_update = vmap((poly_asym_update), in_axes=(0, 0, None))

@jit
def poly_sym_update(patterns, state, dS): # Update rule for symmetric term in Exponential MixedNet excluding self-coupling
    P, N = patterns.shape
    return jnp.einsum('ij,ij->j', patterns, jnp.power( ((patterns@state).reshape(P,1) - (patterns*state)) / (N-1), dS))

batched_poly_sym_update = vmap((poly_sym_update), in_axes=(0, 0, None))

## Polynomial DenseNet
class PDN(object):      
    def initialize(self, sequences, tau, dA):
        self.sequences = sequences
        self.S, self.P, self.N = sequences.shape
        self.tau = tau
        self.dA = dA

    def predict(self, num_timestep):
        self.num_timestep = num_timestep
        
        # Copy to avoid call by reference 
        self.s_history = np.stack((self.sequences[:,-1,:], self.sequences[:,0,:]))
        self.alignment_history = np.zeros((self.num_timestep, self.S, self.P))
        
        # Define predict list
        for t in range(self.num_timestep - 1):            
            bitflips = self._run(t)
            if bitflips > 0:
                return False, jnp.array(self.s_history), jnp.array(self.alignment_history)
            else:
                continue

        return True, jnp.array(self.s_history), jnp.array(self.alignment_history)
        
    def _run(self, timestep):
        """
        Synchronous update
        """
        # Initialize in a pattern
        t = timestep
        states = self.s_history[-1]
                
        # Update network 
        h = batched_poly_asym_update(self.sequences, states, self.dA)
        states = jnp.sign(h)

        # Compute Network Alignment and add to history
        alignments = jnp.einsum('ijk,ik->ij', self.sequences, states) / (self.N)
        
        self.alignment_history[t,:,:] = alignments
        self.s_history = np.vstack((self.s_history, states[None]))

        # Compute bitflips
        bitflips = jnp.sum(jnp.absolute(states - self.sequences[:,t+1])) / 2
        return bitflips

## Polynomial MixedNet for tau = 1
class PMN(object):      
    def initialize(self, sequences, lambd, tau, dS, dA):
        self.sequences = sequences
        self.S, self.P, self.N = sequences.shape
        self.Q = self.P - 1
        self.lambd = lambd
        self.tau = tau
        self.dS = dS
        self.dA = dA

    def predict(self, num_timestep):
        self.num_timestep = num_timestep
        
        # Copy to avoid call by reference 
        self.s_history = np.stack((self.sequences[:,-1,:], self.sequences[:,0,:]))
        self.alignment_history = np.zeros((self.num_timestep, self.S, self.P))
        
        # Define predict list
        for t in range(self.num_timestep - 1):            
            bitflips = self._run(t)
            if bitflips > 0:
                return False, np.array(self.s_history), np.array(self.alignment_history)
            else:
                continue

        return True, np.array(self.s_history), np.array(self.alignment_history)
        
    def _run(self, timestep):
        """
        Synchronous update
        """
        # Initialize in a pattern
        t = timestep
        states = self.s_history[-1]
        
        # Compute s_avg
        # states_avg = jnp.mean(self.s_history[-self.tau:], axis=0)
        
        # Update network 
        h_1 = batched_poly_sym_update(self.sequences, states, self.dS)
        h_2 = batched_poly_asym_update(self.sequences, states, self.dA)
        states = jnp.sign(h_1 + self.lambd * h_2)
        
        alignments = (jnp.einsum('ijk,ik->ij', self.sequences, states) / self.N)
        
        self.alignment_history[t,:,:] = alignments
        self.s_history = np.vstack((self.s_history, states[None]))
        
        # Compute bitflips
        bitflips = jnp.sum(jnp.absolute(states - self.sequences[:,t+1])) / 2
        return bitflips
    
# Exponential Networks

@jit
def exp_asym_update(patterns, state): # Update rule for Exponential DenseNet excluding self-coupling
    P, N = patterns.shape
    return jnp.einsum( 'ij,ij->j', jnp.roll(patterns,-1, axis=0), jnp.exp( ( (patterns@state).reshape(P,1) - (patterns*state) ) - (N-1) ) )

batched_exp_asym_update = vmap((exp_asym_update), in_axes=(0, 0))

@jit
def exp_sym_update(patterns, state): # Update rule for symmetric term in Exponential MixedNet excluding self-coupling
    P, N = patterns.shape
    return jnp.einsum('ij,ij->j', patterns, jnp.exp( ((patterns@state).reshape(P,1) - (patterns*state)) - (N-1)))

batched_exp_sym_update = vmap((exp_sym_update), in_axes=(0, 0))


## Exponential DenseNet
class EDN(object):      
    def initialize(self, sequences, tau):
        self.sequences = sequences
        self.S, self.P, self.N = sequences.shape
        self.lambd = lambd
        self.tau = tau

    def predict(self, num_timestep):
        self.num_timestep = num_timestep
        
        # Copy to avoid call by reference 
        self.s_history = np.stack((self.sequences[:,-1,:], self.sequences[:,0,:]))
        self.alignment_history = np.zeros((self.num_timestep, self.S, self.P))
        
        # Define predict list
        for t in range(self.num_timestep - 1):   
            bitflips = self._run(t)
            if bitflips > 0:
                return False, np.array(self.s_history), np.array(self.alignment_history)
            else:
                continue

        return True, np.array(self.s_history), np.array(self.alignment_history)
        
    def _run(self, timestep):
        """
        Synchronous update
        """
        # Initialize in a pattern
        t = timestep
        states = self.s_history[-1]

        # Update network 
        h = batched_exp_asym_update(self.sequences, states)
        states = jnp.sign(h)

        # Compute Network Alignment and add to history
        alignments = jnp.einsum('ijk,ik->ij', self.sequences, states) / (self.N)
        
        self.alignment_history[t,:,:] = alignments
        self.s_history = np.vstack((self.s_history, states[None]))

        # Compute bitflips
        bitflips = jnp.sum(jnp.absolute(states - self.sequences[:,t+1])) / 2
        return bitflips
    
## Exponential MixedNet for tau = 1
class EMN(object):      
    def initialize(self, sequences, lambd, tau):
        self.sequences = sequences
        self.S, self.P, self.N = sequences.shape
        self.lambd = lambd
        self.tau = tau

    def predict(self, num_timestep):
        self.num_timestep = num_timestep
        
        # Copy to avoid call by reference 
        self.s_history = np.stack((self.sequences[:,-1,:], self.sequences[:,0,:]))
        self.alignment_history = np.zeros((self.num_timestep, self.S, self.P))
        
        # Define predict list
        for t in range(self.num_timestep - 1):   
            bitflips = self._run(t)
            if bitflips > 0:
                return False, np.array(self.s_history), np.array(self.alignment_history)
            else:
                continue

        return True, np.array(self.s_history), np.array(self.alignment_history)
        
    def _run(self, timestep):
        """
        Synchronous update
        """
        # Initialize in a pattern
        t = timestep
        states = self.s_history[-1]
        
        # Compute s_avg
        # states_avg = jnp.mean(self.s_history[-self.tau:], axis=0)

        # Update network 
        h_1 = batched_exp_sym_update(self.sequences, states) # symmetric term
        h_2 = batched_exp_asym_update(self.sequences, states) # asymmetric term
        states = jnp.sign(h_1 + self.lambd * h_2)

        # Compute Network Alignment and add to history
        alignments = jnp.einsum('ijk,ik->ij', self.sequences, states) / (self.N)
        
        self.alignment_history[t,:,:] = alignments
        self.s_history = np.vstack((self.s_history, states[None]))

        # Compute bitflips
        bitflips = jnp.sum(jnp.absolute(states - self.sequences[:,t+1])) / 2
        return bitflips

# Transition Simulations:

In [5]:
# Polynomial Networks

## Polynomial DenseNet
@jit
def compute_bitflips_PDN(state_pair, patterns, dA):
    state, correct_state = state_pair[0], state_pair[1]
    
    # Predict Next State Using Update Rule
    h = poly_asym_update(patterns, state, dA)
    predictedState = jnp.sign(h)

    # Return Number of Bit Flips
    return jnp.sum(jnp.abs(predictedState - correct_state))/2

batched_compute_bitflips_PDN = vmap(compute_bitflips_PDN, in_axes=(0, None, None))

def simulate_PDN_transitions(key, N, P, dA):    
    while True:
        # Generate Random Bernoulli Patterns
        key, _ = random.split(key)
        patterns = generate_patterns(key, N, P) 
        
        # Pair Current State with Next State
        state_pairs = jnp.stack((patterns, jnp.roll(patterns,-1, axis=0)), axis = 1)

        # Calculate Number of Bit Flips (errors) in each transition
        bitflips_per_pattern = batched_compute_bitflips_PDN(state_pairs, patterns, dA)

        # Return Total number of bit flips
        total_bitflips = jnp.sum(bitflips_per_pattern)
        if total_bitflips == 0:
            return P
        else:
            P = int(0.99*P)

## Polynomial MixedNet
@jit
def compute_bitflips_PMN(state_pair, patterns, dS, dA, lambd):
    state, correct_state = state_pair[0], state_pair[1]
    
    # Predict Next State Using Update Rule
    h_1 = poly_sym_update(patterns, state, dS)
    h_2 = poly_asym_update(patterns, state, dA)
    predictedState = jnp.sign(h_1 + lambd * h_2)

    # Return Number of Bit Flips
    return jnp.sum(jnp.abs(predictedState - correct_state))/2

batched_compute_bitflips_PMN = vmap(compute_bitflips_PMN, in_axes=(0, None, None, None, None))

def simulate_PMN_transitions(key, N, P, dA, dS, lambd):    
    while True:
        # Generate Random Bernoulli Patterns
        key, _ = random.split(key)
        patterns = generate_patterns(key, N, P) 
        
        # Pair Current State with Next State
        state_pairs = jnp.stack((patterns, jnp.roll(patterns,-1, axis=0)), axis = 1)

        # Calculate Number of Bit Flips (errors) in each transition
        bitflips_per_pattern = batched_compute_bitflips_PMN(state_pairs, patterns, dS, dA, lambd)

        # Return Total number of bit flips
        total_bitflips = jnp.sum(bitflips_per_pattern)
        if total_bitflips == 0:
            return P
        else:
            P = int(0.99*P)

## Exponential DenseNet
@jit
def compute_bitflips_EDN(state_pair, patterns):
    state, correct_state = state_pair[0], state_pair[1]
    
    # Predict Next State Using Update Rule
    h = exp_asym_update(patterns, state)
    predictedState = jnp.sign(h)

    # Return Number of Bit Flips
    return jnp.sum(jnp.abs(predictedState - correct_state))/2

batched_compute_bitflips_EDN = vmap(compute_bitflips_EDN, in_axes=(0, None))

def simulate_EDN_transitions(key, N, P):    
    if N == 1:
        return 0

    while True:
        # Generate Random Bernoulli Patterns
        key, _ = random.split(key)
        patterns = generate_patterns(key, N, P) 
        
        # Pair Current State with Next State
        state_pairs = jnp.stack((patterns, jnp.roll(patterns,-1, axis=0)), axis = 1)

        # Calculate Number of Bit Flips (errors) in each transition
        bitflips_per_pattern = batched_compute_bitflips_EDN(state_pairs, patterns)

        # Return Total number of bit flips
        total_bitflips = jnp.sum(bitflips_per_pattern)
        if total_bitflips == 0:
            return P
        else:
            P = int(0.99*P)

## Exponential MixedNet
@jit
def compute_bitflips_EMN(state_pair, patterns, lambd):
    state, correct_state = state_pair[0], state_pair[1]
    
    # Predict Next State Using Update Rule
    h_1 = exp_sym_update(patterns, state)
    h_2 = exp_asym_update(patterns, state)
    predictedState = jnp.sign(h_1 + lambd * h_2)

    # Return Number of Bit Flips
    return jnp.sum(jnp.abs(predictedState - correct_state))/2

batched_compute_bitflips_EMN = vmap(compute_bitflips_EMN, in_axes=(0, None, None))

def simulate_EMN_transitions(key, N, P, lambd):    
    if N == 1:
        return 0

    while True:
        # Generate Random Bernoulli Patterns
        key, _ = random.split(key)
        patterns = generate_patterns(key, N, P) 
        
        # Pair Current State with Next State
        state_pairs = jnp.stack((patterns, jnp.roll(patterns,-1, axis=0)), axis = 1)

        # Calculate Number of Bit Flips (errors) in each transition
        bitflips_per_pattern = batched_compute_bitflips_EMN(state_pairs, patterns, lambd)

        # Return Total number of bit flips
        total_bitflips = jnp.sum(bitflips_per_pattern)
        if total_bitflips == 0:
            return P
        else:
            P = int(0.99*P)

In [None]:
# Polynomial DenseNet
lambd = 2.5
num_trials = 20
ds = np.linspace(1,4,4).astype(int)
Ns = np.linspace(10,100,10).astype(int)
PDN_transition_capacity = np.zeros((len(ds), len(Ns), num_trials))

for i_dA, dA in enumerate(ds):
    for i_N, N in enumerate(Ns):
        display(f'PDN: dA = {dA}, N = {N}, current P = {PDN_transition_capacity[i_dA, i_N]}')
        P = round(theory_PDN_trans(N, dA) * 2)
        
        for T in range(num_trials):
            if T > 4:
                P = int(np.max(PDN_transition_capacity[i_dA, i_N]) * 1.25)
            key, _ = random.split(key)
            PDN_transition_capacity[i_dA, i_N, T] = simulate_PDN_transitions(key, N, P, dA)

            display(f'PDN: dA = {dA}, N = {N}, T = {T}, final P = {PDN_transition_capacity[i_dA, i_N, T]}')
            np.save('final_logs/PDN_transition_capacity', PDN_transition_capacity)            
        
np.save('final_logs/PDN_transition_capacity', PDN_transition_capacity)

# Polynomial MixedNet
PMN_transition_capacity = np.zeros((len(ds), len(ds), len(Ns), num_trials))

for i_dS, dS in enumerate(ds):
    for i_dA, dA in enumerate(ds):
        for i_N, N in enumerate(Ns):
            display(f'PMN: dS = {dS}, dA = {dA}, N = {N}, current P = {PMN_transition_capacity[i_dS, i_dA, i_N]}')
            P = round(theory_PMN_trans(N, dA, dS, lambd) * 2)

            for T in range(num_trials):
                if T > 4:
                    P = int(np.max(PMN_transition_capacity[i_dS, i_dA, i_N]) * 1.25)
                key, _ = random.split(key)
                PMN_transition_capacity[i_dS, i_dA, i_N, T] = simulate_PMN_transitions(key, N, P, dS, dA, lambd)

                display(f'PMN: dA = {dA}, N = {N}, T = {T}, final P = {PMN_transition_capacity[i_dS, i_dA, i_N, T]}')
                np.save('final_logs/PMN_transition_capacity', PMN_transition_capacity)            
        
np.save('final_logs/PMN_transition_capacity', PMN_transition_capacity)

# Exponential DenseNet
lambd = 2.5
num_trials  = 20
Ns = np.linspace(1,25,25).astype(int)
EDN_transition_capacity = np.zeros((len(Ns), num_trials))

for i_N, N in enumerate(Ns):
    for T in range(num_trials):
        P = round(theory_EDN_trans(N) * 2)
        if T > 4:
            P = int(jnp.max(EDN_transition_capacity[int(i_N)]) * 1.25)
        key, _ = random.split(key)
        EDN_transition_capacity[i_N, T] = simulate_EDN_transitions(key, N, P, lambd)
        display(f'EDN: dA = {dA}, N = {N}, T = {T}, final P = {EDN_transition_capacity[i_dA, i_N, T]}')
        np.save('final_logs/EDN_transition_capacity', EDN_transition_capacity)    

np.save('final_logs/EDN_transition_capacity', EDN_transition_capacity)

# Exponential MixedNet
EMN_transition_capacity = np.zeros((len(Ns), num_trials))

for i_N, N in enumerate(Ns):
    for T in range(num_trials):
        P = round(theory_EMN_trans(N, lambd) * 2)
        if T > 4:
            P = int(jnp.max(EMN_transition_capacity[int(i_N)]) * 1.25)
        key, _ = random.split(key)
        EDN_transition_capacity[i_N, T] = simulate_EMN_transitions(key, N, P, lambd)
        print(f'Number of Neurons = {N}, Initial Number of Patterns = {int(P)}, Final Number of Patterns = {EMN_transition_capacity[i_N, T]}, Trial = {T}')
        np.save('final_logs/EMN_transition_capacity', EMN_transition_capacity)    

np.save('final_logs/EMN_transition_capacity', EMN_transition_capacity)

# Sequence Simulations:

In [6]:
# Simulate Polynomial Networks

# Simulate Polynomial DenseNet
def simulate_PDN_sequences(key, N, tau, dA, P, S):        
    while True:
        # Generate Random Rademacher Patterns
        key, _ = random.split(key)
        sequences = generate_sequences(key, N, P, S) 

        timesteps = P * tau
        model = PDN()
        model.initialize(sequences, tau, dA)
        success, system_states, system_alignments = model.predict(timesteps)

        if success == False:
            P = int(0.99*P)
        else:
            return P

## Simulate Polynomial MixedNet
def simulate_PMN_sequences(key, N, lambd, tau, dS, dA, P):        
    while True:
        # Generate Random Rademacher Patterns
        key, _ = random.split(key)
        sequences = generate_sequences(key, N, P, S) 

        timesteps = P * tau
        model = PMN()
        model.initialize(sequences, lambd, tau, dS, dA)
        success, system_states, system_alignments = model.predict(timesteps)

        if success == False:
            P = int(0.99*P)
        else:
            return P

# Simulate Exponential Networks

## Simulate Exponential DenseNet
def simulate_EDN_sequences(key, N, tau, P, S):    
    if N == 1:
        return 0
    
    while True:
        # Generate Random Rademacher Patterns
        key, _ = random.split(key)
        sequences = generate_sequences(key, N, P, S) 

        timesteps = P * tau
        model = EDN()
        model.initialize(sequences, tau)
        success, system_states, system_alignments = model.predict(timesteps)

        if success == False:
            P = int(0.99*P)
        else:
            return P
        
## Simulate Exponential MixedNet
def simulate_EMN_sequences(key, N, lambd, tau, P, S):    
    if N == 1:
        return 0
    
    while True:
        # Generate Random Rademacher Patterns
        key, _ = random.split(key)
        sequences = generate_sequences(key, N, P, S) 

        timesteps = P * tau
        model = EMN()
        model.initialize(sequences, lambd, tau)
        success, system_states, system_alignments = model.predict(timesteps)

        if success == False:
            P = int(0.99*P)
        else:
            return P

In [None]:
num_trials = 20
lambd = 2.5
tau = 1
S = 100

## Polynomial DenseNet
ds = np.linspace(1,4,4).astype('int')
Ns = np.linspace(10,100,10).astype('int')
PDN_sequence_capacity = np.zeros((len(ds), len(Ns), num_trials))

for i_dA, dA in enumerate(ds):
    for i_N, N in enumerate(Ns):
        display(f'PDN: dA = {dA}, N = {N}, current P = {PDN_sequence_capacity[i_dA, i_N]}')
        P = round(theory_PDN_seq(N, dA) * 2)
        
        for T in range(num_trials):
            if T > 4:
                P = int(np.max(PDN_sequence_capacity[dA-1, int(N/10)-1]) * 1.25)
            key, _ = random.split(key)
            PDN_sequence_capacity[i_dA, i_N, T] = simulate_PDN_sequences(key, N, tau, dA, P, S)

            clear_output(wait=True)
            display(f'PDN: dA = {dA}, N = {N}, T = {T}, current P = {PDN_sequence_capacity[i_dA, i_N, T]}')
            np.save('final_logs/PDN_sequence_capacity', PDN_sequence_capacity)            
        
np.save('final_logs/PDN_sequence_capacity', PDN_sequence_capacity)

## Polynomial MixedNet
ds = np.linspace(1,4,4).astype('int')
Ns = np.linspace(10,100,10).astype('int')
PMN_sequence_capacity = np.zeros((len(ds), len(ds), len(Ns), num_trials))    

for i_dS, dS in enumerate(ds):
    for i_dA, dA in enumerate(ds):
        for i_N, N in enumerate(Ns):
            display(f'PMN: dS = {dS}, dA = {dA}, N = {N}, current P = {PMN_sequence_capacity[i_dS, i_dA, i_N]}')

            P = round(theory_PMN_seq(N, dA, dS, lambd) * 2)

            for T in range(num_trials):
                if T > 4:
                    P = int(np.max(PMN_sequence_capacity[dS-1, dA-1, i_N]) * 1.25)
                key, _ = random.split(key)
                PMN_sequence_capacity[i_dS, i_dA, i_N, T] = simulate_PMN_sequences(key, N, lambd, tau, dS, dA, P)

                clear_output(wait=True)
                display(f'PMN: dS = {dS}, dA = {dA}, N = {N}, T = {T}, current P = {PMN_sequence_capacity[i_dS, i_dA, i_N, T]}')
                np.save('final_logs/PMN_sequence_capacity', PMN_sequence_capacity)            
        
np.save('final_logs/PMN_sequence_capacity', PMN_sequence_capacity)

## Exponential DenseNet
Ns = np.linspace(1,25,25).astype('int')
EDN_sequence_capacity = np.zeros((len(Ns), num_trials))

for i_N, N in enumerate(Ns):
    display(f'EDN: N = {N}, current P = {EDN_sequence_capacity[i_N]}')
    P = round(theory_EDN_seq(N) * 2)
    
    for T in range(num_trials):
        if T > 4:
            P = int(np.max(EDN_sequence_capacity[N-1]) * 1.25)
        key, _ = random.split(key)
        EDN_sequence_capacity[N-1, T] = simulate_EDN_sequences(key, N, tau, P, S)

        clear_output(wait=True)
        display(f'EDN: N = {N}, T = {T}, current P = {EDN_sequence_capacity[N-1, T]}')
        np.save('final_logs/EDN_sequence_capacity', EDN_sequence_capacity)            
        
np.save('final_logs/EDN_sequence_capacity', EDN_sequence_capacity)

## Exponential MixedNet
Ns = np.linspace(1,25,25).astype('int')
EMN_sequence_capacity = np.zeros((len(Ns), num_trials))

for i_N, N in enumerate(Ns):
    display(f'EMN: N = {N}, current P = {EMN_sequence_capacity[i_N]}')
    P = round(theory_EMN_seq(N, lambd) * 2)

    for T in range(num_trials):
        if T > 4:
            P = int(np.max(EMN_sequence_capacity[N-1]) * 1.25)
        key, _ = random.split(key)
        EMN_sequence_capacity[N-1, T] = simulate_EMN_sequences(key, N, lambd, tau, P, S)

        clear_output(wait=True)
        display(f'EMN: N = {N}, T = {T}, current P = {EMN_sequence_capacity[i_N, T]}')
        np.save('final_logs/EMN_sequence_capacity', EMN_sequence_capacity)            
        
np.save('final_logs/EMN_sequence_capacity', EMN_sequence_capacity)

# Generalized Pseudoinverse Rule:

In [7]:
# Polynomial DenseNet with GPI
@jit
def poly_asym_GPI_update(patterns, PI, state, N, dA):
    return jnp.diag(jnp.tensordot(jnp.roll(patterns,-1, axis=0), jnp.power( PI @ (((patterns@state).reshape(P,1) - (patterns*state)) / (N-1)), dA), axes = ((0, 0))))

class GPI_PDN(object):      
    def initialize(self, patterns, tau, dA):
        self.patterns = patterns
        self.N = patterns.shape[1]
        self.P =  patterns.shape[0]
        self.tau = tau
        self.dA = dA
        O = jnp.tensordot(patterns.astype(float), patterns, axes = ((1, 1))) / N 
        self.PI = jnp.linalg.pinv(O.astype(float), hermitian=True)
        


    def predict(self, num_timestep=1):
        self.num_timestep = num_timestep
        
        # Copy to avoid call by reference 
        self.s_history = np.vstack((np.tile(self.patterns[-1], (self.tau,1)), self.patterns[0]))
        self.alignment_history = np.zeros((self.P, self.num_timestep))
        
        # Define predict list
        for t in range(self.num_timestep - 1):            
            bitflips = self._run(t)
            if bitflips > 0:
                return False, np.array(self.s_history), np.array(self.alignment_history)

        return True, np.array(self.s_history), np.array(self.alignment_history)
        
    def _run(self, timestep):
        """
        Synchronous update
        """
        # Initialize in a pattern
        t = timestep
        s = self.s_history[-1]
        
        # Update network 
        h = poly_asym_GPI_update(self.patterns, self.PI, s, self.N, self.dA)
        s = jnp.sign(h)

        # Compute Network Alignment and add to history
        alignment = np.tensordot(self.patterns, s, 1) / self.N
        self.alignment_history[:,t] = alignment
        self.s_history = np.vstack((self.s_history,s))        
        
        # Compute bitflips
        bitflips = jnp.sum(jnp.absolute(s - self.patterns[timestep+1])) / 2
        return bitflips

# Exponential DenseNet with GPI
@jit
def exp_asym_GPI_update(patterns, PI, state, N):
    return jnp.diag(jnp.tensordot(jnp.roll(patterns,-1, axis=0), jnp.exp( PI @ ((patterns@state).reshape(P,1) - (patterns*state) - (N-1))), axes = ((0, 0))))

class GPI_EDN(object):      
    def initialize(self, patterns, tau):
        self.patterns = patterns
        self.N = patterns.shape[1]
        self.P =  patterns.shape[0]
        self.tau = tau
        O = jnp.tensordot(patterns.astype(float), patterns, axes = ((1, 1))) / N
        self.PI = jnp.linalg.pinv(O.astype(float), hermitian=True)
        x = self.PI

    def predict(self, num_timestep=1):
        self.num_timestep = num_timestep
        
        # Copy to avoid call by reference 
        self.s_history = np.vstack((np.tile(self.patterns[-1], (self.tau,1)), self.patterns[0]))
        self.alignment_history = np.zeros((self.P, self.num_timestep))
        
        # Define predict list
        for t in range(self.num_timestep - 1):            
            bitflips = self._run(t)
            if bitflips > 0:
                return False, np.array(self.s_history), np.array(self.alignment_history)

        return True, np.array(self.s_history), np.array(self.alignment_history)
  
        
    def _run(self, timestep):
        """
        Synchronous update
        """
        # Initialize in a pattern
        t = timestep
        s = self.s_history[-1]
        
        # Update network 
        h_2 = exp_asym_GPI_update(self.patterns, self.PI, s, self.N)
        s = jnp.sign(h_2)

        # Compute Network Alignment and add to history
        alignment = np.tensordot(self.patterns, s, 1) / self.N
        self.alignment_history[:,t] = alignment
        self.s_history = np.vstack((self.s_history,s))        
        
        # Compute bitflips
        bitflips = jnp.sum(jnp.absolute(s - self.patterns[timestep+1])) / 2
        return bitflips

In [None]:
# GPI_PDN
N = 100
tau = 1
num_trials = 20
eps = np.linspace(0, 0.95, 20)
ds = np.linspace(1, 4, 4).astype(int)
GPI_PDN_capacity = np.zeros((len(ds), len(eps), num_trials))


for i_d, d in enumerate(ds):
    for i_e, e in enumerate(eps):

        display(f'GPI_PDN: N = {N}, d = {d}, eps = {e:.2f}, current P = {GPI_PDN_capacity[i_d, i_e]}')
        P = round(theory_PDN_seq(N, d) * 20)
        
        for T in range(num_trials):
            if T > 4:
                P = int(jnp.max(GPI_PDN_capacity[i_d, i_e])*1.25)
            
            while True:
                key, _ = random.split(key)
                timesteps = P * tau
                patterns = generate_correlated_patterns(key, e, N, P)
                            
                model = GPI_PDN()
                model.initialize(patterns, tau, d)
                success, system_states, system_alignments = model.predict(timesteps)

                if success == True:
                    GPI_PDN_capacity[i_d, i_e, T] = system_alignments.shape[0]
                    
                    clear_output(wait=True)
                    display(f'End: d = {d}, N = {N}, eps = {e:.2f}, T = {T}, final P = {GPI_PDN_capacity[i_d, i_e, T]}')
                    
                    # Save results
                    np.save('final_logs/GPI_PDN_sequence_capacity', GPI_PDN_capacity)            
                    break 

                else:
                    P = int(P*0.99)
                    
                    
# GPI_EDN
N = 20
tau = 1
num_trials = 20
eps = np.linspace(0, 0.95, 20)
GPI_EDN_capacity = np.zeros((len(eps), num_trials))



for i_e, e in enumerate(eps):
    display(f'GPI_EDN: N = {N}, eps = {e:.2f}, current P = {GPI_EDN_capacity[i_e]}')
    P = round(theory_EDN_seq(N) * 1)

    for T in range(num_trials):
        if T > 4:
            P = int(jnp.max(GPI_EDN_capacity[i_e])*1.25)

        while True:
            key, _ = random.split(key)
            timesteps = P * tau
            patterns = generate_correlated_patterns(key, e, N, P)

            model = GPI_EDN()
            model.initialize(patterns, tau)
            success, system_states, system_alignments = model.predict(timesteps)

            if success == True:
                GPI_EDN_capacity[i_e, T] = system_alignments.shape[0]

                clear_output(wait=True)
                display(f'End: d = {d}, N = {N}, eps = {e:.2f}, T = {T}, final P = {GPI_EDN_capacity[i_e, T]}')

                # Save results
                np.save('final_logs/GPI_EDN_sequence_capacity', GPI_EDN_capacity)            
                break 

            else:
                P = int(P*0.99)

# Maximal Degree for Polynomial DenseNet

In [9]:
def max_d_trans(N, d):
    while True:
        cond = (N**d / np.log(N)) < (2 * factorial2(2*d-1)) 

        if cond:
            d = int(0.99*d)
        else:
            return d
        
def max_d_seq(N, d):
    while True:
        cond = (N**d / np.log(N)) < (2 * (d+1) * factorial2(2*d-1)) 

        if cond:
            d = int(0.99*d)
        else:
            return d

In [None]:
Ns = [10, 15, 20]

num_trials = 20
ds = list(range(1, 2*max_d_trans(20,100)))
max_degree = np.zeros((len(Ns), len(ds), num_trials))

for i_N, N in enumerate(Ns):
    ds = list(range(1, 2*max_d_trans(N,100)))
    for i_d, d in enumerate(ds):
        P = round(theory_PDN_trans(N, d) * 2.5)
        for T in range(num_trials):            
            key, _ = random.split(key)
            max_degree[i_N, i_d, T] = simulate_PDN_transitions(key, N, P, d)

            clear_output(wait=True)
            display(f'PAHN: N = {N}, d = {d}, T = {T}, current P = {max_degree[i_N, i_d, T]}')
    np.save(f'final_logs/max_degree_capacity', max_degree)

# MNIST Experiments

In [20]:
# Import Packages
import time
import numpy as np

from matplotlib import pyplot as plt
import matplotlib.cm as cm
import matplotlib.gridspec as gridspec
from tqdm import tqdm
import tqdm.notebook as tq
import scipy
from scipy import signal
from scipy.special import factorial2, erfinv, comb
from IPython.display import clear_output, display

import seaborn as sns
sns.set_style("white")
sns.set_context("poster")

import jax
import jax.numpy as jnp
from jax import random, jit, vmap, pmap

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

key = random.PRNGKey(2023)

In [21]:
# Import MNIST dataset and convert it to Black and White for Binary Image Sequence Recall
np.random.seed(2023)

(train_X, train_y), (test_X, test_y) = mnist.load_data()
nclasses = np.unique(train_y).size
threshold = 125
def shapex(X):
    XX = np.empty_like(X)
    XX[X< threshold] = 0
    XX[X>=threshold] = 1
    XX = XX.reshape(*XX.shape, 1) 
    return XX
train_X = shapex(train_X)
test_X = shapex(test_X)
train_y = to_categorical(train_y)
test_y = to_categorical(test_y)

# Reshape to form a sequence of images in order 0 -> 9 and then repeating
num_samples = 2500
numbers = []
numbers += [train_X[np.where(train_y[:,i] == 1)][:2500] for i in range(10)]
number_array = np.array(numbers)

np_seq = np.empty(((number_array.shape[0] * number_array.shape[1], number_array.shape[2], number_array.shape[3], number_array.shape[4])), dtype=number_array.dtype)
for i in range(10):
    np_seq[i::10] = number_array[i]

images = jnp.where(np_seq.reshape(2500 * 10, 28*28), 1, -1).astype('int16')
image_pairs = jnp.stack((images, jnp.roll(images,-1, axis=0)), axis = 1)

In [24]:
# Define Update Rules
@jit
def predict_image_poly(state_pair, patterns, d = 10):
    N = patterns.shape[1]
    state, correct_state = state_pair[0], state_pair[1]
    
    # Predict Next State Using Update Rule
    h_2 = jnp.tensordot(jnp.roll(patterns,-1, axis=0), jnp.power(patterns @ state / N, d), axes = ((0), (0))) 
    predicted_state = jnp.sign(h_2)

    # Return Number of Bit Flips
    return predicted_state
    # return predicted_state, jnp.sum(jnp.abs(predicted_state - correct_state))/2

batched_predict_image_poly = vmap(predict_image_poly, in_axes=(0, None, None))

@jit
def predict_image_exp(state_pair, patterns):
    N = patterns.shape[1]
    state, correct_state = state_pair[0], state_pair[1]
    
    # Predict Next State Using Update Rule
    h_2 = jnp.tensordot(jnp.roll(patterns,-1, axis=0), jnp.exp(patterns @ state - N), axes = ((0), (0))) 
    predicted_state = jnp.sign(h_2)

    # Return Number of Bit Flips
    return predicted_state
    # return predicted_state, jnp.sum(jnp.abs(predicted_state - correct_state))/2

batched_predict_image_exp = vmap(predict_image_exp, in_axes=(0, None))

In [None]:
true_images = image_pairs[:,1,:]
ds = [1, 5, 25]
ts = [0, 101, 202, 303]
fig, axes = plt.subplots(nrows=4, ncols=5, figsize=(20, 18))
for t_i, t in enumerate(ts):
    axes[t_i, 0].imshow(true_images[t].reshape(28,28), cmap='Greys')
    axes[t_i, 0].tick_params(left = False, right = False , labelleft = False, labelbottom = False, bottom = False)
    if t_i == 0:
        axes[t_i, 0].set_title('True', fontsize=40)

poly_predicted_images = [batched_predict_image_poly(image_pairs, images, d) for d in ds]
for d_i, d in enumerate(ds):
    predicted_images = poly_predicted_images[d_i]
    print(f'Degree = {d}')
    for t_i, t in enumerate(ts):
        axes[t_i, d_i+1].imshow(predicted_images[t].reshape(28,28), cmap='Greys')
        axes[t_i, d_i+1].tick_params(left = False, right = False , labelleft = False, labelbottom = False, bottom = False)
        if t_i == 0:
            axes[t_i, d_i+1].set_title(f'Poly: $d = {d}$', fontsize=40)
            if d == 1:
                axes[t_i, d_i+1].set_title(f'SeqNet', fontsize=40)


print(f'Exp')
exp_predicted_images = batched_predict_image_exp(image_pairs, images)
for t_i, t in enumerate(ts):
    axes[t_i, -1].imshow(exp_predicted_images[t].reshape(28,28), cmap='Greys')
    axes[t_i, -1].tick_params(left = False, right = False , labelleft = False, labelbottom = False, bottom = False)
    if t_i == 0:
        axes[t_i, -1].set_title(f'Exp', fontsize=40)

fig.suptitle('Image Sequence Recall', fontsize=35)
fig.tight_layout()
fig.savefig('plots/Image_Sequence.pdf', format = 'pdf', dpi=300)

# Excess Kurtosis Experiments

In [17]:
from scipy.stats import kurtosis

# Function for calculating the xtalk_all and all_stored values
def calculate_values(ind_n, ind_p, n_list, p_list, n_rep):
    xi1 = 2 * (np.random.rand(n_rep, p_list[ind_p] - 1) > 0.5) - 1
    bin1 = np.random.binomial(n_list[ind_n] - 1, 0.5, (n_rep, p_list[ind_p] - 1))

    xtalk = np.sum(xi1 * np.exp(2 * bin1 - 2 * (n_list[ind_n] - 1)), axis=1)
    return xtalk, xtalk > -1

In [None]:
n_list = np.arange(10, 21)
p_list = np.floor(10 ** np.arange(2, 4.25, 0.25)).astype(int)
n_rep = 50000

b = np.exp(2) / np.cosh(2)

xtalk_all = np.zeros((len(n_list), len(p_list), n_rep))
all_stored = np.zeros((len(n_list), len(p_list), n_rep))

for ind_n in range(len(n_list)):
    for ind_p in range(len(p_list)):
        xtalk_all[ind_n, ind_p], all_stored[ind_n, ind_p] = calculate_values(ind_n, ind_p, n_list, p_list, n_rep)
        np.save('final_logs/edn_crosstalk', xtalk_all)
np.save('final_logs/edn_crosstalk', xtalk_all)