## Task 2 - Binary Classification Analysis 

Run Supervised Learning + Reinforcement on Nanowire Networks for classification. 

Author: Alon Loeffler

Required Files/Folders: learning_functions.py | edamame | asn_nw_00350_nj_01350_seed_1581_avl_10.00_disp_01.00_lx_50.00_ly_50.00.mat

In [2]:
#IMPORTS:
#Append path to Ruomin's Edamame Package (Nanowire Simulations)
import sys
import os

sys.path.append('../../edamame') #point to edamame locally

#Choose which nw:
nwChoice=700 #350 or 700

if nwChoice == 350:
    fileName='/PhysicalReinforcementLearning/generate_networks/Network Data/asn_nw_00350_nj_01350_seed_1581_avl_10.00_disp_01.00_lx_50.00_ly_50.00'
elif nwChoice == 700:
    fileName='/PhysicalReinforcementLearning/generate_networks/Network Data/asn_nw_00698_nj_02582_seed_002_avl_10.00_disp_01.00_lx_75.00_ly_75.00'

#point to network data path ^ 

saveFig='../../Data/Figures/'  #Define Figure Save location (not in github repo)
dataLoc='../../Data/' #Define Data Save Location (not in github repo)

#import edamame (neuromorphic nanowire python package by Ruomin Zhu)
from edamame import * 
import numpy as np
import matplotlib.pyplot as plt
import copy
from scipy.io import loadmat, savemat
import networkx as nx
from tqdm.notebook import tqdm_notebook as tqdm
from IPython.core.debugger import set_trace

import pickle 
import _pickle as cPickle
import gzip

from learning_functions import genGridNW,point_on_line,dist,getWeightedGraph
from learning_functions import calc_cost,setupStimulus,setupSourcesOnly,runTesting,getNWState,calcOutputs

### Load Sim Data for each n-back

In [None]:
nbacks=[2] #Define which nback to load
filamentVal = '2' #change filament vals here (0p1,0p2,0p5,2,5)
reinforcement='no' #'w' or 'no'
if reinforcement == 'w':
    loadName='data_'+str(nwSize)+'nw_'+str(nwJunctions)+'nj_3x3_wThresh_Vtrn'+str(onAmp)+'_Vtst'+str(onAmpTest)+'_T200_plus_cross_nback'+str(nbacks[0])+'_filament'+filamentVal+'_DataSimOnly'
else:
    loadName='data_'+str(nwSize)+'nw_'+str(nwJunctions)+'nj_3x3_wThresh_Vtrn'+str(onAmp)+'_Vtst'+str(onAmpTest)+'_T200_plus_cross_nback'+str(nbacks[0])+'_filament'+filamentVal+'_noReinforcement_DataSimOnly'

with open(dataLoc+'Sim Results/'+loadName+'.pkl', 'rb') as f:
    sims=pickle.load(f)[0]

In [None]:
if reinforcement == 'w':
    loadName='data_'+str(nwSize)+'nw_'+str(nwJunctions)+'nj_3x3_wThresh_Vtrn'+str(onAmp)+'_Vtst'+str(onAmpTest)+'_T200_plus_cross_nback'+str(nbacks[0])+'_filament'+filamentVal+'_DataNoSim'
else:
    loadName='data_'+str(nwSize)+'nw_'+str(nwJunctions)+'nj_3x3_wThresh_Vtrn'+str(onAmp)+'_Vtst'+str(onAmpTest)+'_T200_plus_cross_nback'+str(nbacks[0])+'_filament'+filamentVal+'_noReinforcement_DataNoSim'

with open(dataLoc+'Sim Results/'+loadName+'.pkl', 'rb') as f:
    dataNoSim=pickle.load(f)[0]
    
#Accuracy:
acc=[]
for l in dataNoSim['accuracy']:
    l = np.array(l[:40])
    acc.append(l) 

#Current
curr=np.zeros((len(dataNoSim['current']),len(dataNoSim['current'][0][0])))
i=0
for l in dataNoSim['current']:
    a = np.array(l[0][0])
    curr[i]=np.array(a[0])
    i+=1
    
