## Task 1 - Binary Classification Analysis

Run Supervised Learning + Reinforcement on Nanowire Networks for classification. 

Author: Alon Loeffler

Required Files/Folders: learning_functions.py | edamame | asn_nw_00350_nj_01350_seed_1581_avl_10.00_disp_01.00_lx_50.00_ly_50.00.mat

In [None]:
#IMPORTS:
#Append path to Ruomin's Edamame Package (Nanowire Simulations)
import sys
import os

sys.path.append('../../edamame') #point to edamame locally

#Choose which nw:
nwChoice=700 #350 or 700

if nwChoice == 350:
    fileName='/PhysicalReinforcementLearning/generate_networks/Network Data/asn_nw_00350_nj_01350_seed_1581_avl_10.00_disp_01.00_lx_50.00_ly_50.00'
elif nwChoice == 700:
    fileName='/PhysicalReinforcementLearning/generate_networks/Network Data/asn_nw_00698_nj_02582_seed_002_avl_10.00_disp_01.00_lx_75.00_ly_75.00'

#point to network data path ^ 

saveFig='../../Data/Figures/'  #Define Figure Save location locally 
dataLoc='../../Data/' #Define Data Save Location locally 

#import edamame (neuromorphic nanowire python package by Ruomin Zhu)
from edamame import * 
import numpy as np
import matplotlib.pyplot as plt
import copy
from scipy.io import loadmat, savemat
import networkx as nx
from tqdm.notebook import tqdm_notebook as tqdm
from IPython.core.debugger import set_trace

import pickle 
import _pickle as cPickle
import gzip

from learning_functions import genGridNW,point_on_line,dist,getWeightedGraph
from learning_functions import calc_cost,setupStimulus,setupSourcesOnly,runTesting,getNWState,calcOutputs

In [None]:
#Load Network Connectivity + Build Network for analysis
nw=loadmat(fileName)
g=nx.from_numpy_array(nw['adj_matrix'])
adjMat=nw['adj_matrix']

num_drain_training=1
num_drain_testing=2
numSources=4
numSamples=40
nbacks         = [num_drain_testing]
nback=nbacks[0]
numDT=200

connectivity,sources,drain_pool,nwSize,nwJunctions=buildNetworks(fileName,True,num_drain_testing,numSources)

In [None]:
#Parameters
nwJunctions=nw['number_of_junctions'][0][0]
nwWires=nw['number_of_wires'][0][0]
onAmp = 0.3
onAmpTest = 0.1
numDT=200
nbacks=[2]
sampleTimes=numDT*(nbacks[0]+1)
signalLen=numDT*len(alltargets)

In [None]:
#Convert training and testing epochs to times:
nback=nbacks[0]
tmp=numDT*nback
A=np.array(range(0,signalLen+1,int(tmp+numDT))) #training 
B=np.array(range(0+tmp,signalLen+1,int(tmp+numDT)))# testing
c=[]
for k in range(len(B)):
    c.append(range(A[k],B[k]))
testTimes=np.array(np.hstack(c))
l=range(signalLen)
trainTimes=np.array(np.hstack([m for m in l if m not in testTimes]))

#test and train samples:
testSamples=[]
for i in range(-1,len(alltargets),nback+1):
    if i > 0:
        testSamples.append(i)
testSamples=np.array(testSamples)
trainSamples = np.setdiff1d(np.array(range(len(alltargets))),testSamples)

#Target data in time
targets=np.array(alltargets)[testSamples]

In [None]:
#Load Simulation Data

#Auxiliary Data
filamentVal = '0p5'
reinforcement='no'
loadName='data_'+str(nwWires)+'nw_'+str(nwJunctions)+'nj_2x2_wThresh_Vtrn'+str(onAmp)+'_Vtst'+str(onAmpTest)+'_T200_simple_nback'+str(nbacks[0])+'_filament'+filamentVal+'_DataNoSim_'+reinforcement+'Reinforcement'
with open(dataLoc+'Sim Results/'+loadName+'.pkl', 'rb') as f:
    dataNoSim=pickle.load(f)[0]
    
t=dataNoSim['current']
thresholds=dataNoSim['threshold']
accuracy=dataNoSim['accuracy']
alltargets=dataNoSim['targets']
trainingOrder=dataNoSim['training order']
trainingLabels=dataNoSim['training labels']
testingLabels=dataNoSim['testing labels']
allLabels=dataNoSim['all labels']
numTestingSamples=dataNoSim['num test samples']
params=dataNoSim['Parameters']
    
