# Biased Hopfield networks

## Functions

In [None]:
%matplotlib inline
#%matplotlib widget
import sys
import numpy as np
import torch
import matplotlib.pyplot as pl
import pandas as pd
from scipy.stats import kurtosis, skew
from matplotlib.ticker import FormatStrFormatter
import math
from numpy import trapz
from tqdm import tqdm
from scipy.stats import norm
import matplotlib.ticker as mtick
from matplotlib.ticker import PercentFormatter


def learn_patterns(patterns, sparseness):
    N = patterns.size(dim=1)
    return torch.matmul(patterns-sparseness, patterns.permute(0,2,1)-sparseness)/N

def activation_function(activity,threshold,stochastic,temperature):
    output = activity.clone()
    if(stochastic=='on'):
        temp = torch.rand(output.size()).cuda()
        threshold_stochastic = 0.125*torch.log(temp/(1-temp))*temperature/2
    else:
        threshold_stochastic = torch.zeros(output.size()).cuda()
    output[output-threshold-threshold_stochastic<= 0] = -1
    output[output-threshold-threshold_stochastic > 0] = 1
    return(output)
  
def network_output(state, weights, threshold, sparseness_ratio, stochastic, temperature):
    return(activation_function(torch.matmul(weights,state/2+1/2-sparseness_ratio) ,threshold, stochastic, temperature))

def patterns_distance(pattern1, pattern2):
    return((1-torch.sum(pattern1*pattern2,1)/pattern1.size(dim=1))/2)
    
#####
def set_up_the_model(iterations, N, nr_patterns, sparseness_ratio):
    W = torch.zeros((iterations,N,N)).cuda()
    patterns = torch.rand(iterations,N,nr_patterns).cuda()
    patterns[patterns<sparseness_ratio] = 1
    patterns[patterns!=1] = 0
    W = learn_patterns(patterns, sparseness_ratio)
    W = W*(1-torch.eye(N,N).cuda())
    patterns = patterns*2-1
    return patterns, W
          
def update_network(starting_state, W, threshold, update_steps, sparseness_ratio, stochastic='off', temperature=0):
    # update the network based on the W dynamics
    update_state = starting_state.clone()
    for i in range(update_steps):
        update_state = network_output(update_state, W, threshold, sparseness_ratio, stochastic, temperature)
    return update_state.squeeze()

#####
def disturb_a_pattern(pattern, nr_flips, how='flip'):
    if how=='on':
        return turn_nodes_on(pattern, nr_flips)
    elif how=='off':
        return turn_nodes_off(pattern, nr_flips)
    else:
        return flip_nodes(pattern, nr_flips)     #default is flip
    
def flip_nodes(state, nr_flips): 
    output = state.clone()
    N = output.size(dim=1)
    if(nr_flips>0):
        output[:,0:nr_flips] *= -1
    return(output)

def turn_nodes_on(state, nr_flips): 
    output = state.clone()
    N = output.size(dim=1)
    if(nr_flips>0):
        output[:,0:nr_flips] = 1
    return(output)

def turn_nodes_off(state, nr_flips): 
    output = state.clone()
    N = output.size(dim=1)
    if(nr_flips>0):
        output[:,0:nr_flips] = -1
    return(output)

#####
def dilute_connections(W, dilution_prob, random_matrix):
    output = W.clone()
    output[random_matrix<dilution_prob] = 0
    return output

def dilute_nodes_p(patterns, dilution_prob):
    patterns_n = patterns.clone()
    N = patterns.size(dim=1)
    keep_index = torch.arange(int(N*dilution_prob),N).cuda()
    patterns_n = torch.index_select(patterns_n, 1, keep_index).cuda()
    return patterns_n
    
def dilute_nodes_W(W, dilution_prob):
    W_n = W.clone()  
    N = patterns.size(dim=1)
    keep_index = torch.arange(int(N*dilution_prob),N).cuda()
    W_n = torch.index_select(W_n, 1, keep_index).cuda()
    W_n = torch.index_select(W_n, 2, keep_index).cuda()
    return W_n

#####
def calc_auc(y_matrix, x, y_max):
    output = torch.zeros(len(y_matrix))
    for i in range(len(output)):
        output[i] = trapz(y_matrix[i,:].numpy(), x.numpy()) / (x[-1]*y_max)
    return output

def calc_dropping_point(y_matrix, x, threshold):
    output = torch.zeros(len(y_matrix))
    for i in range(len(output)):
        output[i] = x[max(np.min(np.where(y_matrix[i,:].numpy()<threshold))-1,0)] / x[-1]
    return output

def log2_noninf(x):
    if x>0:
        return(np.log2(x))
    if x==0:
        return(0)    