#Thresholds:  
thresh=[]
for l in dataNoSim['threshold']:
    l = np.array(l[:40])
    thresh.append(l) 
    

In [None]:
#Statistical Analyses: Mean and SE
meanThresh1=np.mean(np.array(thresh),axis=0)[:,0]
meanThresh2=np.mean(np.array(thresh),axis=0)[:,1]

seThresh1=np.std(np.array(thresh),axis=0)[:,0]/np.sqrt(len(thresh))
seThresh2=np.std(np.array(thresh),axis=0)[:,1]/np.sqrt(len(thresh))

meanCurr1=np.mean(np.array(curr),axis=0)[:,0]
meanCurr2=np.mean(np.array(curr),axis=0)[:,1]

seCurr1=np.std(np.array(curr),axis=0)[:,0]/np.sqrt(len(curr))
seCurr2=np.std(np.array(curr),axis=0)[:,1]/np.sqrt(len(curr))


In [None]:
#Plot accuracy:
plt.figure()
ax=plt.gca()
mean=np.mean(np.array(acc),axis=0)
se=np.std(np.array(acc),axis=0)/np.sqrt(len(acc))
x=range(len(mean))
ax.plot(x,mean,c='k')
ax.fill_between(x,mean+se,mean-se,color='k',alpha=0.6)
ax.set_ylim([-0.05,1.05])
# plt.savefig(saveFig+'698nw_2582nj_working_memory_nback_'+reinforcement+'Reinforcement_accuracy_b'+filamentVal+'.pdf',format='pdf',dpi=300)

In [None]:
#Plot thresholds:
plt.figure()
ax=plt.gca()
ax.plot(meanThresh1,label='Drain 1',c='b')
ax.fill_between(x,meanThresh1-seThresh1,meanThresh1+seThresh1,color='b',alpha=0.6)
ax.plot(meanThresh2,label='Drain 2',c='r')
ax.fill_between(x,meanThresh2-seThresh2,meanThresh2+seThresh2,color='r',alpha=0.6)
ax.set_xlabel('Epoch')
ax.set_ylabel('$\Theta$')
ax.set_ylim([0.2e-5,1.05e-5])
plt.legend()
plt.savefig(saveFig+'698nw_2582nj_working_memory_nback_'+reinforcement+'Reinforcement_Theta_Threshold_b'+filamentVal+'.pdf',format='pdf',dpi=300)

In [None]:
#Plot Accuracy for each class:
p1acc=np.array(accuracy[:40])[testingLabels==0]
p2acc=np.array(accuracy[:40])[testingLabels==1]

plt.plot(np.mean(p1acc),'x')
plt.plot(np.mean(p2acc),'+')
plt.ylabel('Accuracy')
plt.xticks([0])
ax=plt.gca()
ax.set_xticklabels([4])
ax.set_xlabel('n-back')
ax.set_ylim([0,1])

In [None]:
#Electrode current and voltage for training and testing

# jV=[]
# jF=[]
# jC=[]
# activeSources=[]
elecItrain=[];elecItest=[];inputSignal=[]
drain1=[];drain2=[]
t=0
nback=nbacks[0]
for i in range(len(sims['sim'])):
#     jV.append(sim[i].junctionVoltage)   
#     jF.append(sim[i].filamentState) #Junction Filament Negative
#     jC.append(sim[i].junctionConductance)
#     activeSources.append(sim[i].sources)
    if sims['sim'][i].electrodeCurrent.shape[1] == 10:
        elecItrain.append(sims['sim'][i].electrodeCurrent[:,0])
    else: #only test:
#         if len(elecItest[i]) < 600:
        elecItest.append(sims['sim'][i].electrodeCurrent[:,:2])
        drain1.append(elecItest[t][:,0])
        drain2.append(elecItest[t][:,1])
        t+=1
        
drain1=np.hstack(np.array(drain1))
drain2=np.hstack(np.array(drain2))
# rampV=np.linspace(0,onAmpTest,TlenTest)

In [None]:
#Plot currents:
plt.figure(figsize=(15,10))
plt.plot(drain1)
plt.plot(drain2)
ax=plt.gca()


