# Figure generation notebook

This notebook generates figures for the paper "A factor graph EM algorithm for inference of kinetic microstates from patch clamp measurements", and is packaged with all the data needed to generate the figures.

For Figures 3-6, data can be generated from the simulation code in the parent directory of this repository. Results generated by this code are included in this package.

For Figure 7, raw patchclamp data is included in this package.

All scripts should be executed in the same working directory as the data.

In [None]:
# required libraries and processing functions - execute this cell first

import numpy as np
import csv
import matplotlib
import matplotlib.pyplot as plt

# reads results if known parameters or last EM results only are selected
def readCSV(f):

    with open(f) as csvfile:
        r = csv.reader(csvfile)
        m = []
        for i in r:
            m.append(list(i))
        
    paramList = m[:-1]
    results = m[-1]
    
    return paramList,results

def readCSVpc(f,encoding=None):

    if encoding is None:
        with open(f) as csvfile:
            r = csv.reader(csvfile)
            m = []
            for i in r:
                m.append(list(i))
    else:
        with open(f,encoding=encoding) as csvfile:
            r = csv.reader(csvfile)
            m = []
            for i in r:
                m.append(list(i))
                
    return m

# reads results if full EM results are selected
def readCSVem(f):
    
    with open(f) as csvfile:
        r = csv.reader(csvfile)
        m = []
        for i in r:
            m.append(list(i))
    
    a = len(m) # here half the lists are parameters, half are results
    paramList = m[:int(a/2)]
    results = m[int(a/2):]
    
    return paramList,results

def readCSVOneLine(f,encoding=None):
    
    return readCSVpc(f,encoding=encoding)[0]

def processResults(paramList,results,paramIndex,unique=True):
    
    allIndices = [float(paramList[i][paramIndex]) for i in range(0,len(paramList))]


    # if unique is True, then we take the sum over unique indices 
    # if False, we preserve all indices separately
    if (unique is False):
        uniqueIndices = np.array(allIndices)
        cResults = np.array([float(results[i]) for i in range(0,len(results))])
        indicesCount = np.array([float(paramList[i][0]) for i in range(0,len(results))])
    else:
        uniqueIndices = []
        for r in allIndices:
            if (r not in uniqueIndices):
                uniqueIndices.append(r)

        uniqueIndices.sort()

        cResults = np.zeros(len(uniqueIndices))
        indicesCount = np.zeros(len(uniqueIndices))
        for i in range(0,len(results)):
            j = uniqueIndices.index(float(paramList[i][paramIndex]))
            indicesCount[j] += float(paramList[i][0]) # at this location we have the number of estimates
            cResults[j] += float(results[i])

    cResults = cResults / indicesCount
    
    return uniqueIndices,cResults

# a formatter to allow axis decorations with order of magnitude
class MagnitudeFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, exponent=None):
        super().__init__()
        self._fixed_exponent = exponent

    def _set_order_of_magnitude(self):
        if self._fixed_exponent:
            self.orderOfMagnitude = self._fixed_exponent
        else:
            super()._set_order_of_magnitude()

def parseConfResult(r):
    # results of this kind are of the form '[ a b]' where a and b are integers
    # need to strip the brackets and whitespace
    r = r[1:-1] # strip the brackets
    r = r.split() # split on whitespace, eliminates left and right whitespace
    return int(r[0]),int(r[1])

def string2float_list(l):
    r = []
    for i in range(0,len(l)):
        r.append(float(l[i]))
        
    return r

# plots f with respect to t, but omits points where f = x
def plotExclude(ax,t,f,x,line='-',clr='green',lw=2):
    
    # find runs that do not include x
    frun = []
    trun = []
    inRun = False
    for i in range(0,len(t)):
        if (f[i] != x):
            if (inRun is False):
                # start a run
                inRun = True
            frun.append(f[i])
            trun.append(t[i])
        if (inRun is True) and ((f[i] == x) or (i == len(t)-1)):
            # end the run and plot
            ax.plot(trun,frun,linestyle=line,color=clr,linewidth=lw)
            frun = []
            trun = []
            inRun = False