def patterns_mutual_info(pattern1, pattern2):
    p0 = np.sum(pattern1==-1)/len(pattern1)
    p1 = np.sum(pattern1==1)/len(pattern1)
    H_Y = -p0 * log2_noninf(p0) -p1 * log2_noninf(p1)
    
    p0 = np.sum(pattern2==-1)/len(pattern2)
    p1 = np.sum(pattern2==1)/len(pattern2)
    H_Y_given_X = 0
    if p0>0:
        p00 = np.sum((pattern2==-1) & (pattern1==-1)) /np.sum(pattern2==-1)
        p01 = np.sum((pattern2==-1) & (pattern1==1))  /np.sum(pattern2==-1)
        H_Y_given_X = -p0 * (p00*log2_noninf(p00)+p01*log2_noninf(p01))
    if p1>0:
        p10 = np.sum((pattern2==1)  & (pattern1==-1)) /np.sum(pattern2==1)
        p11 = np.sum((pattern2==1)  & (pattern1==1))  /np.sum(pattern2==1)
        H_Y_given_X = H_Y_given_X -p1 * (p10*log2_noninf(p10)+p11*log2_noninf(p11))    
    return(H_Y - H_Y_given_X)

def patterns_mutual_info_normalized(pattern1, pattern2):
    normalization_factor = patterns_mutual_info(pattern1, pattern1)
    return (patterns_mutual_info(pattern1, pattern2) / normalization_factor)
        
def calc_f_score(pattern1, pattern2):
    TP = np.sum((pattern1==1) & (pattern2==1))
    FP = np.sum((pattern1==-1) & (pattern2==1))
    FN = np.sum((pattern1==1) & (pattern2==-1))
    if TP==0:
        return 0
    else:        
        precision = TP / (TP + FP)
        recall = TP / (TP + FN)
        return(2*(precision*recall)/(precision+recall))

## 0. determining the tolerance

In [None]:
torch.manual_seed(seed=7)
N = 5000
iterations = 100
sparseness_ratio = np.array([0.01,0.02,0.05,0.1,0.15,0.2,0.3,0.4,0.5])
threshold = (sparseness_ratio**3-sparseness_ratio**2 + sparseness_ratio-2*sparseness_ratio**2+sparseness_ratio**3)/2

In [None]:
mutual_info_temp = np.zeros(int(N/5))
f_score_temp = np.zeros(int(N/5))
mutual_info_quantile = np.zeros((len(sparseness_ratio),iterations))
f_score_quantile = np.zeros((len(sparseness_ratio),iterations))
for s in enumerate(sparseness_ratio):
    print(s[0])
    for i in range(iterations):
        pattern = np.random.rand(N)  
        pattern[pattern<np.quantile(pattern, s[1])] = 1 #this means setting the exact number of ON nodes
        pattern[pattern!=1] = 0
        pattern = pattern*2-1        
        for n in range(int(N/5)):
            disturbed_pattern = pattern.copy()
            disturbed_pattern[0:n] *= -1
            mutual_info_temp[n] = patterns_mutual_info_normalized(pattern, disturbed_pattern)
            f_score_temp[n] = calc_f_score(pattern, disturbed_pattern)
        mutual_info_quantile[s[0],i] = np.max(np.where(mutual_info_temp>0.88))
        f_score_quantile[s[0],i] = np.max(np.where(f_score_temp>0.88))

In [None]:
pl.rcParams.update({'font.size': 11})
#print(np.round(np.mean(mutual_info_quantile,1)/N,4))
tolerance = np.array([0.0025,0.0041,0.007,0.0101,0.0119,0.0134,0.015,0.0159,0.0162])

fig, axs = pl.subplots(1, 1, figsize=(7, 4), facecolor='w', edgecolor='k')
axs.plot(sparseness_ratio, tolerance, '-s', color='darkgreen')
axs.set_title("tolerance level based on mutual information",size=15)
axs.set_ylabel("tolerance",size=14)
axs.set_xlabel("mean activity level (p)",size=14)
axs.xaxis.set_major_formatter(mtick.PercentFormatter(1.0)) 
axs.set_ylim([0, 0.02])
axs.grid(linewidth = 0.6)
axs.yaxis.set_major_formatter(mtick.PercentFormatter(1.0)) 

for i in range(len(sparseness_ratio)):
    if i>6:        
        pl.annotate(tolerance[i],
                    (sparseness_ratio[i],tolerance[i]),
                    textcoords="offset points",
                    xytext=(-21,-16))
    elif i>5:        
        pl.annotate(tolerance[i],
                    (sparseness_ratio[i],tolerance[i]),
                    textcoords="offset points",
                    xytext=(-15,-16))        
    elif i>1:
        pl.annotate(tolerance[i],
            (sparseness_ratio[i],tolerance[i]),
            textcoords="offset points",
            xytext=(3.5,-12))
    else:
        pl.annotate(tolerance[i],
            (sparseness_ratio[i],tolerance[i]),
            textcoords="offset points",
            xytext=(6,-5))    
pl.show()    

## 1. dependence of error on threshold for different sparsness

In [None]:
# GENERAL SETTINGS (see methods)
N = 5000
update_steps = 50
iterations = 10
batches = 10
sparseness_ratio = np.array([0.01,0.02,0.05,0.1,0.15,0.2,0.3,0.4,0.5])
threshold = (sparseness_ratio**3-sparseness_ratio**2 + sparseness_ratio-2*sparseness_ratio**2+sparseness_ratio**3)/2
tolerance = np.array([0.0025,0.0041,0.007,0.0101,0.0119,0.0134,0.015,0.0159,0.0162])

In [None]:
torch.manual_seed(seed=7)

nr_patterns = int(0.1*N)
thresholds = np.sort(np.append(np.append([-0.2, -0.05, 0.005, 0.11, 0.2], np.linspace(-0.04, 0.1, 20)), threshold))

