## Task 3 - Random Pattern n-back Analysis

Author: Alon Loeffler

Required Files/Folders: learning_functions.py | edamame | asn_nw_00350_nj_01350_seed_1581_avl_10.00_disp_01.00_lx_50.00_ly_50.00.mat


In [None]:
#IMPORTS:
#Append path to Ruomin's Edamame Package (Nanowire Simulations)
import sys
import os

sys.path.append('../') #point to PhysicalReinforcementLearning Folder locally
sys.path.append('../../edamame') #point to edamame locally

#Choose which nw:
nwChoice=700 #350 or 700

if nwChoice == 350:
    fileName='/PhysicalReinforcementLearning/generate_networks/Network Data/asn_nw_00350_nj_01350_seed_1581_avl_10.00_disp_01.00_lx_50.00_ly_50.00'
elif nwChoice == 700:
    fileName='/PhysicalReinforcementLearning/generate_networks/Network Data/asn_nw_00698_nj_02582_seed_002_avl_10.00_disp_01.00_lx_75.00_ly_75.00'

#point to network data path ^ 

saveFig='../../Data/Figures/'  #Define Figure Save location (not in github repo)
dataLoc='../../Data/Associative Learning/EquilProp/' #Define Data Save Location (not in github repo)

#import edamame (neuromorphic nanowire python package by Ruomin Zhu)
from edamame import * 
import numpy as np
import matplotlib.pyplot as plt
import copy
from scipy.io import loadmat, savemat
import networkx as nx
from tqdm.notebook import tqdm_notebook as tqdm
from IPython.core.debugger import set_trace

import pickle 
import _pickle as cPickle
import gzip

from learning_functions import genGridNW,point_on_line,dist,getWeightedGraph
from learning_functions import calc_cost,setupStimulus,setupSourcesOnly,runTesting,getNWState,calcOutputs

In [None]:
#Functions:
def buildNetworks(nw,ManualSources=True,numDrains=2,numSources=9):
    """
    Load Networks
    """

    #load data
#     nw=loadmat(fileName)
    nwSize = nw['number_of_wires'][0][0]
    nwJunctions=nw['number_of_junctions'][0][0]
    print('Network '+str(nwSize)+ ' Loaded')
#     this_seed=8639
    Network=nw
    connectivity=connectivity__(wires_dict=Network) 

    #fixing file structure from Matlab:
    connectivity.avg_length=connectivity.avg_length[0][0]
    connectivity.number_of_junctions=connectivity.number_of_junctions[0][0]
    connectivity.centroid_dispersion=connectivity.centroid_dispersion[0][0]
    connectivity.dispersion=connectivity.dispersion[0][0]
    # connectivity.generating_number_of_wires=connectivity.generating_number_of_wires[0][0]
    connectivity.gennorm_shape=connectivity.gennorm_shape[0][0]
    connectivity.length_x=connectivity.length_x[0][0]
    connectivity.length_y=connectivity.length_y[0][0]
    connectivity.number_of_wires=connectivity.number_of_wires[0][0]
    connectivity.numOfWires=connectivity.numOfWires[0][0]
    connectivity.numOfJunctions=connectivity.numOfJunctions[0][0]
    connectivity.theta=connectivity.theta[0][0]
    connectivity.this_seed=connectivity.this_seed[0][0]
    
    #find x and y values of each end of each wire 
    xa=connectivity.xa[0]
    xb=connectivity.xb[0]
    ya=connectivity.ya[0]
    yb=connectivity.yb[0]

    #Pick Electrode placement/distance apart:
    
    #DRAINS
#     numDrains=7
    ex=np.zeros(numDrains)*5#*15
    if nwSize > 500:
        eyVal=76
    else:
        eyVal= 51
    ey=np.linspace(-1,eyVal,numDrains)#(-1,101,numDrains)

    elecDrain=genGridNW(xa,xb,ya,yb,ex,ey) #generate drain locations in ex, ey coordinates


    #IN A LINE:
    if nwSize > 500:
        exVal=75
        eyVal = 74
    else:
        exVal=50
        eyVal=49
    ex=np.ones(numSources)*exVal#50
    ey=np.linspace(-2,eyVal,numSources)#(-2,99,numSources)
    
    #IN A 3X3 GRID:
    # middleNWx=int(connectivity.length_x/2)+1
    # middleNWy=int(connectivity.length_y/2)-1
    # distBwElecs=10
    # ex=[middleNWx-distBwElecs,middleNWx-distBwElecs,middleNWx-distBwElecs,middleNWx,middleNWx,middleNWx,middleNWx+distBwElecs,middleNWx+distBwElecs,middleNWx+distBwElecs]
    # ey=[middleNWy-distBwElecs,middleNWy,middleNWy+distBwElecs]*3

    elecSource=genGridNW(xa,xb,ya,yb,ex,ey) #generate source locations in ex, ey coordinates

    #Manual Source Placement:
    if ManualSources:
        if numSources==9:
            if nwSize > 500:
                elecSource=[678, 260, 491, 173, 628, 424, 301, 236, 483] #700nws
            else:
                elecSource=[23,320,194,74, 145, 317, 129, 34, 141] #350 nws

        elif numSources == 4:
            elecSource=[320,42,161,141]