In [None]:
#Plot current vs Accuracy:
#Current Strength vs Accuracy:
plt.rcParams['pdf.fonttype'] = 42


fig,ax = plt.subplots(1,1, figsize=(12, 8), dpi=300)
p1=ax.plot(drain1,'b',alpha=0.5)
p2=ax.plot((drain2),'r',alpha=0.5)
plt.legend([p1[0],p2[0]],['Drain 1','Drain 2'],loc='lower right',title='Current')


if nback == 1:
    xlimmax=830
    vlinesStep=40
    vlinesStart=30
    dotpos=2

if nback == 2:
#     xlimmax=1230
    vlinesStep=600
    vlinesStart=500
    dotpos=3
elif nback == 3:
    xlimmax=1630
    vlinesStep=80
    vlinesStart=70
    dotpos=4
    
colrs=[]
for val in targetsNew:
    if val < 1:
        colrs.append('C1')
    else:
        colrs.append('C2')    

testColors=[]
for i in colrs:
    if i == 'C1':
        testColors.append('b')
    else:
        testColors.append('r')
        
ax.set_xlabel('Time')
ax.set_ylabel('Current')
# ax.set_xlim([-600,24400])

ax.vlines(range(vlinesStart,len(drain1),vlinesStep),np.min(drain2),np.max(drain2),linestyle='dashed',color=testColors,alpha=0.5)

ax2=ax.twinx().twiny()
ax2.scatter(list(range(0,len(targetsNew),1)),accuracy,c=testColors)
ax2.set_ylim([-0.05,1.05])



ax2.set_xlabel('Sample Num')
ax2.set_ylabel('Accuracy')

# fig.savefig(saveFig+'698w_2582j_current_time_accuracy_nback2_2x2_b2.pdf',format='pdf',dpi=300)


In [None]:
#Reaction Time analysis:
rt1=[];rt2=[]
for test in elecItest:
    threshReachedBool1=np.array(np.where(test[:,0]>=0.6e-6))
    threshReachedBool2=np.array(np.where(test[:,1]>=0.6e-6))
    if threshReachedBool1.size > 0 :
        rt1.append(np.min(threshReachedBool1))
    else:
        rt1.append(np.nan)
        
    if threshReachedBool2.size > 0 :
        rt2.append(np.min(threshReachedBool2))
    else:
        rt2.append(np.nan)  
    
rt1=np.array(rt1)
rt2=np.array(rt2)
rt1=rt1[~np.isnan(rt1)]
rt2=rt2[~np.isnan(rt2)]

In [None]:
#Plot reaction times:
def box_plot(data, edge_color, fill_color):
    bp = ax.boxplot(data, patch_artist=True)
    
    for element in ['boxes', 'whiskers', 'fliers', 'means', 'medians', 'caps']:
        plt.setp(bp[element], color=edge_color)

    for patch in bp['boxes']:
        patch.set(facecolor=fill_color,alpha=0.7)       
        
    return bp
fig,ax=plt.subplots()
ax.set_ylim([0,600])
box_plot(rt1,'b','b')
box_plot(rt2,'r','r')

# plt.boxplot([rt2],patch_artist=True,boxprops=dict(facecolor='w', color='r'))


### Load all n-backs

In [None]:
nbacks=[2,3,4,5,6]
i=0
data=[]
accuracies=[]
filamentVal='0p5' 
reinforcement='no' #w or no
numEpochs=40

for n in tqdm(nbacks):
#     saveName='data_'+str(nwSize)+'nw_'+str(nwJunctions)+'nj_3x3_wThresh_Vtrn0.3_Vtst0.1_beta0p5_T200_plus_cross_nback'+str(n)+'_filament0p5'
    saveName='data_'+str(nwSize)+'nw_'+str(nwJunctions)+'nj_3x3_wThresh_Vtrn'+str(onAmp)+'_Vtst'+str(onAmpTest)+'_T200_plus_cross_nback'+str(n)+'_filament'+filamentVal+'_'+reinforcement+'Reinforcement_DataNoSim'
    with open(dataLoc+'/Sim Results/'+saveName+'.pkl', 'rb') as f:
        data.append(pickle.load(f))
    accuracies.append(data[i][0]['accuracy'])
    i+=1