#Sim data
loadName='data_'+str(nwWires)+'nw_'+str(nwJunctions)+'nj_2x2_wThresh_Vtrn'+str(onAmp)+'_Vtst'+str(onAmpTest)+'_T200_simple_nback'+str(nbacks[0])+'_filament'+filamentVal+'_DataSimOnly_'+reinforcement+'Reinforcement'

with open(dataLoc+'Sim Results/'+loadName+'.pkl', 'rb') as f:
    sims=pickle.load(f)
sims=sims[0]['sim']

In [None]:
#Electrode current and voltage for training and testing

# jV=[]
# jF=[]
# jC=[]
# activeSources=[]
elecItrain=[];elecItest=[];inputSignal=[]
drain1=[];drain2=[];voltage=[]
t=0
nback=nbacks[0]
for i in range(len(sims)):
#     jV.append(sim[i].junctionVoltage)   
#     jF.append(sim[i].filamentState) #Junction Filament Negative
#     jC.append(sim[i].junctionConductance)
#     activeSources.append(sim[i].sources)
    voltage.append(sims[i].wireVoltage[:,drain_pool])
    if sims[i].electrodeCurrent.shape[1] == 5:
        tmp=sims[i].electrodeCurrent[:,0]
        elecItrain.append(tmp)
        if alltargets[i][0]==0:
            drain1.append(tmp)
            drain2.append([0]*tmp)
        else:
            drain2.append(tmp)
            drain1.append([0]*tmp)
    else: #only test:
#         if len(elecItest[i]) < 600:
        elecItest.append(sims[i].electrodeCurrent[:,:2])
        drain1.append(elecItest[t][:,0])
        drain2.append(elecItest[t][:,1])
        t+=1
    
    
drain1=np.hstack(np.array(drain1))
drain2=np.hstack(np.array(drain2))
voltage=np.vstack(np.array(voltage))


In [None]:
#Threshold Figure
fig=plt.figure(figsize=(15,8))
p1=plt.plot(np.array(thresholds[::3])[:,0],c='b')
p2=plt.plot(np.array(thresholds[::3])[:,1],c='r')
plt.xlabel('Samples')
plt.ylabel('Voltage Threshold')
plt.legend([p1[0],p2[0]],['Drain 1','Drain 2'],loc='lower right',title='Current')
plt.ylim([-5e-7,1.05e-5])
# if patternSize==2:
#     fig.savefig(saveFig+'698nw_2582nj_2x2_simple_nback2_filament'+filamentVal+'_threshValues.pdf',format='pdf',dpi=300)
# else:
fig.savefig(saveFig+loadName+'_threshValues.pdf',format='pdf',dpi=300)

In [None]:
#Current Figure
plt.figure(figsize=(15,10))
plt.plot(drain1)
plt.plot(drain2)
ax=plt.gca()