#     elecDrain=elecDrain[[0,-1]]
    
    return connectivity,elecSource,elecDrain,nwSize,nwJunctions

### Load NWN Network

In [None]:
#Network Data:
nw=loadmat('../generate_networks/Network Data/asn_nw_00698_nj_02582_seed_002_avl_10.00_disp_01.00_lx_75.00_ly_75.00.mat')

# nw=loadmat(fileName)
g=nx.from_numpy_array(nw['adj_matrix'])
adjMat=nw['adj_matrix']

num_drain_training=1
num_drain_testing=7
numSources=9
numSamples=70
nbacks         = [num_drain_testing]
nback=nbacks[0]
numDT=200

# reinforcement=True #do we want to activate reinforcement?

#Build networks again (in case we didn't run simulation - e.g. loadOnly == True)
connectivity,sources,drain_pool,nwSize,nwJunctions=buildNetworks(nw,True,num_drain_testing,numSources)

#NW Parameters
nwJunctions=nw['number_of_junctions'][0][0]
nwWires=nw['number_of_wires'][0][0]

### Load Sim Data for each n-back

In [None]:
#LOAD ALL NONSIM DATA
b='b0p5' #filament decay value (b)
loadReinf='wReinforcement' #noReinforcement or wReinforcement

loadName='698nw_2582nj_working_memory_nback_'+loadReinf+'_7drains_data_'+b

#LOAD ALL OTHER DATA TOGETHER
with open(dataLoc+'/Sim Results/working memory nback/'+loadName+'.pkl', 'rb') as f:
        dataNoSim = pickle.load(f) 
accuracy=dataNoSim[3]
alltargets=dataNoSim[4]
thresholds=dataNoSim[1]


In [None]:
#LOAD SIM DATA
loadName='698nw_2582nj_working_memory_nback_'+loadReinf+'_7drains_sim_'+b
sims=[]
#LOAD EACH SIM INDIVIDUALLY
for i in tqdm(range(560)):
    with open(dataLoc+'/Sim Results/working memory nback/'+loadName+str(i)+'.pkl', 'rb') as f:
        sims.append(pickle.load(f))

In [None]:
nbacks=[1,2,3,4,5,6,7]#,8,9]


In [None]:
#Convert data from epochs to timesteps, or vice-versa

numDT=200
sampleTimes=numDT*(nbacks[-1]+1)
signalLen=200*len(alltargets)

nback=nbacks[-1] #max nback
tmp=numDT*nback
A=np.array(range(0,signalLen+1,int(tmp+numDT))) #training 
B=np.array(range(0+tmp,signalLen+1,int(tmp+numDT)))# testing
c=[]
for k in range(len(B)):
    c.append(range(A[k],B[k]))
trainTimes=np.array(np.hstack(c))
l=range(signalLen)
# testTimes=np.array(np.hstack([m for m in l if m not in testTimes]))

#test and train samples:
testSamples=[]
for i in range(-1,len(alltargets),nback+1):
    if i > 0:
        testSamples.append(i)
testSamples=np.array(testSamples)
trainSamples = np.setdiff1d(np.array(range(len(alltargets))),testSamples)

#Targets
targets=np.array(alltargets)[testSamples]
trainTargets=np.array(alltargets)[trainSamples]
trainTargets=trainTargets[:,0]

In [None]:
#Separate drain current to training and testing

# jV=[]
# jF=[]
# jC=[]
# activeSources=[]
elecItrain=[];elecItest=[];inputSignal=[]
drain1train=[];drain1test=[];voltage=[];drain1=[]
t=0
nback=nbacks[0]
for i in range(len(sims)):
#     jV.append(sim[i].junctionVoltage)   
#     jF.append(sim[i].filamentState) #Junction Filament Negative
#     jC.append(sim[i].junctionConductance)
#     activeSources.append(sim[i].sources)
    voltage.append(sims[i].wireVoltage[:,drain_pool])
    if i in trainSamples[trainTargets==1]:
        if sims[i].electrodeCurrent.shape[1] == 10: #if training
            tmp=sims[i].electrodeCurrent[:,0]
            elecItrain.append(tmp)