In [None]:
distances1 = torch.zeros((batches,len(sparseness_ratio),iterations,len(thresholds)))
for b in range(batches):    
    print(b)
    for s in enumerate(sparseness_ratio):
        patterns, W = set_up_the_model(iterations, N, nr_patterns, s[1])
        for th in enumerate(thresholds):
            distances1[b,s[0],:,th[0]] = patterns_distance(patterns[:,:,0], update_network(patterns[:,:,0].unsqueeze(-1), W, th[1], update_steps, s[1])) #only the first pattern

In [None]:
hit_ratio = torch.zeros((len(sparseness_ratio),len(thresholds)))
average_distance = torch.zeros((len(sparseness_ratio),len(thresholds)))
for s in enumerate(sparseness_ratio):
    for th in enumerate(thresholds):
        hit_ratio[s[0],th[0]] = torch.sum(distances1[:,s[0],:,th[0]]<=tolerance[s[0]])/(iterations*batches)
        average_distance[s[0],th[0]] = torch.mean(distances1[:,s[0],:,th[0]])
#####        
fig, axs = pl.subplots(3, 3, figsize=(15, 8), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=0.45, wspace=0.2)
for i in range(axs.shape[0]):
    for j in range(axs.shape[1]):
        axs[i,j].plot(thresholds, hit_ratio[i*axs.shape[1]+j,:], color='darkgreen')
        axs[i,j].set_title('p='+"{:.0%}".format(np.round(sparseness_ratio[i*axs.shape[1]+j],3)),size=14)
        axs[i,j].set_ylim((-0.1, 1.1))
        axs[i,j].axvline(x=threshold[i*axs.shape[1]+j], color='r', linestyle='--', linewidth=1.2)
        axs[i,j].grid(linewidth = 0.6)
        if j==0: axs[i,j].set_ylabel('hit ratio', fontsize=14)
        if i==2: axs[i,j].set_xlabel('threshold', fontsize=14)

## 2. examining the load parameter for sparse networks

In [None]:
torch.manual_seed(seed=7)

nr_patterns = (np.append(np.linspace(0.05, 0.31, 14), np.append(np.linspace(0.34, 1, 12), np.linspace(1.1, 2.1, 11)))*N).astype(int)

In [None]:
distances2 = torch.ones((batches,len(sparseness_ratio),iterations,len(nr_patterns)))
for b in range(batches):
    print(b)
    for s in enumerate(sparseness_ratio):
        for n in enumerate(nr_patterns):
            if(n[1]<N or s[1]<=0.02): #this is merely to make the program faster; for sparsity>2%, the performance drops to 0 earlier than load=1.
                patterns, W = set_up_the_model(iterations, N, int(n[1]), s[1])
                distances2[b,s[0],:,n[0]] = patterns_distance(patterns[:,:,0], update_network(patterns[:,:,0].unsqueeze(-1), W, threshold[s[0]], update_steps, s[1])) #only the first pattern

In [None]:
hit_ratio = torch.zeros((len(sparseness_ratio),len(nr_patterns)))
average_distance = torch.zeros((len(sparseness_ratio),len(nr_patterns)))
quantile_distance = torch.zeros((len(sparseness_ratio),len(nr_patterns)))
for s in enumerate(sparseness_ratio):
    for n in enumerate(nr_patterns):
        hit_ratio[s[0],n[0]] = torch.sum(distances2[:,s[0],:,n[0]]<=tolerance[s[0]])/(iterations*batches)
        average_distance[s[0],n[0]] = torch.mean(distances2[:,s[0],:,n[0]])
        quantile_distance[s[0],n[0]] = np.quantile(distances2[:,s[0],:,n[0]], 0.9)
#####        

fig, axs = pl.subplots(3, 3, figsize=(15, 8), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=0.45, wspace=0.2)
for i in range(axs.shape[0]):
    for j in range(axs.shape[1]):
        axs[i,j].fill_between(nr_patterns/N, hit_ratio[i*axs.shape[1]+j,:], color='seagreen')
        axs[i,j].plot(nr_patterns/N, hit_ratio[i*axs.shape[1]+j,:], color='darkgreen')
        axs[i,j].set_title('p='+"{:.0%}".format(np.round(sparseness_ratio[i*axs.shape[1]+j],3)),size=14)
        axs[i,j].set_ylim((-0.1, 1.1))
        axs[i,j].grid(linewidth = 0.6)
        if j==0: axs[i,j].set_ylabel('hit ratio', fontsize=14)
        if i==2: axs[i,j].set_xlabel('load parameter', fontsize=14)

## 3. investigating and showing the discontinuous behaviour

In [None]:
# 5%
b = np.append([0,0.005],np.linspace(0.01,0.5,99))
fig, axs = pl.subplots(2, 3, figsize=(16, 6), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=0.37, wspace=0.17)

for i in range(axs.shape[0]):
    for j in range(axs.shape[1]):
        n, bins, patches = axs[i,j].hist(distances2[:,2,:,15+i*axs.shape[1]+j].reshape(1, iterations*batches), bins=b, color='seagreen')
        axs[i,j].set_ylim([0,iterations*batches])
        axs[i,j].set_title(r'$\alpha$='+str(np.round(nr_patterns[15+i*axs.shape[1]+j]/N,2)),size=14)
        axs[i,j].xaxis.set_major_formatter(FormatStrFormatter('%.3f'))
        if j==0: axs[i,j].set_ylabel('frequency', fontsize=14)
        if i==1: axs[i,j].set_xlabel('distance', fontsize=14)