In [None]:
#Convert Data to Pandas dataframe
import pandas as pd
nback=nbacks[0]
thisTrainInputs=[[] for i in range(len(trainingLabels))]
for trial in tqdm(range(len([trainingLabels]))): #for each trial
    #bin training samples
    thisLabels=np.array(allLabels)
    tmp=trainingInputs[trainingOrder]
    
    #plot training inputs that are to be tested:
    count = 0
    for inpt in tmp:
        thisTrainInputs[count]=(inpt.reshape(-1,1))
        count+=1
    trainlabel=trainingLabels#[:-4]
    trainlabel=np.array(trainlabel).reshape(-1, nback)
    thisTestTime = testTimes
    targetTimes = signal_expand(targets,numDT)
    
    currents=np.array(t)
    thisTarget=targets
    testlabel=testingLabels[:numTestingSamples]
        
    if nback == 2:    
        #seperate to classes 

        c1c1notest=[0]
        c1c2notest=[0]
        c2c1notest=[0]
        c2c2notest=[0]
        count1=0
        count2=0
        c2count=[]
        c1count=[]
        classOrder=[]
        targetClass=[]
        for label in trainlabel:
            if np.all(label == [0,0]):
                classOrder.append('c1-c1')
                targetClass.append(1)
                count1+=1
            elif np.all(label == [0,1]):
                classOrder.append('c1-c2')
                targetClass.append(1)
                count1+=1
            elif np.all(label == [1,1]):
                classOrder.append('c2-c2')
                targetClass.append(2)
                count2+=1
            elif np.all(label == [1,0]):
                classOrder.append('c2-c1')
                targetClass.append(2)
                count2+=1
            c2count.append(count2)
            c1count.append(count1)
       #count how many times the network has seen c1-c1,c1-c2/c2-c1,c2-c2 at each test point:
        for i in range(len(trainlabel)):
            a=np.sum(np.array_equiv(trainlabel[i],np.array([0,0])))
            b=np.sum(np.array_equiv(trainlabel[i],np.array([0,1])))
            c=np.sum(np.array_equiv(trainlabel[i],np.array([1,0])))
            d=np.sum(np.array_equiv(trainlabel[i],np.array([1,1])))

            c1c1notest.append(c1c1notest[i]+a)
            c1c2notest.append(c1c2notest[i]+b)
            c2c1notest.append(c2c1notest[i]+c)
            c2c2notest.append(c2c2notest[i]+d)


        #count how many times the network has seen c1/c2 at each test point:
        count1notest=[]
        count2notest=[]
        temp1=trainlabel.reshape(-1)
        for i in range(len(temp1)):
            if i == 0:
                count1notest.append(0)
                count2notest.append(0)
            if i % nback == 0 and i > 0:
                count1notest.append(np.sum(temp1[:i]==0))
                count2notest.append(np.sum(temp1[:i]==1))

        count1plustest=[]
        count2plustest=[]
        temp1=thisLabels
        n=nback+1
        for i in range(len(temp1)):
            if i == 0:
                count1plustest.append(0)
                count2plustest.append(0)
            if i % n == 0 and i > 0: 
                count1plustest.append(np.sum(temp1[:i]==0))
                count2plustest.append(np.sum(temp1[:i]==1))            

        cumsumdiffnotest=np.array(count1notest)-np.array(count2notest)
        cumsumdifftest=np.array(count1plustest)-np.array(count2plustest)    

#         if len(accuracy)!= len(targetClass):
#             temp=np.append(accuracy,np.nan)
#         set_trace()
        df=pd.DataFrame({'Class':np.array(targetClass),'Order':classOrder,'Trial':trial,
                         'Accuracy':accuracy,'C1NoTest':count1notest,
                         'C2NoTest':count2notest,'CountDiff':np.array(c2count)-np.array(c1count),
                         'C1+Test':count1plustest,'C2+Test':count2plustest,'C1C1count':c1c1notest[1:],
                         'C1C2count':c1c2notest[1:], 'C2C1count':c2c1notest[1:],'C2C2count':c2c2notest[1:],
                        })
   
    
    
    counts=[]
    for i in range(len(df)):
        counts.append(i)
        
    d = {'OrderVal':counts,'Train Time Start':((np.array(counts)+1)*numDT*(nback+1))-numDT*(nback+1),
                     'Train Time End':((np.array(counts)+1)*numDT*(nback+1))-numDT-1,
                     'Test Time Start':((np.array(counts)+1)*numDT*(nback+1))-numDT,
                     'Test Time End':((np.array(counts)+1))*numDT*(nback+1)-1,}
    df = df.join(pd.DataFrame(d, index=df.index))


In [None]:
#Display dataframe:
df.head()

In [None]:
#Calculate accuracies:
dfAccuracy = df.groupby('Order').agg([np.mean,np.std])
dfAccuracy= dfAccuracy['Accuracy']

In [None]:
#Generate Figure 2: 
#Current Strength vs Accuracy:

targetsNew=targets[:,0]
plt.rcParams['pdf.fonttype'] = 42


fig,ax = plt.subplots(1,1, figsize=(12, 8), dpi=300)
x=(drain1)
p1=ax.plot(x,'b',alpha=0.5)
p2=ax.plot((drain2),'r',alpha=0.5)
plt.legend([p1[0],p2[0]],['Drain 1','Drain 2'],loc='lower right',title='Current')


if nback == 1:
    xlimmax=830
    vlinesStep=40
    vlinesStart=30
    dotpos=2

if nback == 2:
#     xlimmax=1230
    vlinesStep=600
    vlinesStart=500
    dotpos=3
elif nback == 3:
    xlimmax=1630
    vlinesStep=80
    vlinesStart=70
    dotpos=4
    
colrs=[]
for val in targetsNew:
    if val < 1:
        colrs.append('C1')
    else:
        colrs.append('C2')    