#             if alltargets[i][0]==0:
            drain1train.append(tmp)
            drain1.append(tmp)
    #             drain2.append([0]*tmp)
#             else:
    #             drain2.append(tmp)
#                 drain1train.append([0]*tmp)
    else: #only test:
#         if len(elecItest[i]) < 600:
        elecItest.append(sims[i].electrodeCurrent[:,:7])
        drain1test.append(elecItest[t][:,0])
        drain1.append(elecItest[t][:,0])
#         drain2.append(elecItest[t][:,1])
        t+=1
    
drain1=np.hstack(np.array(drain1))
drain1test=np.hstack(np.array(drain1test))
drain1train=np.hstack(np.array(drain1train))
voltage=np.vstack(np.array(voltage))


In [None]:
#Save Reinforcement figure if with reinforcement
if loadReinf == 'wReinforcement':
    plt.figure()
    ax=plt.gca()
    ax.plot(np.array(dataNoSim[1])[:,0][::10],label='Target',c='b')
    ax.plot(np.array(dataNoSim[1])[:,1][::10],label='Non-Target',c='r')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('$\Theta$')
    # ax.set_ylim([3e-6,])
    plt.legend()
    plt.savefig(saveFig+'698nw_2582nj_working_memory_nback_'+loadReinf+'_Theta_Threshold_'+b+'.pdf',format='pdf',dpi=300)

In [None]:
#Plot Accuracy vs Time
plt.figure()
ax=plt.gca()
ax.plot(np.array(accuracy),label='Target',c='b')
#     ax.plot(np.array(accuracy)[:,1][::10],label='Non-Target',c='r')
ax.set_xlabel('Epoch')
ax.set_ylabel('Acc')
# ax.set_ylim([3e-6,])
plt.legend()
plt.savefig(saveFig+'698nw_2582nj_working_memory_nback_'+loadReinf+'accuracy_time_'+b+'.pdf',format='pdf',dpi=300)

In [None]:
#Plot Current Strength vs Time:

import seaborn as sns
plt.rcParams['pdf.fonttype'] = 42

fig,ax = plt.subplots(1,1, figsize=(15, 8), dpi=300)
x=(drain1test)
p1=ax.plot(x,'b',alpha=0.7)
# p2=ax.plot((voltage[:,1]),'r',alpha=0.5)
# plt.legend([p1[0],p2[0]],['Target Drain'],loc='lower right',title='Current')


if nback == 7:
    xlimmax=830
    vlinesStep=40
    vlinesStart=30
    dotpos=2

if nback == 2:
#     xlimmax=1230
    vlinesStep=600
    vlinesStart=500
    dotpos=3
elif nback == 3:
    xlimmax=1630
    vlinesStep=80
    vlinesStart=70
    dotpos=4
        
ax.set_xlabel('Time')
ax.set_ylabel('Current')

fig.savefig(saveFig+'698w_2582j_current_time_accuracy_nback_random_patterns_'+loadReinf+'_'+b+'.pdf',format='pdf',dpi=300)


In [None]:
#Current vs Time inset

drain1thresh=signal_expand(np.array(thresholds)[:,0],numDT)
tstart=55000
tend=56000
# dur1=tend-np.where(drain1thresh>5e-6)[0][0]
# dur2=np.where(drain1thresh>5e-6)[0][0]-tstart
plt.plot(drain1[tstart:tend],label='Drain 1',c='b')
# plt.plot(drain2[tstart:tend],label='Drain 2')
plt.plot(drain1thresh[tstart:tend],c='black',linestyle='dashed',label='Threshold')
# plt.hlines(drain1thresh[dur2],0,dur2,colors='black',linestyles='dashed',label='Threshold')
# plt.hlines(drain1thresh[dur1],0,dur1,colors='black',linestyles='dashed',label='Threshold')

ax=plt.gca()
start, end = ax.get_xlim()
# plt.xticks(np.arange(start-start,end+start,np.abs(start-end)/6),labels=range(tstart,tend,int(dur/7)))
# ax.set_xticklabels(range(tstart-100,tend+100,int(dur/8)))
ax.set_xlabel('Timesteps')
ax.set_ylabel('Current (A)')
plt.legend()
plt.savefig(saveFig+'698w_2582j_current_time_oneEpoch_nback_random_patterns_t_'+str(tstart)+'_'+str(tend)+'_'+loadReinf+'_'+b+'.pdf',format='pdf')