In [None]:
# 10%
b = np.append([0,0.005],np.linspace(0.01,0.5,99))
fig, axs = pl.subplots(2, 3, figsize=(16, 6), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=0.37, wspace=0.17)

for i in range(axs.shape[0]):
    for j in range(axs.shape[1]):
        n, bins, patches = axs[i,j].hist(distances2[:,3,:,10+i*axs.shape[1]+j].reshape(1, iterations*batches), bins=b, color='seagreen')
        axs[i,j].set_ylim([0,iterations*batches])
        axs[i,j].set_title(r'$\alpha$='+str(np.round(nr_patterns[10+i*axs.shape[1]+j]/N,2)),size=14)
        axs[i,j].xaxis.set_major_formatter(FormatStrFormatter('%.3f'))
        if j==0: axs[i,j].set_ylabel('frequency', fontsize=14)
        if i==1: axs[i,j].set_xlabel('distance', fontsize=14)

In [None]:
# 50%
b = np.append([0,0.005],np.linspace(0.01,0.5,99))
fig, axs = pl.subplots(2, 3, figsize=(16, 6), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=0.37, wspace=0.17)

for i in range(axs.shape[0]):
    for j in range(axs.shape[1]):
        n, bins, patches = axs[i,j].hist(distances2[:,8,:,1+i*axs.shape[1]+j].reshape(1, iterations*batches), bins=b, color='seagreen')
        axs[i,j].set_ylim([0,iterations*batches])
        axs[i,j].set_title(r'$\alpha$='+str(np.round(nr_patterns[1+i*axs.shape[1]+j]/N,2)),size=14)
        axs[i,j].xaxis.set_major_formatter(FormatStrFormatter('%.3f'))
        if j==0: axs[i,j].set_ylabel('frequency', fontsize=14)
        if i==1: axs[i,j].set_xlabel('distance', fontsize=14)

### the 90% percentile for the two examples

In [None]:
torch.manual_seed(seed=7)

In [None]:
nr_patterns = (np.linspace(0.52, 0.58, 12, dtype=np.float16)*N).astype(int)
sparseness_ratio = 0.05
threshold = (sparseness_ratio**3-sparseness_ratio**2 + sparseness_ratio-2*sparseness_ratio**2+sparseness_ratio**3)/2

distances2_1 = torch.zeros((batches,iterations,len(nr_patterns)) ,dtype=torch.float16)
for b in range(batches):    
    print(b)
    for n in enumerate(nr_patterns):
        patterns, W = set_up_the_model(iterations, N, int(n[1]), sparseness_ratio)
        distances2_1[b,:,n[0]] = patterns_distance(patterns[:,:,0], update_network(patterns[:,:,0].unsqueeze(-1), W, threshold, update_steps, sparseness_ratio)) #only the first pattern

In [None]:
nr_patterns = (np.linspace(0.1, 0.16, 12, dtype=np.float16)*N).astype(int)
sparseness_ratio = 0.5
threshold = (sparseness_ratio**3-sparseness_ratio**2 + sparseness_ratio-2*sparseness_ratio**2+sparseness_ratio**3)/2

distances2_2 = torch.zeros((batches,iterations,len(nr_patterns)) ,dtype=torch.float16)
for b in range(batches):    
    print(b)
    for n in enumerate(nr_patterns):
        patterns, W = set_up_the_model(iterations, N, int(n[1]), sparseness_ratio)
        distances2_2[b,:,n[0]] = patterns_distance(patterns[:,:,0], update_network(patterns[:,:,0].unsqueeze(-1), W, threshold, update_steps, sparseness_ratio)) #only the first pattern

In [None]:
pl.rcParams.update({'font.size': 11})
percentile_distance1 = np.zeros(12)
percentile_distance2 = np.zeros(12)
for n in range(12):
    percentile_distance1[n] = np.quantile(distances2_1[:,:,n],0.9)
    percentile_distance2[n] = np.quantile(distances2_2[:,:,n],0.9)
#####
fig, axs = pl.subplots(1, 2, figsize=(12, 5), facecolor='w', edgecolor='k')
fig.subplots_adjust(wspace=0.3)

nr_patterns = (np.linspace(0.52, 0.58, 12, dtype=np.float16)*N).astype(int)
axs[0].plot(nr_patterns/N, percentile_distance1, color='darkgreen', linewidth=2)
axs[0].set_title('p='+"{:.0%}".format(0.05),size=14)
#axs[0].set_xlim((0.1, 0.16))
axs[0].set_ylim((0, 0.4))
axs[0].set_xlabel('load parameter', fontsize=14)
axs[0].set_ylabel('90th percentile of distance', fontsize=14)
first_derivative = np.diff(percentile_distance1)
second_derivative = np.diff(np.diff(percentile_distance1))
max_curvature = nr_patterns[np.where(second_derivative==np.max(second_derivative))[0]+1]/N
max_curvature_dist = percentile_distance1[np.where(second_derivative==np.max(second_derivative))[0]+1]
print(max_curvature)
print(max_curvature_dist)
print(nr_patterns[np.where(first_derivative==np.max(first_derivative))[0]]/N)
axs[0].axvline(x=max_curvature, color='red', linestyle='--', linewidth=1)
axs[0].text(max_curvature-0.003,0.1,'$\\alpha$='+' '.join(map(str, np.round(max_curvature,4))),rotation=90, size=12)
axs[0].axhline(y=max_curvature_dist, color='red', linestyle='--', linewidth=1)
axs[0].text(0.525,max_curvature_dist+0.006,'distance='+' '.join(map(str, np.round(max_curvature_dist,4))), size=12)