def readData(pathPrefix,fname_pre,fname_raw,fname_10,fname_100):       
    # read the raw data
    data = string2float_list(readCSVOneLine(pathPrefix + fname_pre))
    rawdata = string2float_list(np.array(readCSVpc(pathPrefix + fname_raw,encoding='utf-8-sig'))[:,1])

    # read the estimates ... estimates are the 3rd line of the file, estimates with reliability are the 4th line
    # the first 2 lines have the parameters and can be ignored
    e_10 = readCSVpc(pathPrefix + fname_10)
    e_100 = readCSVpc(pathPrefix + fname_100)


    e_10_est = string2float_list(e_10[2])
    e_10_conf = string2float_list(e_10[3])
    e_100_est = string2float_list(e_100[2])
    e_100_conf = string2float_list(e_100[3])

    t = np.arange(len(data))*500e-6*50
    t_raw = np.arange(len(rawdata))*500e-6
    
    return data,rawdata,e_10_est,e_10_conf,e_100_est,e_100_conf,t,t_raw

# wherever we find a value from old_map in v, we replace it with the corresponding value of new_map
def remap(old_map,new_map,v):
    r = np.zeros(len(v))
    for i in range(0,len(v)):
        for j in range(0,len(old_map)):
            if v[i] == old_map[j]:
                r[i] = new_map[j]
                
    return r


            

### Figure 3

This code generates Figure 3 in the paper, reading the data files in the Figure3 directory.

In [None]:
fig3, (ax, ax2) = plt.subplots(ncols=2)

#plt.title('CFTR: Visualizing the progress of the algorithm')
px,rx = readCSVem('Figure3/cftr_example.csv')

for j in range(0,5):
    rxf = [float(i)/20000 for i in rx[j][:100]]
    ax.plot(np.arange(0,len(rxf)),rxf,'b-',linewidth=1)


px,rx = readCSVem('Figure3/ach_example.csv')

for j in range(0,5):
    rxf = [float(i)/20000 for i in rx[j][:100]]
    ax2.plot(np.arange(0,len(rxf)),rxf,'b-',linewidth=1)


ax.set_box_aspect(1)
ax.set_title('(a): CFTR')
ax.set_xlabel('Iterations')
ax.set_ylabel('$P_e$')

ax2.set_box_aspect(1)
ax2.set_title('(b): nAChR')
ax2.set_xlabel('Iterations')

fig3.tight_layout(pad=1)

plt.savefig('EM-visualization.pdf',bbox_inches='tight',pad_inches=0.05)
plt.show()


### Figure 4

This code generates Figure 4 in the paper, reading the data files in the Figure4 directory.

In [None]:
atpc = 1/90 # for converting probabilities back to ATP concentrations -- CFTR only

fig4, (ax, ax2) = plt.subplots(ncols=2)

# CFTR
pe,re = readCSV('Figure4/summary-CFTR-1.csv')

uie,cre = processResults(pe,re,3,unique=False)
uiu,cru = processResults(pe,re,3,unique=True)

pk,rk = readCSV('Figure4/known-CFTR-1.csv')

uik,crk = processResults(pk,rk,3)

ax.plot(np.array(uie)*atpc,cre,'b.',markersize=1,markeredgecolor='#7777ff',markerfacecolor='#7777ff')
ax.plot(np.array(uiu)*atpc,cru,'b-')
ax.plot(np.array(uik)*atpc,crk,'r-')

ax.set_box_aspect(1)
ax.set_title('(a): CFTR')
ax.set_xlabel('[ATP] (M)')
ax.set_ylabel('$P_e$')
ax.set_ylim([0.1,0.5])

# uses custom formatter above -- -4 means 10^-4
ax.xaxis.set_major_formatter(MagnitudeFormatter(-3))

# ACh
pe,re = readCSV('Figure4/summary-ACh-1.csv')

uie,cre = processResults(pe,re,3,unique=False)
uiu,cru = processResults(pe,re,3,unique=True)

pk,rk = readCSV('Figure4/known-ACh-1.csv')

uik,crk = processResults(pk,rk,3)