testColors=[]
for i in colrs:
    if i == 'C1':
        testColors.append('b')
    else:
        testColors.append('r')
        
ax.set_xlabel('Time')
ax.set_ylabel('Current')
ax.set_ylim([-5e-7,1.05e-5])

# ax.vlines(range(vlinesStart,len(x),vlinesStep),np.min(drain2),np.max(drain2),linestyle='dashed',color=testColors,alpha=0.5)

ax2=ax.twinx().twiny()
ax2.scatter(list(range(0,len(targetsNew),1)),df['Accuracy'],c=testColors)
ax2.set_ylim([-0.05,1.05])

# ax2.set_xticklabels(np.array(range(0,121,20)))#np.round(np.array(range(vlinesStart,len(x)-20,vlinesStep))/39)-1)

# testTimeVals=np.array(range(len(accuracy)))*dotpos+dotpos


# ax2.scatter(x=range(vlinesStart,len(x)-20,vlinesStep),y=accuracy,c=testColors)

ax2.set_xlabel('Sample Num')
ax2.set_ylabel('Accuracy')
# sns.scatterplot(x=testTimeVals,y='Accuracy',data=accuracy,hue=colrs[:-1],palette=['r','b'],ax=ax2)
# ax2.legend(title='Accuracy',loc='lower left')
# ax2.set_xlim([-20/20,(xlimmax+10)/20])
# ax2.set_ylabel('Accuracy')

# ax3=ax2.twinx()
# p1=ax3.plot(testTimeVals[:-1],count1plustest[:-1],color='c',marker='x')
# p2=ax3.plot(testTimeVals[:-1],count2plustest[:-1],color='m',marker='x')
# ax3.set_xlim([-20/20,1240/20])

# plt.vlines(range(3,testTimeVals.values[-1],3),0,1,linestyle='dashed',color='k',)
# if patternSize==2:
#     fig.savefig(saveFig+saveName+'.pdf',format='pdf',dpi=300)
# else:
fig.savefig(saveFig+'698w_2582j_current_time_accuracy_nback2_2x2_'+filamentVal+'_'+reinforcement+'_reinforcement.pdf',format='pdf',dpi=300)



In [None]:
#Generate Figure 2 Insets:
tstart=4800
tstartepoch=int((tstart/numDT)/3)
tend=7200
tendepoch=int((tend/numDT)/3)
dur=tend-tstart

plt.plot(drain1[tstart:tend],label='Drain 1',c='b')
plt.plot(drain2[tstart:tend],label='Drain 2',c='r')
thresh1=[np.array(thresholds)[(tstartepoch),0]]*int(len(drain1[tstart:tend])/2)
thresh2=[np.array(thresholds)[(tendepoch),0]]*int(len(drain1[tstart:tend])/2)
thresh=np.hstack((thresh1,thresh2))
plt.plot(thresh,'--',c='b')
thresh1=[np.array(thresholds)[(tstartepoch),1]]*int(len(drain2[tstart:tend])/2)
thresh2=[np.array(thresholds)[(tendepoch),1]]*int(len(drain2[tstart:tend])/2)
thresh=np.hstack((thresh1,thresh2))
plt.plot(thresh,'--',c='r')

ax=plt.gca()
ax.set_xticks(range(0,2400+1,400))
ax.set_xticklabels(range(tstart,tend+1,400))
ax.set_xlabel('Timesteps')
ax.set_ylabel('Current (A)')
plt.legend()
plt.savefig(saveFig+'698w_2582j_current_time_oneEpoch_nback2_2x2_'+filamentVal+'_'+reinforcement+'reinforcement.pdf',format='pdf')

In [None]:
#Generate Supp Figure 1:
#Voltage vs Accuracy:
import seaborn as sns
plt.rcParams['pdf.fonttype'] = 42


fig,ax = plt.subplots(1,1, figsize=(15, 8), dpi=300)
x=(voltage[:,0])
p1=ax.plot(x,'b',alpha=0.5)
p2=ax.plot((voltage[:,1]),'r',alpha=0.5)
plt.legend([p1[0],p2[0]],['Drain 1','Drain 2'],loc='lower right',title='Voltage')


if nback == 1:
    xlimmax=830
    vlinesStep=40
    vlinesStart=30
    dotpos=2

if nback == 2:
#     xlimmax=1230
    vlinesStep=600
    vlinesStart=500
    dotpos=3