nr_patterns = (np.linspace(0.1, 0.16, 12, dtype=np.float16)*N).astype(int)
axs[1].plot(nr_patterns/N, percentile_distance2, color='darkgreen', linewidth=2)
axs[1].set_title('p='+"{:.0%}".format(0.5),size=14)
axs[1].set_ylim((0, 0.4))
axs[1].set_xlabel('load parameter', fontsize=14)
axs[1].set_ylabel('90th percentile of distance', fontsize=14)
first_derivative = np.diff(percentile_distance2)
second_derivative = np.diff(np.diff(percentile_distance2))
max_curvature = nr_patterns[np.where(second_derivative==np.max(second_derivative))[0]+1]/N
max_curvature_dist = percentile_distance2[np.where(second_derivative==np.max(second_derivative))[0]+1]
print(max_curvature)
print(max_curvature_dist)
print(nr_patterns[np.where(first_derivative==np.max(first_derivative))[0]]/N)
axs[1].axvline(x=max_curvature, color='red', linestyle='--', linewidth=1)
axs[1].text(max_curvature-0.003,0.1,'$\\alpha$='+' '.join(map(str, np.round(max_curvature,4))),rotation=90, size=12)
axs[1].axhline(y=max_curvature_dist, color='red', linestyle='--', linewidth=1)
axs[1].text(0.105,max_curvature_dist+0.006,'distance='+' '.join(map(str, np.round(max_curvature_dist,4))), size=12)
pl.show()

## 4. (example) dependence of performance on weight dilution for different sparseness values

In [None]:
torch.manual_seed(seed=7)

nr_patterns = int(0.12*N)

In [None]:
noise = torch.linspace(0, 1, 21)

distances3 = torch.zeros((batches, len(sparseness_ratio), iterations, len(noise)))
for b in range(batches):
    print(b)
    for s in enumerate(sparseness_ratio):
        patterns, W = set_up_the_model(iterations, N, nr_patterns, s[1])
        random_matrix = torch.rand(iterations,N,N).cuda()
        for n in enumerate(noise):
            W_n = dilute_connections(W, n[1], random_matrix)
            distances3[b,s[0],:,n[0]] = patterns_distance(patterns[:,:,0], update_network(patterns[:,:,0].unsqueeze(-1), W_n, threshold[s[0]], update_steps, s[1]))

In [None]:
pl.rcParams.update({'font.size': 11})
hit_ratio = torch.zeros((len(sparseness_ratio),len(noise)))
average_distance = torch.zeros((len(sparseness_ratio),len(noise)))
for s in enumerate(sparseness_ratio):
    for n in enumerate(noise):
        hit_ratio[s[0],n[0]] = torch.sum(distances3[:,s[0],:,n[0]]<=tolerance[s[0]])/(iterations*batches)
        average_distance[s[0],n[0]] = torch.mean(distances3[:,s[0],:,n[0]])
#####        
fig, axs = pl.subplots(3, 3, figsize=(15, 8), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=0.45, wspace=0.2)
for i in range(axs.shape[0]):
    for j in range(axs.shape[1]):
        axs[i,j].fill_between(noise, hit_ratio[i*axs.shape[1]+j,:], color='seagreen')
        axs[i,j].plot(noise, hit_ratio[i*axs.shape[1]+j,:], color='darkgreen')
        axs[i,j].set_title('p='+"{:.0%}".format(np.round(sparseness_ratio[i*axs.shape[1]+j],3)),size=14)
        axs[i,j].set_ylim((-0.1, 1.1))
        axs[i,j].grid(linewidth = 0.6)
        if j==0: axs[i,j].set_ylabel('hit ratio', fontsize=14)
        if i==2: axs[i,j].set_xlabel('dilution probability', fontsize=14)
#####
fig, axs = pl.subplots(1, 1, figsize=(6, 4), facecolor='w', edgecolor='k')
axs.plot(sparseness_ratio, calc_auc(hit_ratio, noise, 1), '-s', color='darkgreen')
axs.set_xlabel("mean activity level (p)",size=14)
axs.set_ylabel("normalized AUC",size=14)
axs.set_title("Area Under Curve",size=15)
axs.set_ylim([0,0.6])
axs.grid(linewidth = 0.6)
axs.xaxis.set_major_formatter(mtick.PercentFormatter(1.0)) 
pl.show()