ax2.plot(uie,cre,'.',markersize=1,markeredgecolor='#7777ff',markerfacecolor='#7777ff')
ax2.plot(uiu,cru,'b-')
ax2.plot(uik,crk,'r-')

ax2.set_box_aspect(1)
ax2.set_title('(b): nAChR')
ax2.set_xlabel('[ACh] (M)')

fig4.tight_layout(pad=1)

plt.savefig('EM-concentration.pdf',bbox_inches='tight',pad_inches=0.05)
plt.show()

### Figure 5

This code generates Figure 5 in the paper, reading the data files in the Figure5 directory.

In [None]:
atpc = 1/90 # for converting probabilities back to ATP concentrations -- CFTR only

fig5, ((ax,ax2),(ax3,ax4)) = plt.subplots(nrows=2,ncols=2,figsize=(6,6))

conf_params = []
conf_errors = []
conf_num = []

a,b = readCSV('Figure5/CFTR_conf_em.csv')

for i in range(0,len(b)):
    c,d = parseConfResult(b[i])
    conf_errors.append(c)
    conf_num.append(d)
    conf_params.append(a[i])

# processResults gives us an average over each parameter value, but divides the result by 20000,
# which is not what we need here ...
qq,rr = processResults(conf_params,conf_errors,3,unique=False)
    
result_errors = np.array(conf_errors)/np.array(conf_num)
result_num = np.array(conf_num)/20000

xxatp = np.array(qq)*atpc
zz1 = np.array([rr[i]*20000/conf_num[i] for i in range(0,len(rr))])
zz2 = np.array([conf_num[i]/20000 for i in range(0,len(rr))])

ax.plot(xxatp,zz1,'.',markersize=1,markeredgecolor='#7777ff',markerfacecolor='#7777ff')
ax3.plot(xxatp,zz2,'.',markersize=1,markeredgecolor='#77bf77',markerfacecolor='#77bf77')

ss,tt = processResults(conf_params,[rr[i]*20000/conf_num[i] for i in range(0,len(rr))],3,unique=True)
uu,vv = processResults(conf_params,[conf_num[i]/20000 for i in range(0,len(rr))],3,unique=True)

ssatp = np.array(ss)*atpc
uuatp = np.array(ss)*atpc
    
ax.plot(ssatp,tt*20000,'b-')
ax3.plot(uuatp,vv*20000,'g-')

ax.set_box_aspect(1)
ax.set_title('(a): CFTR')
ax.set_ylabel('$P_e$')

ax3.set_box_aspect(1)
ax3.set_title('(b): CFTR')
ax3.set_xlabel('[ATP] (M)')
ax3.set_ylabel('Confidence fraction')

ax.xaxis.set_major_formatter(MagnitudeFormatter(-3))
ax3.xaxis.set_major_formatter(MagnitudeFormatter(-3))

conf_params = []
conf_errors = []
conf_num = []

a,b = readCSV('Figure5/ACh_conf_em.csv')

for i in range(0,len(b)):
    c,d = parseConfResult(b[i])
    conf_errors.append(c)
    conf_num.append(d)
    conf_params.append(a[i])

# processResults gives us an average over each parameter value, but divides the result by 20000,
# which is not what we need here ...
qq,rr = processResults(conf_params,conf_errors,3,unique=False)
    
result_errors = np.array(conf_errors)/np.array(conf_num)
result_num = np.array(conf_num)/20000

ax2.plot(qq,[rr[i]*20000/conf_num[i] for i in range(0,len(rr))],'.',markersize=1,markeredgecolor='#7777ff',markerfacecolor='#7777ff')
ax4.plot(qq,[conf_num[i]/20000 for i in range(0,len(rr))],'g.',markersize=1)

ss,tt = processResults(conf_params,[rr[i]*20000/conf_num[i] for i in range(0,len(rr))],3,unique=True)
uu,vv = processResults(conf_params,[conf_num[i]/20000 for i in range(0,len(rr))],3,unique=True)
    
ax2.plot(ss,tt*20000,'b-')
ax4.plot(uu,vv*20000,'g-')

