In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')
from copy import deepcopy

In [None]:
!pip install brian2
!pip install neurodynex3
from brian2 import *

## Hopfield network model

In this simplified Hopfied model, each neuron can only take values +1 or -1. The network stores pixel patterns, and attempts to retrieve them given some cue in the form of a part of the original pattern used to set an initial state $S_i(t=0)$ for every neuron in the network. The network activation of all the neurons evolves as follows

$$S_i(t+1) = sgn\left(\sum_j{w_{ij} S_j(t)}\right), $$

where the weight of every synaptic connection is calculated as,

$$w_{ij} = \frac{1}{N} \sum_{\mu}{p_i^\mu p_j^\mu}, $$

where in turn, $N$ is the number of neurons, and $p_i^\mu$ is the state of the $i^{th}$ neuron for encoding pattern $\mu$.

The code below is taken from this [link](https://neuronaldynamics-exercises.readthedocs.io/en/latest/exercises/hopfield-network.html). Follow the accompanying exercises to be better prepared for your assignment. Remember to install all needed libraries before trying to run it. 

In [None]:
%matplotlib inline
from neurodynex3.hopfield_network import network, pattern_tools, plot_tools

pattern_size = 4

# create an instance of the class HopfieldNetwork
hopfield_net = network.HopfieldNetwork(nr_neurons= pattern_size**2)
# instantiate a pattern factory
factory = pattern_tools.PatternFactory(pattern_size, pattern_size)
# create a checkerboard pattern and add it to the pattern list
checkerboard = factory.create_checkerboard()
pattern_list = [checkerboard]

# add random patterns to the list
pattern_list.extend(factory.create_random_pattern_list(nr_patterns=4, on_probability=0.5))
plot_tools.plot_pattern_list(pattern_list)
# how similar are the random patterns and the checkerboard? Check the overlaps
overlap_matrix = pattern_tools.compute_overlap_matrix(pattern_list)
#plot_tools.plot_overlap_matrix(overlap_matrix)

# let the hopfield network "learn" the patterns. Note: they are not stored
# explicitly but only network weights are updated !
hopfield_net.store_patterns(pattern_list)

# create a noisy version of a pattern and use that to initialize the network
noisy_init_state = pattern_tools.flip_n(checkerboard, nr_of_flips=4)
hopfield_net.set_state_from_pattern(noisy_init_state)

# from this initial state, let the network dynamics evolve.
states = hopfield_net.run_with_monitoring(nr_steps=3)

# each network state is a vector. reshape it to the same shape used to create the patterns.
states_as_patterns = factory.reshape_patterns(states)
# plot the states of the network
plot_tools.plot_state_sequence_and_overlap(states_as_patterns, pattern_list, reference_idx=0, suptitle="Network dynamics")


## Our own Hopfield network model

You will write your own code to build a Hopfield network. I am leaving sample expected results below as a guide. 

You are free to use your own functions, no need to follow the guide  
Function  
1. get_patterns()  
2. perturb_pattern()  
3. calculate_weights()  
4. network_evolution()  
5. run_network()

In [None]:
plist = get_patterns(4,5)
#plot_tools.plot_pattern_list(plist)

cue = deepcopy(plist)
#print(plist[0])

S = perturb_pattern(cue[0],3)
wghts = calculate_weights(plist, 4, 5)
#print(wghts)
#print(wghts[15][15])
state_list =[S]
#state_list.append(S)
#print(state_list)
#plot_tools.plot_pattern(S)
for i in range(3):
    state_list.append(network_evolution(S,wghts))
    S = state_list[i+1]
#    state_list.append(S)

plot_tools.plot_pattern_list(plist)
#print(state_list[0])
for i in range(4):
  state_list[i] = state_list[i].reshape(4,4)

plot_tools.plot_pattern(state_list[3].reshape(4,4))
plot_tools.plot_state_sequence_and_overlap(state_list, plist, reference_idx=0, suptitle="Network dynamics")

## Assignment 1

Q1. Can you write you own Hopfield network model that works more or less like the one simulated above?

Q2. Run the model with different parameters to figure out how the model's capacity to retrieve the correct pattern in response to a cue deteriorates as a function of   
(a) the informativeness of the cue  
(b) the number of other patterns stored in the network  
(c) the size of the network  

Present your answers with plots and/or math.

You can use plot_tools as is

In [None]:
def get_patterns(size,num):
    pl=[];
    for i in range(0,num):
       temp=(np.random.choice([-1,1],size=(size*size)))
       temp= temp.reshape(size,size)
       pl.append(temp)
    return pl

In [None]:
def perturb_pattern(x,n):
    m = np.product(x.shape)
    x.ravel()[np.random.randint(0, m, size=n)] = 0
    return x

In [None]:
def calculate_weights(plist, n):
  n=n**2
  w = np.zeros((n,n))
  # print(w)
  #print(len(plist))
  for i in range(0,n):
    for j in range(0,n):
      for p in plist:
         p=p.flatten();
         w[i][j]+=p[i]*p[j]
  w=w/n
  return w

In [None]:
def network_evolution(S,w):
  s=S.shape
  n=S.shape[0]
  n=n*n
  a=np.zeros(n)
  S=S.flatten()
  for i in range(0,n):
    for j in range(0,n):
      a[i]+= w[i][j]*S[j]
  S=np.reshape(a,s)
  S= np.sign(S)    
  return S

In [None]:
plist = get_patterns(4,5)
#plot_tools.plot_pattern_list(plist)

cue = deepcopy(plist)

S = perturb_pattern(cue[0],10)
wghts = calculate_weights(plist, pattern_size)
print(wghts)
state_list =[S]
#state_list.append(S)
#print(state_list)
#plot_tools.plot_pattern(S)
print(network_evolution(S,wghts))
for i in range(3):
    state_list.append(network_evolution(S,wghts))
#    state_list.append(S)

plot_tools.plot_pattern_list(plist)
#print(state_list[0])
plot_tools.plot_pattern(state_list[0])
plot_tools.plot_state_sequence_and_overlap(state_list, plist, reference_idx=0, suptitle="Network dynamics")

In [None]:
def run_network(pattern_size,pattern_num,noise_flip,iter):
  plist = get_patterns(pattern_size,pattern_num)
  #plot_tools.plot_pattern_list(plist)

  cue = deepcopy(plist)
  S = perturb_pattern(cue[0],noise_flip)

  state_list=[S]
  plot_tools.plot_pattern(S)
  wghts=calculate_weights(plist, pattern_size)
  # print(S)
  for i in range(iter):
       S=network_evolution(S,wghts)
       #  print(S)
       state_list.append(S)
  plot_tools.plot_pattern_list(plist)
  #print(state_list[0])
  #plot_tools.plot_pattern(state_list[0])
  plot_tools.plot_state_sequence_and_overlap(state_list, plist, reference_idx=0, suptitle="Network dynamics")

In [None]:
run_network(20,20,170,4)