In [None]:
#Sort data by n-back for histograms
import pandas as pd

nbackVals=np.argwhere(np.array(dataNoSim[7])==-1)-np.argwhere(np.array(dataNoSim[7])==0)
nbackVals=nbackVals.reshape(-1) #datastructure = [nbacks, experiments, epochs]
nbackAcc=pd.DataFrame({'nback':nbackVals,'Accuracy':accuracy})


In [None]:
#Statistical Analysis (Means and SEs)
numVals=[]
seAcc=[]
meanAcc=[]
for i in range(1,8):
    
    numVals.append((nbackAcc.where(nbackAcc['nback']==i)).count()[0])

    meanAcc.append((nbackAcc.where(nbackAcc['nback']==i)).mean(skipna=True)[1])
    seAcc.append((nbackAcc.where(nbackAcc['nback']==i)).std(skipna=True)[1]/np.sqrt(len(nbackAcc.where(nbackAcc['nback']==i))))

In [None]:
#Histograms
plt.rcParams['pdf.fonttype'] = 42

plt.bar(nbacks,meanAcc,yerr=seAcc,alpha=0.8, ecolor='black', capsize=5)
plt.xlabel('nback')
plt.ylabel('Accuracy')
ax=plt.gca()
ax.set_xticks(range(8))
ax.set_xticklabels(range(8))
ax.set_ylim([0,1.05])
plt.title(loadReinf)
# plt.savefig(saveFig+'698nw_2582nj_nback_working_memory_no_reinforcement_b0p1.pdf',format='pdf')


## Network Connectivity

In [None]:
sys.path.append('/import/silo2/aloe8475/Documents/CODE/aux_functions/') #point to edamame locally

from matplotlib import animation
from matplotlib import colorbar as clbr
from PIL import Image

from matplotlib import rcParams
from create_colormap import get_continuous_cmap,hex_to_rgb,rgb_to_dec

#Functions:
def getWeightedGraph(adjMat,jC,jV,jF,edgeList,numWires,time,edge_mode):#, this_TimeStamp = 0):
#     edgeList = network['edge_list']
    adjMat = np.zeros((numWires, numWires))
#     set_trace()
    if edge_mode=='conductance':
        adjMat[edgeList[:,0], edgeList[:,1]] = jC[time,:]#network.junctionSwitch[this_TimeStamp,:] #CHANGE THIS TO CONDUCTANCE THRESHOLD?
        adjMat[edgeList[:,1], edgeList[:,0]] = jC[time,:]#network.junctionSwitch[this_TimeStamp,:] #CHANGE THIS TO CONDUCTANCE THRESHOLD?
        maxWeights=np.max(jC)

    elif edge_mode=='current':
        adjMat[edgeList[:,0], edgeList[:,1]] = jV[time,:]*jC[time,:]
        adjMat[edgeList[:,1], edgeList[:,0]] = jV[time,:]*jC[time,:]
        maxWeights=np.max(jV*jC)
        
        
    elif edge_mode=='filament':
        adjMat[edgeList[:,0], edgeList[:,1]] = jF[time,:]
        adjMat[edgeList[:,1], edgeList[:,0]] = jF[time,:]
        
    WeightedGraph = nx.from_numpy_array(adjMat)
    WeightedGraph=nx.DiGraph.to_undirected(WeightedGraph)
    
    return WeightedGraph

def image_draw_voltage(time,cmap,maxWeights,minWeights):
    ax.clear() 
    pos=nx.kamada_kawai_layout(g)
           
    #ALON'S CODE:    
    node_weight=sim.wireVoltage[time]

    pos=nx.kamada_kawai_layout(g)
    nodeList=g.nodes
    numWires=g.number_of_nodes()

    h=nx.draw_networkx_nodes(g,pos=pos,node_color=node_weight,cmap=cmap,node_size=50,ax=ax)

    h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color='grey')
    # nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='g',node_size=50,ax=ax)
    nx.draw_networkx_nodes(g,pos=pos,nodelist=[drains[1]],node_color='r',node_size=10,ax=ax)
    nx.draw_networkx_nodes(g,pos=pos,nodelist=[drains[0]],node_color='b',node_size=10,ax=ax)
    ax.set_title(str(time))
    
import matplotlib.colors as clrs

def image_draw_current_filament(time,edge_mode,edge_weight,cmap,maxWeights,minWeights):
    ax.clear() 
#     pos=nx.kamada_kawai_layout(g)
        
    #RUOMIN'S CODE:
#     draw_graph(sim,time,edge_mode=edge_mode,edge_weight=edge_weight,edge_colorbar=False,node_size=30,figsize=(10,8),edge_cmap=cmap,with_labels=False,norm=True,ax=ax,maxWeight=maxWeights,minWeight=minWeights)

    #ALON'S CODE:    
    
#     edge_mode=animationType
#     edgeList=sim.connectivity.edge_list
#     numWires=g.number_of_nodes()
#     G=getWeightedGraph(sim,edgeList,numWires,time,edge_mode)
    
    pos=nx.kamada_kawai_layout(g)
    edgeList=sims[0].connectivity.edge_list
    numWires=g.number_of_nodes()
    G=getWeightedGraph(adjMat,jC,jV,jF,edgeList,numWires,time,animationType)#, this_TimeStamp = 0):

    edge_weights=nx.get_edge_attributes(G,'weight')
#     G.remove_edges_from((e for e, w in edge_weights.items() if w <1e-)) 
    
    edges=G.edges()
    weights=[G[u][v]['weight'] for u,v in edges]
    
    if animationType=='filament':
        logVal=1e-7
    else:
        logVal=1e-7
    Q
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=clrs.SymLogNorm(logVal,base=10,vmin=minWeights, vmax=maxWeights))
    sm.set_array([])
    normweights=sm.norm(weights).data
    if len(normweights)==0:
        normweights=np.zeros(g.number_of_edges())
    
    normMin=sm.norm(minWeights)
    normMax=sm.norm(maxWeights)
    h=nx.draw_networkx_nodes(g,pos=pos,node_color='grey',node_size=20,ax=ax)
    
    thisTarget=int(targetValsTime[time])
    x=(alltargets==thisTarget)
    #     h.set_zorder(1)
    h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color=normweights,width=2,edge_cmap=cmap,edge_vmin=normMin,edge_vmax=normMax,alpha=0.7)
    if thisTarget == 1:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)
        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[np.where(x)[0][0]+1]),node_color='g',node_size=100,ax=ax)

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=drain_pool,node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[0]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(drain_pool[1:]),node_color='k',node_size=100,ax=ax)
    else:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)
        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[np.where(x)[0][0]+1]),node_color='g',node_size=100,ax=ax) #need to write this better

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=drain_pool,node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[thisTarget-1]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(np.hstack(np.array([drain_pool[:thisTarget-1],drain_pool[thisTarget:]]))),node_color='k',node_size=100,ax=ax)


    
    # shift position a little bit
    shift = [-0.1, 0]
    shifted_pos ={node: node_pos + shift for node, node_pos in pos.items()}
#     set_trace()
    # Just some text to print in addition to node ids  
    
    labels = {}
    for node in g.nodes():
        if node in drain_pool:
        #set the node name as the key and the label as its value 
            labels[node] = node
#     labels[1] = 'Drain 1'
#     labels[2] = 'Drain 2'
    nx.draw_networkx_labels(g, shifted_pos, labels=labels, horizontalalignment="left")
    if time in testTimes:
        timeVals='Testing'
    else:
        timeVals='Training'
    
    ax.set_title('T = ' + str(time)+ ' | ' + timeVals + ' | Target ' + str(targetValsTime[time]))

    
#New loadmat:
import scipy.io as spio

def loadmatNew(filename):
    '''
    this function should be called instead of direct spio.loadmat
    as it cures the problem of not properly recovering python dictionaries
    from mat files. It calls the function check keys to cure all entries
    which are still mat-objects
    '''
    data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
    return _check_keys(data)

def _check_keys(dict):
    '''
    checks if entries in dictionary are mat-objects. If yes
    todict is called to change them to nested dictionaries
    '''
    for key in dict:
        if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
            dict[key] = _todict(dict[key])
    return dict        

def _todict(matobj):
    '''
    A recursive function which constructs from matobjects nested dictionaries
    '''
    dict = {}
    for strg in matobj._fieldnames:
        elem = matobj.__dict__[strg]
        if isinstance(elem, spio.matlab.mio5_params.mat_struct):
            dict[strg] = _todict(elem)
        else:
            dict[strg] = elem
    return dict


### Load Network

In [None]:
#combine all sim voltages, conductance and filament together
jV=[]
jF=[]
jC=[]
activeSources=[]
for i in range(len(sims)):
    jV.append(sims[i].junctionVoltage)   
    jF.append(sims[i].filamentState) #Junction Filament Negative
    jC.append(sims[i].junctionConductance)
    activeSources.append(sims[i].sources)

#reshape 
jV=np.array(jV).reshape(len(jV)*numDT,nwJunctions) 
jF=np.array(jF).reshape(len(jF)*numDT,nwJunctions)
jC=np.array(jC).reshape(len(jC)*numDT,nwJunctions)