ax2.set_box_aspect(1)
ax2.set_title('(c): nAChR')

ax4.set_box_aspect(1)
ax4.set_title('(d): nAChR')
ax4.set_xlabel('[ACh] (M)')

fig5.tight_layout(pad=1)

plt.savefig('EM-confidence.pdf',bbox_inches='tight',pad_inches=0.05)

### Figure 6

This code generates Figure 6 in the paper, reading the data files in the Figure6 directory. The figure also includes the lines used in Figure 4, so the relevant files are read from the Figure4 directory.

In [None]:
atpc = 1/90 # for converting probabilities back to ATP concentrations -- CFTR only

fig6, (ax, ax2) = plt.subplots(ncols=2)

# CFTR

#### noise-free results

pe,re = readCSV('Figure4/summary-CFTR-1.csv')

uie,cre = processResults(pe,re,3,unique=False)
uiu,cru = processResults(pe,re,3,unique=True)

pk,rk = readCSV('Figure4/known-CFTR-1.csv')

uik,crk = processResults(pk,rk,3)

ax.plot(np.array(uiu)*atpc,cru,linestyle='solid',color='#44cccc')
ax.plot(np.array(uik)*atpc,crk,linestyle='solid',color='#ff9933')

#### noisy results

pe,re = readCSV('Figure6/CFTR-noisy-em.csv')

uie,cre = processResults(pe,re,3,unique=False)
uiu,cru = processResults(pe,re,3,unique=True)

pk,rk = readCSV('Figure6/CFTR-noisy-kp.csv')

uik,crk = processResults(pk,rk,3)

ax.plot(np.array(uie)*atpc,cre,'b.',markersize=1,markeredgecolor='#7777ff',markerfacecolor='#7777ff')
ax.plot(np.array(uiu)*atpc,cru,'b-')
ax.plot(np.array(uik)*atpc,crk,'r-')

ax.set_box_aspect(1)
ax.set_title('(a): CFTR')
ax.set_xlabel('[ATP] (M)')
ax.set_ylabel('$P_e$')

# uses custom formatter above -- -4 means 10^-4
ax.xaxis.set_major_formatter(MagnitudeFormatter(-3))

# ACh

#### noise-free results

pe,re = readCSV('Figure4/summary-ACh-1.csv')

uie,cre = processResults(pe,re,3,unique=False)
uiu,cru = processResults(pe,re,3,unique=True)

pk,rk = readCSV('Figure4/known-ACh-1.csv')

uik,crk = processResults(pk,rk,3)

ax2.plot(uiu,cru,linestyle='solid',color='#44cccc')
ax2.plot(uik,crk,linestyle='solid',color='#ff9933')

#### noisy results

pe,re = readCSV('Figure6/ACh-noisy-em.csv')

uie,cre = processResults(pe,re,3,unique=False)
uiu,cru = processResults(pe,re,3,unique=True)

pk,rk = readCSV('Figure6/ACh-noisy-kp.csv')

uik,crk = processResults(pk,rk,3)

ax2.plot(uie,cre,'.',markersize=1,markeredgecolor='#7777ff',markerfacecolor='#7777ff')
ax2.plot(uiu,cru,'b-')
ax2.plot(uik,crk,'r-')

ax2.set_box_aspect(1)
ax2.set_title('(b): nAChR')
ax2.set_xlabel('[ACh] (M)')

fig6.tight_layout(pad=1)

plt.savefig('EM-noisy.pdf',bbox_inches='tight',pad_inches=0.05)
plt.show()


### Figure 7

This code generates Figure 7 in the paper, reading the data files in the Figure7 directory. Preprocessed data is created from raw data by taking a block average over non-overlapping blocks of size 50, then multiplying by 1.75 to scale the data so that features are roughly in the range [0,1]. Final state estimates are generated using the code in the parent directory using the command (for example, for the n17000 dataset and 10 EM iterations):

cat n17000-preprocessed.csv | python3 main.py -t -l -ev -i=10 > n17000-r10.csv