elif nback == 3:
    xlimmax=1630
    vlinesStep=80
    vlinesStart=70
    dotpos=4
    
colrs=[]
for val in targetsNew:
    if val < 1:
        colrs.append('C1')
    else:
        colrs.append('C2')    

testColors=[]
for i in colrs:
    if i == 'C1':
        testColors.append('b')
    else:
        testColors.append('r')
        
ax.set_xlabel('Time')
ax.set_ylabel('Voltage')
# ax.set_xlim([-600,24400])

# ax.vlines(range(vlinesStart,len(x),vlinesStep),np.min(voltage[:,1]),np.max(voltage[:,1]),linestyle='dashed',color=testColors,alpha=0.5)

# ax2=ax.twinx().twiny()
# ax2.scatter(list(range(0,len(targetsNew),1)),df['Accuracy'][:-1],c=testColors)
# ax2.set_ylim([-0.05,1.05])

# ax2.set_xticklabels(np.array(range(0,121,20)))#np.round(np.array(range(vlinesStart,len(x)-20,vlinesStep))/39)-1)

# testTimeVals=np.array(range(len(accuracy)))*dotpos+dotpos


# ax2.scatter(x=range(vlinesStart,len(x)-20,vlinesStep),y=accuracy,c=testColors)

ax2.set_xlabel('Sample Num')
ax2.set_ylabel('Accuracy')
# sns.scatterplot(x=testTimeVals,y='Accuracy',data=accuracy,hue=colrs[:-1],palette=['r','b'],ax=ax2)
# ax2.legend(title='Accuracy',loc='lower left')
# ax2.set_xlim([-20/20,(xlimmax+10)/20])
# ax2.set_ylabel('Accuracy')

# ax3=ax2.twinx()
# p1=ax3.plot(testTimeVals[:-1],count1plustest[:-1],color='c',marker='x')
# p2=ax3.plot(testTimeVals[:-1],count2plustest[:-1],color='m',marker='x')
# ax3.set_xlim([-20/20,1240/20])

# plt.vlines(range(3,testTimeVals.values[-1],3),0,1,linestyle='dashed',color='k',)
# if patternSize==2:
#     fig.savefig(saveFig+saveName+'.pdf',format='pdf',dpi=300)
# else:
fig.savefig(saveFig+'698w_2582j_voltage_time_accuracy_nback2_2x2_'+filamentVal+'_'+reinforcement+'reinforcement.pdf',format='pdf',dpi=300)


In [None]:
#Voltage Insets:
tstart=600
tend=1800
dur=tend-tstart
plt.plot(voltage[tstart:tend,0],label='Drain 1')
plt.plot(voltage[tstart:tend,1],label='Drain 2')
plt.hlines(np.array(thresholds)[:,0],0,dur,colors='black',linestyles='dashed',label='Threshold')
ax=plt.gca()
# ax.set_xticks(range(0,10))
ax.set_xticklabels(range(tstart-200,tend+200,int(dur/6)))
ax.set_xlabel('Timesteps')
ax.set_ylabel('Voltage (V)')
plt.legend()
plt.savefig(saveFig+'698w_2582j_voltage_time_oneEpoch_nback2_2x2_'+filamentVal+'_'+reinforcement+'reinforcement.pdf',format='pdf')

## Network Connectivity

In [1]:
sys.path.append('/import/silo2/aloe8475/Documents/CODE/aux_functions/') #point to edamame locally

from matplotlib import animation
from matplotlib import colorbar as clbr
from PIL import Image

from matplotlib import rcParams
from create_colormap import get_continuous_cmap,hex_to_rgb,rgb_to_dec

#Functions:
def getWeightedGraph(adjMat,jC,jV,jF,edgeList,numWires,time,edge_mode):#, this_TimeStamp = 0):
#     edgeList = network['edge_list']
    adjMat = np.zeros((numWires, numWires))