In [None]:
#Define test and training start and end times
targetValsTime=np.array(dataNoSim[4]).reshape(-1)
alltargets=np.array(dataNoSim[4])[:,0].reshape(-1)
counts=[]
for i in range(numSamples):
    counts.append(i)
    
    
TestTimeStart=((np.array(counts)+1)*numDT*(nback+1))-numDT
TestTimeEnd=((np.array(counts)+1))*numDT*(nback+1)-1
TrainTimeStart=((np.array(counts)+1)*numDT*(nback+1))-numDT*(nback+1)
TrainTimeEnd=((np.array(counts)+1)*numDT*(nback+1))-numDT-1


vals=[]
for i in range(len(counts)):
    vals.append(np.arange(TestTimeStart[i],TestTimeEnd[i]))
testTimes=np.array(vals).reshape(-1)


xlims=[TrainTimeStart[0],TrainTimeEnd[-1]]


In [None]:
# Network Connectivity Map Plot

plt.rcParams['animation.embed_limit'] = 2**64

edge_mode='custom'
animationType = 'conductance'


#Define animation types (current, conductance, filament)
if animationType =='current':
    minWeights=0
    jI=np.array(jV*jC)
    edge_weight = jI
    maxWeights=np.max(jI)
    cmap=plt.cm.magma_r
elif animationType== 'conductance':
    minWeights=0
    maxWeights=np.max(jC)
#     cmap=plt.cm.coolwarm
    colorlist=  ["a0051f","b8323d","cf5f5a","e68c78","fdb995","e4c5bc","cad0e3","cddff1","cfe6f8","d0edff"]#["710000","fad39d","e8f5fd","003470"]

    colorlist.reverse()
    cmap=get_continuous_cmap(colorlist) #hexlist

    edge_weight=jC
elif animationType== 'filament':
    minWeights=0
    maxWeights=np.max(jF)
    colorlist=["950700","c93a28","fd6c50","fb986a","f7c59f","efd6b8","e6e6d0"]
    colorlist.reverse()
    cmap=get_continuous_cmap(colorlist) #hexlist
#     cmap=plt.cm.Pastel1
    edge_weight=jF
    #binary threshold (only show above or below a certain value)

times=[1400, 55800] #choose times

for time in tqdm(times): #loop through times
    
    f,[ax,cax] = plt.subplots(1,2, gridspec_kw={"width_ratios":[50,1]},frameon=False, figsize=(10, 8), dpi=300)
    canvas_width, canvas_height = f.canvas.get_width_height()
    ax.axis('off')
    cax.axis('off')

    pos=nx.kamada_kawai_layout(g)
    edgeList=sims[0].connectivity.edge_list
    numWires=g.number_of_nodes()
    G=getWeightedGraph(adjMat,jC,jV,jF,edgeList,numWires,time,animationType)#, this_TimeStamp = 0):

    # edge_weights=nx.get_edge_attributes(G,'weight')
    # G.remove_edges_from((e for e, w in edge_weights.items() if w <1e-7)) #threshold Gj if required

    edges=G.edges()
    weights=[G[u][v]['weight'] for u,v in edges]

    #Log transform of weights
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=clrs.SymLogNorm(1e-7,base=10,vmin=minWeights, vmax=maxWeights))
    sm.set_array([])

    normweights=sm.norm(weights).data #apply log transformation
    normMin=sm.norm(minWeights)
    normMax=sm.norm(maxWeights)
    
    #Draw original network
    h=nx.draw_networkx_nodes(g,pos=pos,node_color='grey',node_size=20,ax=ax)

    # h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color=normweights,width=2,edge_cmap=cmap,edge_vmin=normMin,edge_vmax=normMax)

    thisTarget=int(targetValsTime[time])
    x=(alltargets==thisTarget)
    #     h.set_zorder(1)
    
    #Draw subgraph edges (Gj)
    h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color=normweights,width=2,edge_cmap=cmap,edge_vmin=normMin,edge_vmax=normMax,alpha=0.6)
    #Change colors of electrode nodes
    if thisTarget == 1:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)
        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[np.where(x)[0][0]+1]),node_color='g',node_size=100,ax=ax)

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=drain_pool,node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[0]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(drain_pool[1:]),node_color='k',node_size=100,ax=ax)
    else:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)
        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[np.where(x)[0][0]+1]),node_color='g',node_size=100,ax=ax) #need to write this better

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=drain_pool,node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[thisTarget-1]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(np.hstack(np.array([drain_pool[:thisTarget-1],drain_pool[thisTarget:]]))),node_color='k',node_size=100,ax=ax)


    if time in testTimes:
        timeVals='Testing'
    else:
        timeVals='Training'

    ax.set_title('T = ' + str(time)+ ' | ' + timeVals + ' | Target ' + str(targetValsTime[time]))

    cbar = plt.colorbar(sm, ax=ax,
                        fraction = 0.05, label=animationType)