In [None]:
#Accuracies:
acc_arr=np.array([item for sublist in accuracies for subsublist in sublist for item in subsublist[:40]]).reshape(5,10,40)

In [None]:
#SEM across epochs:

if reinforcement == 'w':
    meanAcc_epochs_Reinf=[];stdAcc_epochs_Reinf=[]
    for i in range(len(acc_arr)):
        meanAcc_epochs_Reinf.append(np.mean(np.mean(acc_arr,axis=1)[i]))
        stdAcc_epochs_Reinf.append(np.std(np.mean(acc_arr,axis=1)[i]))
    seAcc_epochs_Reinf=np.array(stdAcc)/np.sqrt(numEpochs)
else:
    meanAcc_epochs_noReinf=[];stdAcc_epochs_noReinf=[]
    for i in range(len(acc_arr)):
        meanAcc_epochs_noReinf.append(np.mean(np.mean(acc_arr,axis=1)[i]))
        stdAcc_epochs_noReinf.append(np.std(np.mean(acc_arr,axis=1)[i]))
    seAcc_epochs_noReinf=np.array(stdAcc)/np.sqrt(numEpochs)

In [None]:
#SEM across experiments
i=0
if reinforcement == 'w':
    meanAcc_Reinf=[[]*10 for i in range(5)];stdAcc=[[]*10 for i in range(5)]
    for a in accuracies: #for each nback
        a=np.array([np.array(ai[:numEpochs]) for ai in a])
        meanAcc_Reinf[i].append(np.mean(a))
        stdAcc[i].append(np.std(a))
        i+=1
    seAcc_Reinf=np.array(stdAcc)/np.sqrt(40)
else:
    meanAcc_noReinf=[[]*10 for i in range(5)];stdAcc=[[]*10 for i in range(5)]
    for a in accuracies: #for each nback
        a=np.array([np.array(ai[:numEpochs]) for ai in a])
        meanAcc_noReinf[i].append(np.mean(a))
        stdAcc[i].append(np.std(a))
        i+=1
    seAcc_noReinf=np.array(stdAcc)/np.sqrt(numEpochs)

### LOAD EXPERIMENTAL DATA

In [None]:
import pandas as pd
#EXPERIMENTAL RESULTS:
#LOAD EXPERIMENTAL DATA:
loadLoc='/PhysicalReinforcementLearning/experimental_results/Task 2/'
fileName='non_reinforce_xp_pl__11_40_24_PD.csv'
fileName2='reinforce_xp_pl__11_39_57_PD.csv'

dataExp=pd.read_csv(loadLoc+fileName)
dataExp2=pd.read_csv(loadLoc+fileName2)

numEpochsExp=100

In [None]:
#PLOT MEAN AND SEM ACROSS EXPERIMENTS

plt.rcParams['pdf.fonttype'] = 42

plt.figure(figsize=(10,8))
nbacks=[2,3,4,5,6]
# for i in range(len(meanAcc)):
yerr1=np.array(seAcc_noReinf).reshape(-1)
y1 = np.array(meanAcc_noReinf).reshape(-1)
yerr2=np.array(seAcc_Reinf).reshape(-1)
y2 = np.array(meanAcc_Reinf).reshape(-1)

plt.plot(nbacks,y1,'--o',c='r')
plt.fill_between(nbacks,y1+yerr1,y1-yerr1,alpha=0.7,color='r')
plt.plot(nbacks,y2,'-o',c='b')
plt.fill_between(nbacks,y2+yerr2,y2-yerr2,alpha=0.7,color='b')

plt.plot(nbacks,dataExp['globalAccuracy'],'--s',c='k')
plt.plot(nbacks,dataExp2['globalAccuracy'],'-s',c='k')

plt.xlabel('n-back')
plt.ylabel('Accuracy')
plt.ylim([-0.04,1.02])


plt.legend(['No Reinforcement','Reinforcement'])

plt.savefig(saveFig+'698nw_2582nj_plus_cross_nback23456_Sim_Experiment_b'+filamentVal+'.pdf',format='pdf',dpi=300)