In [None]:
fig, axs = pl.subplots(3, 3, figsize=(15, 8), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=0.45, wspace=0.2)
for i in range(axs.shape[0]):
    for j in range(axs.shape[1]):
        axs[i,j].fill_between(noise, average_distance[i*axs.shape[1]+j,:], color="seagreen")
        axs[i,j].plot(noise, average_distance[i*axs.shape[1]+j,:], color="darkgreen")
        axs[i,j].set_title('p='+"{:.0%}".format(np.round(sparseness_ratio[i*axs.shape[1]+j],3)),size=14)
        axs[i,j].set_ylim((-0.1, 1.1))
        axs[i,j].grid(linewidth = 0.6)
        if j==0: axs[i,j].set_ylabel('average distance', fontsize=14)
        if i==2: axs[i,j].set_xlabel('dilution probability', fontsize=14)

## 5. robustness

In [None]:
torch.manual_seed(seed=7)

nr_patterns = (np.array([0.05,0.1,0.15,0.2,0.3,0.4,0.5])*N).astype(int)

In [None]:
nr_flips = torch.linspace(0, N, 21)
noise_w = torch.linspace(0, 0.002, 20)
noise_w_dil = torch.linspace(0, 1, 21)
noise_n_dil = torch.linspace(0, 0.9, 19)
noise_stoc = torch.linspace(0, 1, 21)

distances_f = torch.zeros((batches,len(nr_patterns), len(sparseness_ratio), len(nr_flips), iterations))
distances_f1 = torch.zeros((batches,len(nr_patterns), len(sparseness_ratio), len(nr_flips), iterations))
distances_f2 = torch.zeros((batches,len(nr_patterns), len(sparseness_ratio), len(nr_flips), iterations))
distances_n1 = torch.zeros((batches,len(nr_patterns), len(sparseness_ratio), len(noise_w), iterations))
distances_n2 = torch.zeros((batches,len(nr_patterns), len(sparseness_ratio), len(noise_w_dil), iterations))
distances_n3 = torch.zeros((batches,len(nr_patterns), len(sparseness_ratio), len(noise_n_dil), iterations))
distances_stoch = torch.zeros((batches,len(nr_patterns), len(sparseness_ratio), len(noise_stoc), iterations))

for b in range(batches):
    for pat in enumerate(nr_patterns):
        print(pat[0])
        for s in enumerate(sparseness_ratio):
            patterns, W = set_up_the_model(iterations, N, pat[1], s[1])
            random_matrix_norm = torch.normal(0,1,(iterations,N,N)).cuda()
            random_matrix_unif = torch.rand(iterations,N,N).cuda()
            for f in enumerate(nr_flips):
                distances_f[b,pat[0],s[0],f[0],:] = patterns_distance(patterns[:,:,0], update_network(disturb_a_pattern(patterns[:,:,0], int(f[1])).unsqueeze(-1), W, threshold[s[0]], update_steps, s[1]))
                distances_f1[b,pat[0],s[0],f[0],:] = patterns_distance(patterns[:,:,0], update_network(disturb_a_pattern(patterns[:,:,0], int(f[1]), 'on').unsqueeze(-1), W, threshold[s[0]], update_steps, s[1]))
                distances_f2[b,pat[0],s[0],f[0],:] = patterns_distance(patterns[:,:,0], update_network(disturb_a_pattern(patterns[:,:,0], int(f[1]), 'off').unsqueeze(-1), W, threshold[s[0]], update_steps, s[1]))
            for n in enumerate(noise_w):
                W_n = W+random_matrix_norm*n[1]
                W_n = W_n*(1-torch.eye(N,N).cuda())
                distances_n1[b,pat[0],s[0],n[0],:] = patterns_distance(patterns[:,:,0], update_network(patterns[:,:,0].unsqueeze(-1), W_n, threshold[s[0]], update_steps, s[1]))
            for n in enumerate(noise_w_dil):
                W_n = dilute_connections(W, n[1], random_matrix_unif)
                distances_n2[b,pat[0],s[0],n[0],:] = patterns_distance(patterns[:,:,0], update_network(patterns[:,:,0].unsqueeze(-1), W_n, threshold[s[0]], update_steps, s[1]))
            for n in enumerate(noise_n_dil):
                patterns_n    = dilute_nodes_p(patterns, n[1])
                W_n           = dilute_nodes_W(W, n[1])
                distances_n3[b,pat[0],s[0],n[0],:] = patterns_distance(patterns_n[:,:,0], update_network(patterns_n[:,:,0].unsqueeze(-1), W_n, threshold[s[0]], update_steps, s[1]))
                index_pattern_n_0 = torch.where(torch.sum(patterns_n[:,:,0]==1,1)==0)[0].cuda()
                distances_n3[b,pat[0],s[0],n[0],index_pattern_n_0] = 1
            for n in enumerate(noise_stoc):
                distances_stoch[b,pat[0],s[0],n[0],:] = patterns_distance(patterns[:,:,0], update_network(patterns[:,:,0].unsqueeze(-1), W, threshold[s[0]], update_steps, s[1], 'on', n[1]))

In [None]:
hit_ratio_flip = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(nr_flips)))
hit_ratio_on = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(nr_flips)))
hit_ratio_off = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(nr_flips)))
hit_ratio_w = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(noise_w)))
hit_ratio_w_dil = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(noise_w_dil)))
hit_ratio_n_dil = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(noise_n_dil)))
hit_ratio_stoc = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(noise_stoc)))