# SAVE PLOTS AND JUNCTION CONDUCTANCES

    print(loadReinf+'_'+b)
    plt.savefig(saveFig+str(nwWires)+'nw_'+str(nwJunctions)+'nj_nback_WorkingMemory_'+loadReinf+'_7drains_3x3_'+b+'_t'+str(time)+timeVals+'.pdf',format='pdf')
    with open(dataLoc+str(nwWires)+'nw_'+str(nwJunctions)+'nj_nback_WorkingMemory_'+loadReinf+'_7drains_3x3_'+b+'_JC_t'+str(time)+timeVals+'.pkl', 'wb') as f1:
        pickle.dump(weights, f1) 
    print('Saved')
    
    plt.close(f)

In [None]:
#Animation
anim1 = animation.FuncAnimation(f, image_draw_current_filament, 
                               frames=range(xlims[0],xlims[1],1)[::200], 
                               interval=20, 
                               repeat=False,fargs=(edge_mode,edge_weight,cmap,maxWeights,minWeights))

In [None]:
plt.close('all')

In [None]:
#Save Animation (requires ffmpeg)
currentPath=os.getcwd()
FFwriter = animation.FFMpegWriter()
os.chdir(saveFig)
# if ManualSources:
anim1.save(str(nwWires)+'nw_'+str(nwJunctions)+'nj_nback_WorkingMemory_'+loadReinf+'_3x3_7drains_'+b+'_'+animationType+'_t'+str(xlims[0])+'_'+str(xlims[1])+'.mp4', fps=10,
          progress_callback = lambda i, n: print(f'Saving frame {i} of {n}'),)
# else:
#     anim1.save('nback_explore_'+animationType+'_trial9_t'+str(xlims[0])+'_'+str(xlims[1])+'_magmacmap.mp4', fps=10,
#               progress_callback = lambda i, n: print(f'Saving frame {i} of {n}'),)
os.chdir(currentPath)

## Map Difference

Calculate delta Gj for (reinforced) - (non-reinforced) (t = 558s - t = 14s)

This is for Figure 5e and f in the paper

In [None]:
#Load Junction Conductance:
b='b0p5'
jCnoReinforcement=[]
jCwReinforcement=[]
times=[1400, 55800]
for time in times:
    #LOAD 
    with open(dataLoc+'/Sim Results/working memory nback/'+str(nwWires)+'nw_'+str(nwJunctions)+'nj_nback_WorkingMemory_noReinforcement_7drains_3x3_'+b+'_JC_t'+str(time)+'Testing.pkl', 'rb') as f:
        jCnoReinforcement.append(pickle.load(f)) 
    with open(dataLoc+'/Sim Results/working memory nback/'+str(nwWires)+'nw_'+str(nwJunctions)+'nj_nback_WorkingMemory_wReinforcement_7drains_3x3_'+b+'_JC_t'+str(time)+'Testing.pkl', 'rb') as f:
        jCwReinforcement.append(pickle.load(f)) 
        
jCnoReinforcement=np.array(jCnoReinforcement)
jCwReinforcement=np.array(jCwReinforcement)

In [None]:
#Plot Histograms
plt.rcParams['pdf.fonttype'] = 42

whichEpoch = 1


reducedJCnoReinf=jCnoReinforcement[whichEpoch][jCnoReinforcement[whichEpoch]>2e-5]
reducedJCwReinf=jCwReinforcement[whichEpoch][jCwReinforcement[whichEpoch]>2e-5]

plt.figure(figsize=(12,8))
plt.hist(reducedJCnoReinf,facecolor='b',label='No Reinforcement',bins=15)
plt.hist(reducedJCwReinf,facecolor='r',alpha=0.7,label='With Reinforcement',bins=15)
plt.title('t = '+str(times[whichEpoch]))
plt.legend()
plt.xlabel('Gj')
plt.ylabel('Frequency')
plt.ylim([0,275])
plt.savefig(saveFig+str(nwWires)+'nw_'+str(nwJunctions)+'nj_nback_WorkingMemory_'+b+'_JC_hist_'+str(times[whichEpoch])+'.pdf',format='pdf',dpi=300)