In [None]:
#PLOT MEAN AND SEM ACROSS EPOCHS (Figure 3 in Paper)

plt.rcParams['pdf.fonttype'] = 42

plt.figure(figsize=(10,8))
nbacks=[2,3,4,5,6]
# for i in range(len(meanAcc)):
yerr1=np.array(seAcc_epochs_noReinf).reshape(-1)
y1 = np.array(meanAcc_epochs_noReinf).reshape(-1)
yerr2=np.array(seAcc_epochs_Reinf).reshape(-1)
y2 = np.array(meanAcc_epochs_Reinf).reshape(-1)

plt.plot(nbacks,y1,'--o',c='r')
plt.fill_between(nbacks,y1+yerr1,y1-yerr1,alpha=0.7,color='r')
plt.plot(nbacks,y2,'-o',c='b')
plt.fill_between(nbacks,y2+yerr2,y2-yerr2,alpha=0.7,color='b')


seExpReinf=dataExp2['mix_accuracy'].values/np.sqrt(numEpochsExp)
seExpNoReinf=dataExp['mix_accuracy'].values/np.sqrt(numEpochsExp)

plt.plot(nbacks,dataExp['globalAccuracy'],'--s',c='k')
plt.fill_between(nbacks,dataExp['globalAccuracy']+seExpNoReinf,dataExp['globalAccuracy']-seExpNoReinf,alpha=0.7,color='k')

plt.plot(nbacks,dataExp2['globalAccuracy'],'-s',c='k')
plt.fill_between(nbacks,dataExp2['globalAccuracy']+seExpReinf,dataExp2['globalAccuracy']-seExpReinf,alpha=0.7,color='k')

plt.xlabel('n-back')
plt.ylabel('Accuracy')
plt.ylim([-0.04,1.02])


plt.legend(['No Reinforcement','Reinforcement'])

plt.savefig(saveFig+'698nw_2582nj_plus_cross_nback23456_Sim_SEMepochs_Experiment_b'+filamentVal+'.pdf',format='pdf',dpi=300)

In [None]:
#Split accuracies:

nbacks=[2,3,4,5,6]#,7,8]
outputs=[]
for n in tqdm(nbacks):
    #Cross and Plus:
    saveName='data_698nw_2582nj_3x3_wThresh_Vtrn0.3_Vtst0.1_beta0p5_T200_plus_cross_nback'+str(n)+'_filament5'

    #simple:
#     saveName='data_698nw_2582nj_3x3_wThresh_Vtrn0.3_Vtst0.1_beta0p5_T200_oneDrainTrain_nback'+str(n)+'_filament0p5'
    with open(dataLoc+saveName+'.pkl', 'rb') as f:
        outputs.append(pickle.load(f))
        
# outputs[0]=[outputs[0]] #fix annoying data structure for nback 2 of the simple case

In [None]:
p1Acc=[];p2Acc=[]
p1avg=[];p2avg=[];p1std=[];p2std=[]
for i in range(len(outputs)):
    p1Acc.append(np.array(outputs[i][0]['accuracy'])[:numEpochs][outputs[i][0]['testing labels']==0])
    p2Acc.append(np.array(outputs[i][0]['accuracy'])[:numEpochs][outputs[i][0]['testing labels']==1])
    p1avg.append(np.mean(p1Acc[i]))
    p2avg.append(np.mean(p2Acc[i]))
    p1std.append(np.std(p1Acc[i]))
    p2std.append(np.std(p2Acc[i]))

In [None]:
#Plot split accuracies:
plt.plot(p1avg,'-o',label='Pattern 1')
plt.plot(p2avg,'-s',label='Pattern 2')
ax=plt.gca()
ax.set_title('b = 5')

ax.set_xticklabels([2,3,4,5,6])#,7,8,9])
ax.set_xticks([0,1,2,3,4])#,7,8,9])

ax.set_xlabel('n-back')
ax.set_ylabel('Accuracy')
ax.set_ylim([0,1.1])
plt.legend()

plt.savefig(saveFig+'698nw_2582nj_plus_cross_Adrian_nback_Accuracy_b5.pdf',format='pdf',dpi=300)