average_distance_flip = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(nr_flips)))
average_distance_on = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(nr_flips)))
average_distance_off = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(nr_flips)))
average_distance_w = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(noise_w)))
average_distance_w_dil = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(noise_w_dil)))
average_distance_n_dil = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(noise_n_dil)))
average_distance_stoc = torch.zeros((len(nr_patterns),len(sparseness_ratio),len(noise_stoc)))

AUC_flip = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
AUC_on = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
AUC_off = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
AUC_w = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
AUC_w_dil = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
AUC_n_dil = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
AUC_stoc = torch.zeros((len(nr_patterns),len(sparseness_ratio)))

DP_flip = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
DP_on = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
DP_off = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
DP_w = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
DP_w_dil = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
DP_n_dil = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
DP_stoc = torch.zeros((len(nr_patterns),len(sparseness_ratio)))
DR_threshold = 0.95

for pat in enumerate(nr_patterns):
    for s in enumerate(sparseness_ratio):
        for f in enumerate(nr_flips):
            hit_ratio_flip[pat[0],s[0],f[0]] = torch.sum(distances_f[:,pat[0],s[0],f[0],:]<=tolerance[s[0]])/(iterations*batches)
            hit_ratio_on[pat[0],s[0],f[0]] = torch.sum(distances_f1[:,pat[0],s[0],f[0],:]<=tolerance[s[0]])/(iterations*batches)
            hit_ratio_off[pat[0],s[0],f[0]] = torch.sum(distances_f2[:,pat[0],s[0],f[0],:]<=tolerance[s[0]])/(iterations*batches)
            average_distance_flip[pat[0],s[0],f[0]] = torch.mean(distances_f[:,pat[0],s[0],f[0],:])
            average_distance_on[pat[0],s[0],f[0]] = torch.mean(distances_f1[:,pat[0],s[0],f[0],:])
            average_distance_off[pat[0],s[0],f[0]] = torch.mean(distances_f2[:,pat[0],s[0],f[0],:])            

        for n in enumerate(noise_w):
            hit_ratio_w[pat[0],s[0],n[0]] = torch.sum(distances_n1[:,pat[0],s[0],n[0],:]<=tolerance[s[0]])/(iterations*batches)
            average_distance_w[pat[0],s[0],n[0]] = torch.mean(distances_n1[:,pat[0],s[0],n[0],:])
        for n in enumerate(noise_w_dil):
            hit_ratio_w_dil[pat[0],s[0],n[0]] = torch.sum(distances_n2[:,pat[0],s[0],n[0],:]<=tolerance[s[0]])/(iterations*batches)
            average_distance_w_dil[pat[0],s[0],n[0]] = torch.mean(distances_n2[:,pat[0],s[0],n[0],:])
        for n in enumerate(noise_n_dil):
            hit_ratio_n_dil[pat[0],s[0],n[0]] = torch.sum(distances_n3[:,pat[0],s[0],n[0],:]<=tolerance[s[0]])/(iterations*batches)
            average_distance_n_dil[pat[0],s[0],n[0]] = torch.mean(distances_n3[:,pat[0],s[0],n[0],:])
        for n in enumerate(noise_stoc):
            hit_ratio_stoc[pat[0],s[0],n[0]] = torch.sum(distances_stoch[:,pat[0],s[0],n[0],:]<=tolerance[s[0]])/(iterations*batches)
            average_distance_stoc[pat[0],s[0],n[0]] = torch.mean(distances_stoch[:,pat[0],s[0],n[0],:])        
    AUC_flip[pat[0],:] = np.round(calc_auc(hit_ratio_flip[pat[0],:,:], nr_flips, 1), 2)
    AUC_on[pat[0],:] = np.round(calc_auc(hit_ratio_on[pat[0],:,:], nr_flips, 1), 2)
    AUC_off[pat[0],:] = np.round(calc_auc(hit_ratio_off[pat[0],:,:], nr_flips, 1), 2)
    AUC_w[pat[0],:] = np.round(calc_auc(hit_ratio_w[pat[0],:,:], noise_w, 1), 2)
    AUC_w_dil[pat[0],:] = np.round(calc_auc(hit_ratio_w_dil[pat[0],:,:], noise_w_dil, 1), 2)
    AUC_n_dil[pat[0],:] = np.round(calc_auc(hit_ratio_n_dil[pat[0],:,:], noise_n_dil, 1), 2)
    AUC_stoc[pat[0],:] = np.round(calc_auc(hit_ratio_stoc[pat[0],:,:], noise_stoc, 1), 2)
    
    DP_flip[pat[0],:] = calc_dropping_point(hit_ratio_flip[pat[0],:,:], nr_flips, DR_threshold)
    DP_on[pat[0],:] = calc_dropping_point(hit_ratio_on[pat[0],:,:], nr_flips, DR_threshold)
    DP_off[pat[0],:] = calc_dropping_point(hit_ratio_off[pat[0],:,:], nr_flips, DR_threshold)
    DP_w[pat[0],:] = calc_dropping_point(hit_ratio_w[pat[0],:,:], noise_w, DR_threshold)
    DP_w_dil[pat[0],:] = calc_dropping_point(hit_ratio_w_dil[pat[0],:,:], noise_w_dil, DR_threshold)
    DP_n_dil[pat[0],:] = calc_dropping_point(hit_ratio_n_dil[pat[0],:,:], noise_n_dil, DR_threshold)
    DP_stoc[pat[0],:] = calc_dropping_point(hit_ratio_stoc[pat[0],:,:], noise_stoc, DR_threshold)