In [None]:
#Calculate Differences
noDiff=jCnoReinforcement[1]-jCnoReinforcement[0]
wDiff=jCwReinforcement[1]-jCwReinforcement[0]
totalDiff=wDiff-noDiff
zeroEdges=totalDiff==0

In [None]:
#Plot Differences:
# plt.close('all')
rcParams['animation.embed_limit'] = 2**64
plt.rcParams['pdf.fonttype'] = 42


reinf='noReinforcement' #wReinforcement = reinforcement #noReiforcement = no reinforcement #'' = diff between reinforcement and no reinf
if reinf=='noReinforcement':
    chosenWeight=noDiff
elif reinf=='wReinforcement':
    chosenWeight = wDiff
else:
    chosenWeight = totalDiff
    
edge_mode='custom'
animationType = 'diff' #the animation we want is the difference between reinfrocement and no reinforcement

minWeights=np.nanmin(chosenWeight)
maxWeights=np.nanmax(chosenWeight)
cmap=plt.cm.coolwarm

#Uncomment for manual colormap

# colorlist=  ["a0051f","b8323d","cf5f5a","e68c78","fdb995","e4c5bc","cad0e3","cddff1","cfe6f8","d0edff"]#["710000","fad39d","e8f5fd","003470"]

# colorlist.reverse()
# cmap=get_continuous_cmap(colorlist) #hexlist

edge_weight=chosenWeight

#Choose which times to plot
times=[55800]#[600, 1400, 55000, 55800, 108800,110200]

for time in tqdm(times): #loop through each time
    
    f,[ax,cax] = plt.subplots(1,2, gridspec_kw={"width_ratios":[50,1]},frameon=False, figsize=(10, 8), dpi=300)
    canvas_width, canvas_height = f.canvas.get_width_height()
    ax.axis('off')
    cax.axis('off')

    pos=nx.kamada_kawai_layout(g)
    edgeList=sims[0].connectivity.edge_list
    numWires=g.number_of_nodes()
    #get weighted subgraph 
    G=getWeightedGraph(adjMat,jC,jV,jF,edgeList,numWires,time,"conductance")#, this_TimeStamp = 0):
    
    count=0
    for u,v,d in G.edges(data=True):
        d['weight']=chosenWeight[count]
        count+=1

    edge_weights=nx.get_edge_attributes(G,'weight')
    G.remove_edges_from((e for e, w in edge_weights.items() if abs(w) <1e-7)) #remove edges with Gj < 1e-7
    edges=G.edges()
    weights=[G[u][v]['weight'] for u,v in edges]
    
    #transform colors to log:
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=clrs.SymLogNorm(1e-7,vmin=minWeights, vmax=maxWeights)) 
    sm.set_array([])

    normweights=sm.norm(weights).data #apply normalization to weights
    normMin=sm.norm(minWeights)
    normMax=sm.norm(maxWeights)

    #Draw original network
    h=nx.draw_networkx_nodes(g,pos=pos,node_color='grey',node_size=20,ax=ax)

    # h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color=normweights,width=2,edge_cmap=cmap,edge_vmin=normMin,edge_vmax=normMax)

    thisTarget=int(targetValsTime[time])
    x=(alltargets==thisTarget)
    #     h.set_zorder(1)
    #Draw subgraph edges (Gj)
    h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color=normweights,width=2,edge_cmap=cmap,edge_vmin=normMin,edge_vmax=normMax,alpha=0.6)
           
    #Change colors of electrode nodes
    if thisTarget == 1:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)
        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[np.where(x)[0][0]+1]),node_color='g',node_size=100,ax=ax)

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(drain_pool),node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[0]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(drain_pool[1:]),node_color='k',node_size=100,ax=ax)
    else:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)

        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[np.where(x)[0][0]+1]),node_color='g',node_size=100,ax=ax) #need to write this better

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(drain_pool),node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[thisTarget-1]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(np.hstack(np.array([drain_pool[:thisTarget-1],drain_pool[thisTarget:]]))),node_color='k',node_size=100,ax=ax)

    if time in testTimes:
        timeVals='Testing'
    else:
        timeVals='Training'

    ax.set_title('T = ' + str(time)+ ' | ' + timeVals + ' | Target ' + str(targetValsTime[time]))

    cbar = plt.colorbar(sm, ax=ax,
                        fraction = 0.05, label=animationType)

# SAVE PLOTS
#     print(loadReinf+'_'+b)
    plt.savefig(saveFig+str(nwWires)+'nw_'+str(nwJunctions)+'nj_nback_WorkingMemory_7drains_3x3_'+b+'_t'+str(time)+timeVals+'_topological_reconfig_'+reinf+'.pdf',format='pdf')    