#     set_trace()
    if edge_mode=='conductance':
        adjMat[edgeList[:,0], edgeList[:,1]] = jC[time,:]#network.junctionSwitch[this_TimeStamp,:]
        adjMat[edgeList[:,1], edgeList[:,0]] = jC[time,:]#network.junctionSwitch[this_TimeStamp,:]
        maxWeights=np.max(jC)

    elif edge_mode=='current':
        adjMat[edgeList[:,0], edgeList[:,1]] = jV[time,:]*jC[time,:]
        adjMat[edgeList[:,1], edgeList[:,0]] = jV[time,:]*jC[time,:]
        maxWeights=np.max(jV*jC)
        
        
    elif edge_mode=='filament':
        adjMat[edgeList[:,0], edgeList[:,1]] = jF[time,:]
        adjMat[edgeList[:,1], edgeList[:,0]] = jF[time,:]
        
    WeightedGraph = nx.from_numpy_array(adjMat)
    WeightedGraph=nx.DiGraph.to_undirected(WeightedGraph)
    
    return WeightedGraph

def image_draw_voltage(time,cmap,maxWeights,minWeights):
    ax.clear() 
    pos=nx.kamada_kawai_layout(g)
           
    #ALON'S CODE:    
    node_weight=sim.wireVoltage[time]

    pos=nx.kamada_kawai_layout(g)
    nodeList=g.nodes
    numWires=g.number_of_nodes()

    h=nx.draw_networkx_nodes(g,pos=pos,node_color=node_weight,cmap=cmap,node_size=50,ax=ax)

    h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color='grey')
    # nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='g',node_size=50,ax=ax)
    nx.draw_networkx_nodes(g,pos=pos,nodelist=[drains[1]],node_color='r',node_size=10,ax=ax)
    nx.draw_networkx_nodes(g,pos=pos,nodelist=[drains[0]],node_color='b',node_size=10,ax=ax)
    ax.set_title(str(time))
    
import matplotlib.colors as clrs

def image_draw_current_filament(time,edge_mode,edge_weight,cmap,maxWeights,minWeights):
    ax.clear() 
#     pos=nx.kamada_kawai_layout(g)
        
    #RUOMIN'S CODE:
#     draw_graph(sim,time,edge_mode=edge_mode,edge_weight=edge_weight,edge_colorbar=False,node_size=30,figsize=(10,8),edge_cmap=cmap,with_labels=False,norm=True,ax=ax,maxWeight=maxWeights,minWeight=minWeights)

    #ALON'S CODE:    
    
#     edge_mode=animationType
#     edgeList=sim.connectivity.edge_list
#     numWires=g.number_of_nodes()
#     G=getWeightedGraph(sim,edgeList,numWires,time,edge_mode)
    
    pos=nx.kamada_kawai_layout(g)
    edgeList=sims[0].connectivity.edge_list
    numWires=g.number_of_nodes()
    G=getWeightedGraph(adjMat,jC,jV,jF,edgeList,numWires,time,animationType)#, this_TimeStamp = 0):

    edge_weights=nx.get_edge_attributes(G,'weight')
#     G.remove_edges_from((e for e, w in edge_weights.items() if w <1e-)) 
    
    edges=G.edges()
    weights=[G[u][v]['weight'] for u,v in edges]
    
    if animationType=='filament':
        logVal=1e-7
    else:
        logVal=1e-7
    Q
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=clrs.SymLogNorm(logVal,base=10,vmin=minWeights, vmax=maxWeights))
    sm.set_array([])
    normweights=sm.norm(weights).data
    if len(normweights)==0:
        normweights=np.zeros(g.number_of_edges())
    
    normMin=sm.norm(minWeights)
    normMax=sm.norm(maxWeights)
    h=nx.draw_networkx_nodes(g,pos=pos,node_color='grey',node_size=20,ax=ax)
    
    thisTarget=int(targetValsTime[time])
    x=(alltargets==thisTarget)
    #     h.set_zorder(1)
    h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color=normweights,width=2,edge_cmap=cmap,edge_vmin=normMin,edge_vmax=normMax,alpha=0.7)
    if thisTarget == 1:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)
        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[np.where(x)[0][0]+1]),node_color='g',node_size=100,ax=ax)

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=drain_pool,node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[0]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(drain_pool[1:]),node_color='k',node_size=100,ax=ax)
    else:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)
        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[np.where(x)[0][0]+1]),node_color='g',node_size=100,ax=ax) #need to write this better

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=drain_pool,node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[thisTarget-1]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(np.hstack(np.array([drain_pool[:thisTarget-1],drain_pool[thisTarget:]]))),node_color='k',node_size=100,ax=ax)


    
    # shift position a little bit
    shift = [-0.1, 0]
    shifted_pos ={node: node_pos + shift for node, node_pos in pos.items()}