In [None]:
pl.rcParams.update({'font.size': 12})
colors = np.array(['tab:blue', 'tab:red', 'tab:green', 'tab:purple', 'tab:olive', 'tab:cyan', 'tab:brown'])
fig, axs = pl.subplots(4, 2, figsize=(15, 20), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=0.4, wspace=0.14)
fig.delaxes(axs[3,1])
for pat in enumerate(nr_patterns):
    axs[0,0].plot(sparseness_ratio, AUC_stoc[pat[0],:], '-s', label=str(pat[1]/N), color=colors[pat[0]])
    axs[0,1].plot(sparseness_ratio, AUC_w[pat[0],:], '-s', label=str(pat[1]/N), color=colors[pat[0]])
    axs[1,0].plot(sparseness_ratio, AUC_w_dil[pat[0],:], '-s', label=str(pat[1]/N), color=colors[pat[0]])
    axs[1,1].plot(sparseness_ratio, AUC_n_dil[pat[0],:], '-s', label=str(pat[1]/N), color=colors[pat[0]])
    axs[2,0].plot(sparseness_ratio, AUC_flip[pat[0],:], '-s', label=str(pat[1]/N), color=colors[pat[0]])
    axs[2,1].plot(sparseness_ratio, AUC_on[pat[0],:], '-s', label=str(pat[1]/N), color=colors[pat[0]])    
    axs[3,0].plot(sparseness_ratio, AUC_off[pat[0],:], '-s', label=str(pat[1]/N), color=colors[pat[0]])    
    
for i in range(4):    
    for j in range(2):    
        #axs[i,j].set_ylim([0,1])
        axs[i,j].xaxis.set_major_formatter(mtick.PercentFormatter(1.0)) 
        axs[i,j].set_xlabel("mean activity level (p)", size=14)
        if j==0:
            axs[i,j].set_ylabel("normalized AUC", size=14)
        if j==1 or i==3:
            axs[i,j].legend(bbox_to_anchor =(1, 1))

axs[0,0].set_title("robustness against stochastic threshold", size=16)
axs[0,1].set_title("robustness against weight noise", size=16)
axs[1,0].set_title("robustness against connection loss", size=16)
axs[1,1].set_title("robustness against node loss", size=16)
axs[2,0].set_title("robustness against state flips", size=16)
axs[2,1].set_title("robustness against states turning on", size=16)
axs[3,0].set_title("robustness against states turning off", size=16)

pl.show()

In [None]:
import seaborn as sns
fig, axs = pl.subplots(4, 2, figsize=(16, 20), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=0.4, wspace=0.12)
fig.delaxes(axs[3,1])
sns.heatmap(AUC_stoc, linewidth=1, annot=True, xticklabels=list(map("{:.0%}".format, sparseness_ratio)), yticklabels=np.round(nr_patterns/N,2), cbar=False, ax=axs[0,0])
sns.heatmap(AUC_w,    linewidth=1, annot=True, xticklabels=list(map("{:.0%}".format, sparseness_ratio)), yticklabels=np.round(nr_patterns/N,2), cbar=False, ax=axs[0,1])
sns.heatmap(AUC_w_dil,linewidth=1, annot=True, xticklabels=list(map("{:.0%}".format, sparseness_ratio)), yticklabels=np.round(nr_patterns/N,2), cbar=False, ax=axs[1,0])
sns.heatmap(AUC_n_dil,linewidth=1, annot=True, xticklabels=list(map("{:.0%}".format, sparseness_ratio)), yticklabels=np.round(nr_patterns/N,2), cbar=False, ax=axs[1,1])
sns.heatmap(AUC_flip, linewidth=1, annot=True, xticklabels=list(map("{:.0%}".format, sparseness_ratio)), yticklabels=np.round(nr_patterns/N,2), cbar=False, ax=axs[2,0])
sns.heatmap(AUC_on,   linewidth=1, annot=True, xticklabels=list(map("{:.0%}".format, sparseness_ratio)), yticklabels=np.round(nr_patterns/N,2), cbar=False, ax=axs[2,1])
sns.heatmap(AUC_off,  linewidth=1, annot=True, xticklabels=list(map("{:.0%}".format, sparseness_ratio)), yticklabels=np.round(nr_patterns/N,2), cbar=False, ax=axs[3,0])

for i in range(4):    
    for j in range(2):    
        #axs[i,j].legend()
        axs[i,j].invert_yaxis()
        axs[i,j].set_xlabel("mean activity level (p)", size=14)
        if j==0:
            axs[i,j].set_ylabel("load parameter", size=14)

axs[0,0].set_title("robustness against stochastic threshold", size=16)
axs[0,1].set_title("robustness against weight noise", size=16)
axs[1,0].set_title("robustness against connection loss", size=16)
axs[1,1].set_title("robustness against node loss", size=16)
axs[2,0].set_title("robustness against state flips", size=16)
axs[2,1].set_title("robustness against states turning on", size=16)
axs[3,0].set_title("robustness against states turning off", size=16)

pl.show()