In [None]:
data1,rawdata1,e1_10_est,e1_10_conf,e1_100_est,e1_100_conf,t1,t1_raw = readData('Figure7/',
                                                                                'n17000-preprocessed.csv',
                                                                                'n17000-raw.csv',
                                                                                'n17000-r10.csv',
                                                                                'n17000-r100.csv')

data2,rawdata2,e2_10_est,e2_10_conf,e2_100_est,e2_100_conf,t2,t2_raw = readData('Figure7/',
                                                                                'n17004-1-preprocessed.csv',
                                                                                'n17004-1-raw.csv',
                                                                                'n17004-1-r10.csv',
                                                                                'n17004-1-r100.csv')

fig7, ((ax1,ax4), (ax2,ax5), (ax3,ax6)) = plt.subplots(nrows=3,ncols=2,figsize=(9,7))

ax1.plot(t1_raw,np.array(rawdata1),color='#7c2dad',linewidth=1.)
ax1.set_ylabel('Current (pA)')
ax1.set_ylim(-0.2,1)
ax1.set_xlim(t1_raw[0],t1_raw[-1])
ax1.set_title('(a)')

ax2.plot(t1,np.array(data1),color='#000000',linewidth=1.)
ax2.set_ylabel('Current (scaled)')
ax2.set_ylim(-0.2,1.4)
ax2.set_xlim(t1[0],t1[-1])
ax2.set_title('(b)')

ax3.plot(t1,remap([0.,1.,2.,3.,4.,5.,6.,-1.],[2.,3.,4.,5.,6.,0.,1.,-1.],e1_10_est),color='#888888')
ax3.plot(t1,remap([0.,1.,2.,3.,4.,5.,6.,-1.],[2.,3.,4.,5.,6.,0.,1.,-1.],e1_100_est),'r-')
plotExclude(ax3,t1,remap([0.,1.,2.,3.,4.,5.,6.,-1.],[2.,3.,4.,5.,6.,0.,1.,-1.],e1_100_conf),-1.,clr='#ffaa33',lw=3)
ax3.set_yticks([0.,1.,2.,3.,4.,5.,6.])
ax3.set_yticklabels(['C3','C4','C1a','C1b','C2','O1','O2'])
ax3.set_ylabel('State estimate')
ax3.set_xlim(t1[0],t1[-1])
ax3.set_ylim(-0.4,6.4)
ax3.axhspan(-0.4,4.5,facecolor='#dddddd',alpha=0.5)
ax3.set_xlabel('Time (s)')
ax3.set_title('(c)')

ax4.plot(t2_raw,np.array(rawdata2),color='#7c2dad',linewidth=1.)
ax4.set_ylim(-0.2,1)
ax4.set_xlim(t2_raw[0],t2_raw[-1])
ax4.set_title('(d)')

ax5.plot(t2,np.array(data2),color='#000000',linewidth=1.)
ax5.set_ylim(-0.2,1.4)
ax5.set_xlim(t2[0],t2[-1])
ax5.set_title('(e)')

ax6.plot(t2,remap([0.,1.,2.,3.,4.,5.,6.,-1.],[2.,3.,4.,5.,6.,0.,1.,-1.],e2_10_est),color='#888888')
ax6.plot(t2,remap([0.,1.,2.,3.,4.,5.,6.,-1.],[2.,3.,4.,5.,6.,0.,1.,-1.],e2_100_est),'r-')
plotExclude(ax6,t2,remap([0.,1.,2.,3.,4.,5.,6.,-1.],[2.,3.,4.,5.,6.,0.,1.,-1.],e2_100_conf),-1.,clr='#ffaa33',lw=3)
ax6.set_yticks([0.,1.,2.,3.,4.,5.,6.])
ax6.set_yticklabels(['C3','C4','C1a','C1b','C2','O1','O2'])
ax6.set_xlim(t2[0],t2[-1])
ax6.set_ylim(-0.4,6.4)
ax6.axhspan(-0.4,4.5,facecolor='#dddddd',alpha=0.5)
ax6.set_xlabel('Time (s)')
ax6.set_title('(f)')

fig7.tight_layout(pad=1)
fig7.savefig('patchclamp-data.pdf')