#     set_trace()
    # Just some text to print in addition to node ids  
    
    labels = {}
    for node in g.nodes():
        if node in drain_pool:
        #set the node name as the key and the label as its value 
            labels[node] = node
#     labels[1] = 'Drain 1'
#     labels[2] = 'Drain 2'
    nx.draw_networkx_labels(g, shifted_pos, labels=labels, horizontalalignment="left")
    if time in testTimes:
        timeVals='Testing'
    else:
        timeVals='Training'
    
    ax.set_title('T = ' + str(time)+ ' | ' + timeVals + ' | Target ' + str(targetValsTime[time]))

    
#New loadmat:
import scipy.io as spio

def loadmatNew(filename):
    '''
    this function should be called instead of direct spio.loadmat
    as it cures the problem of not properly recovering python dictionaries
    from mat files. It calls the function check keys to cure all entries
    which are still mat-objects
    '''
    data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
    return _check_keys(data)

def _check_keys(dict):
    '''
    checks if entries in dictionary are mat-objects. If yes
    todict is called to change them to nested dictionaries
    '''
    for key in dict:
        if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
            dict[key] = _todict(dict[key])
    return dict        

def _todict(matobj):
    '''
    A recursive function which constructs from matobjects nested dictionaries
    '''
    dict = {}
    for strg in matobj._fieldnames:
        elem = matobj.__dict__[strg]
        if isinstance(elem, spio.matlab.mio5_params.mat_struct):
            dict[strg] = _todict(elem)
        else:
            dict[strg] = elem
    return dict


NameError: name 'sys' is not defined

In [None]:
#combine all sim voltages, conductance and filament together
jV=[]
jF=[]
jC=[]
activeSources=[]
for i in range(len(sims)):
    jV.append(sims[i].junctionVoltage)   
    jF.append(sims[i].filamentState) #Junction Filament Negative
    jC.append(sims[i].junctionConductance)
    activeSources.append(sims[i].sources)
    
#Reshape data
jV=np.array(jV).reshape(len(jV)*numDT,nwJunctions) 
jF=np.array(jF).reshape(len(jF)*numDT,nwJunctions)
jC=np.array(jC).reshape(len(jC)*numDT,nwJunctions)


In [None]:
#Times for network connectivity analysis:
targetValsTime=np.array(dataNoSim['targets']).reshape(-1)
# alltargets=np.array(dataNoSim['targets']).reshape(-1)
counts=[]
for i in range(numSamples):
    counts.append(i)
    
    
TestTimeStart=((np.array(counts)+1)*numDT*(nback+1))-numDT
TestTimeEnd=((np.array(counts)+1))*numDT*(nback+1)-1
TrainTimeStart=((np.array(counts)+1)*numDT*(nback+1))-numDT*(nback+1)
TrainTimeEnd=((np.array(counts)+1)*numDT*(nback+1))-numDT-1


vals=[]
for i in range(len(counts)):
    vals.append(np.arange(TestTimeStart[i],TestTimeEnd[i]))
testTimes=np.array(vals).reshape(-1)


xlims=[TrainTimeStart[0],TrainTimeEnd[-1]]

In [2]:
#Generate Network Connectivity Figures (Supp Figure 2)
# plt.close('all')
rcParams['animation.embed_limit'] = 2**64


edge_mode='custom'
animationType = 'conductance'


if animationType =='current':
    minWeights=0
    jI=np.array(jV*jC)
    edge_weight = jI
    maxWeights=np.max(jI)
    cmap=plt.cm.magma_r
elif animationType== 'conductance':
    minWeights=0
    maxWeights=np.max(jC)
#     cmap=plt.cm.coolwarm
    colorlist=  ["a0051f","b8323d","cf5f5a","e68c78","fdb995","e4c5bc","cad0e3","cddff1","cfe6f8","d0edff"]#["710000","fad39d","e8f5fd","003470"]

    colorlist.reverse()
    cmap=get_continuous_cmap(colorlist) #hexlist

    edge_weight=jC
elif animationType== 'filament':
    minWeights=0
    maxWeights=np.max(jF)
    colorlist=["950700","c93a28","fd6c50","fb986a","f7c59f","efd6b8","e6e6d0"]


    colorlist.reverse()
    cmap=get_continuous_cmap(colorlist) #hexlist
#     cmap=plt.cm.Pastel1
    edge_weight=jF
    #binary threshold (only show above or below a certain value)

times=[4800, 5300, 5400, 5900]

for time in tqdm(times):
    
    f,[ax,cax] = plt.subplots(1,2, gridspec_kw={"width_ratios":[50,1]},frameon=False, figsize=(10, 8), dpi=300)
    canvas_width, canvas_height = f.canvas.get_width_height()
    ax.axis('off')
    cax.axis('off')

    # clrplt=np.linspace(minWeights,maxWeights,1000).reshape(1000,1)
    # img=cax.imshow(clrplt,cmap=cmap,norm=clrs.LogNorm())
    # cax.set_visible(False)
    # formatter = LogFormatter(10, labelOnlyBase=True) 
    # norm=clrs.SymLogNorm(1,base=10,vmin=minWeights,vmax=maxWeights)
    # cb1=clbr.ColorbarBase(cax,cmap=cmap,norm=norm)


    pos=nx.kamada_kawai_layout(g)
    edgeList=sims[0].connectivity.edge_list
    numWires=g.number_of_nodes()
    G=getWeightedGraph(adjMat,jC,jV,jF,edgeList,numWires,time,animationType)#, this_TimeStamp = 0):

    # edge_weights=nx.get_edge_attributes(G,'weight')
    # G.remove_edges_from((e for e, w in edge_weights.items() if w <1e-10)) 

    edges=G.edges()
    weights=[G[u][v]['weight'] for u,v in edges]

    # normweights = (np.array(weights)-minWeights) / (maxWeights - minWeights)
    # weights = [cmap(color_normal(e)) for e in weights]
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=clrs.SymLogNorm(1e-7,base=10,vmin=minWeights, vmax=maxWeights))
    sm.set_array([])

    normweights=sm.norm(weights).data
    normMin=sm.norm(minWeights)
    normMax=sm.norm(maxWeights)

    h=nx.draw_networkx_nodes(g,pos=pos,node_color='grey',node_size=20,ax=ax)

    # h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color=normweights,width=2,edge_cmap=cmap,edge_vmin=normMin,edge_vmax=normMax)

    thisTarget=int(targetValsTime[time])
#     x=(targetValsTime==thisTarget)
    #     h.set_zorder(1)
    h2=nx.draw_networkx_edges(G,pos=pos,ax=ax,edge_color=normweights,width=2,edge_cmap=cmap,edge_vmin=normMin,edge_vmax=normMax,alpha=0.6)
    if thisTarget == 1:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)
        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[thisTarget]),node_color='g',node_size=100,ax=ax)

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=drain_pool,node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[0]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(drain_pool[1:]),node_color='k',node_size=100,ax=ax)
    else:
        nx.draw_networkx_nodes(g,pos=pos,nodelist=sources,node_color='k',node_size=100,ax=ax)

    #     set_trace()

        nx.draw_networkx_nodes(g,pos=pos,nodelist=list(activeSources[thisTarget]),node_color='g',node_size=100,ax=ax) #need to write this better

        if time in testTimes:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=drain_pool,node_color='r',node_size=100,ax=ax)
        else:
            nx.draw_networkx_nodes(g,pos=pos,nodelist=[drain_pool[thisTarget-1]],node_color='r',node_size=100,ax=ax)
            nx.draw_networkx_nodes(g,pos=pos,nodelist=list(np.hstack(np.array([drain_pool[:thisTarget-1],drain_pool[thisTarget:]]))),node_color='k',node_size=100,ax=ax)




    if time in testTimes:
        timeVals='Testing'
    else:
        timeVals='Training'

    ax.set_title('T = ' + str(time)+ ' | ' + timeVals + ' | Target ' + str(targetValsTime[time]))

    cbar = plt.colorbar(sm, ax=ax,
                        fraction = 0.05, label=animationType)


#     print(loadReinf+'_'+b)
    plt.savefig(saveFig+str(nwWires)+'nw_'+str(nwJunctions)+'nj_nback_WorkingMemory_'+reinforcement+'Reinforcement_2drains_2x2_'+filamentVal+'_t'+str(time)+timeVals+'.pdf',format='pdf')
#     with open(dataLoc+'/Sim Results/working memory nback/'+str(nwWires)+'nw_'+str(nwJunctions)+'nj_nback_WorkingMemory_'+loadReinf+'_7drains_3x3_'+b+'_JC_t'+str(time)+timeVals+'.pkl', 'wb') as f1:
#         pickle.dump(weights, f1) 
#     print('Saved')
    
    plt.close(f)