# Code to cluster cells by GLIF and spike shape parameters using iterative binary clustering and generate confidence intervals for clustering similarity
### Teeter et al. 2018
#### This notebook runs the iterative binary approach to generate clusters from the GLIF and spike-shape parameters. This code outputs graphs showing the confidence intervals on clustering similarity, the latter measured using the Adjusted Rand Index and the Adjusted Variation of Information. This corresponds to Supplemental Figure 17 in the paper. 

### 1) Install required packages

In [7]:
if (!require(ape)) {install.packages("ape", repos = "http://cran.us.r-project.org")}
if (!require(e1071)) {install.packages("e1071", repos = "http://cran.us.r-project.org")}
if (!require(gplots)) {install.packages("gplots", repos = "http://cran.us.r-project.org")}
if (!require(plotrix)) {install.packages("plotrix", repos = "http://cran.us.r-project.org")}
require(ape)
require(e1071)
require(gplots)
require(mclust)
require(plotrix)

Loading required package: plotrix
In library(package, lib.loc = lib.loc, character.only = TRUE, logical.return = TRUE, : there is no package called 'plotrix'


  There is a binary version available (and will be installed) but the
  source version is later:
        binary source
plotrix  3.6-1    3.7

package 'plotrix' successfully unpacked and MD5 sums checked

The downloaded binary packages are in
	C:\Users\menonv\AppData\Local\Temp\RtmpsBeaGa\downloaded_packages


Loading required package: plotrix

Attaching package: 'plotrix'

The following object is masked from 'package:gplots':

    plotCI



### 2) Load data and metadata

In [2]:
###model parameters
dat=read.table("GLIF_param_plus_spike_features_7_27_17.csv",sep="\t",as.is=T,row.names=1,check.names=F,header=T)
metadata=dat[,1:2]
fulldat=dat[,-c(1:2)]

###Cre line metadata
crecols=read.csv("cre_colors.csv",as.is=T,header=F)
newcols=rgb(crecols[,2:4],maxColorValue = 255)
names(newcols)=crecols[,5]
colvec=newcols[match(metadata$cre,crecols[,1])]
cre_order=c("Htr3a","Ndnf","Vip","Sst","Pvalb","Nkx2-1","Chat","Chrna2","Cux2","Nr5a1","Scnn1a-Tg2","Scnn1a-Tg3","Rorb","Rbp4","Ntsr1","Ctgf")
  
###features
featdat=read.table("features_7_27_17.csv",as.is=T,row.names=1,check.names=F,sep=",",header=T)
featmetadata=featdat[,1:2]
featfulldat=featdat[,c("tau","ri","vrest","threshold_i_long_square","threshold_v_long_square","peak_v_long_square","fast_trough_v_long_square","trough_v_long_square","upstroke_downstroke_ratio_long_square","upstroke_downstroke_ratio_short_square","sag","f_i_curve_slope","latency","max_burstiness_across_sweeps")]





### 3) Apply log transform to skewed parameters/features

In [3]:
###model parameters
for (ii in 1:ncol(fulldat)) {
  if (min(fulldat[,ii])*max(fulldat[,ii])>0) {
    if (min(fulldat[,ii])>0) {
      if (skewness(fulldat[,ii])>skewness(log10(fulldat[,ii]))) {
        fulldat[,ii]=log10(fulldat[,ii])
      }
    } else {
      if (skewness(-fulldat[,ii])>skewness(log10(-fulldat[,ii]))) {
        fulldat[,ii]=log10(-fulldat[,ii])
      }
    }
  }
}
fulldat_all=fulldat

###features
for (ii in 1:ncol(featfulldat)) {
  if (min(featfulldat[,ii])*max(featfulldat[,ii])>0) {
    if (min(featfulldat[,ii])>0) {
      if (skewness(featfulldat[,ii])>skewness(log10(featfulldat[,ii]))) {
        featfulldat[,ii]=log10(featfulldat[,ii])
      }
    } else {
      if (skewness(-featfulldat[,ii])>skewness(log10(-featfulldat[,ii]))) {
        featfulldat[,ii]=log10(-featfulldat[,ii])
      }
    }
  }
}
featfulldat_all=featfulldat

### 4) Load clustering and clustering overlap functions

In [4]:
###function to separate data into two clusters and check for cluster separation using SVM-based prediction
cluster_into_two=function(fulldat,startseed,meth='ward.D') {
  fulldat=scale(fulldat[,apply(fulldat,2,var)>0])
  hc=hclust(as.dist(1-cor(t(fulldat),method="pearson")),method=meth)
  clustids=cutree(hc,2)
  outlist=list()
  ###assess predictability using SVM prediction###
   fraction_incorrect=c()
   inds1=which(clustids==1)
   inds2=which(clustids==2)
   if (length(inds1)>5 & length(inds2)>5) {
     sampfrac1=round(0.5*length(inds1))
     sampfrac2=round(0.5*length(inds2))
     for (tt in 1:100) {
       set.seed(tt+startseed)
       sampvec=c(sample(inds1,sampfrac1),sample(inds2,sampfrac2))
       setcols=which(apply(fulldat[sampvec,],2,var)>0)
       svmpred=predict(svm(x=fulldat[sampvec,setcols],y=clustids[sampvec],type="C-classification"),fulldat[-sampvec,setcols])
       conf=table(svmpred,clustids[-sampvec])
       fraction_incorrect=c(fraction_incorrect,(conf[2,1]+conf[1,2])/sum(conf))
     }
   } else {
     fraction_incorrect=c(1,1)
     fraction_incorrect_rand=c(1,1)
   }
   outlist[['fraction_incorrect']]=fraction_incorrect
   outlist[['clustids']]=clustids
  return(outlist)
}

###function to cluster iteratively using binary splits
recursive_clustering=function(keepcols,fulldat_all,fraclim=0.2,splitlim=50,startseed,outlist,methall="ward.D") {
  clustmat=fulldat_all[,keepcols]
  tempout=cluster_into_two(clustmat,meth=methall,startseed)
  if (!is.na(tempout$fraction_incorrect[1])) {
    if (max(tempout$fraction_incorrect,na.rm=T)<=fraclim) {
      outlist$clustnames[names(tempout$clustids)]=paste(outlist$clustnames[names(tempout$clustids)],tempout$clustids,sep="_")
      outlist$fracmat=rbind(outlist$fracmat,tempout$fraction_incorrect)
      for (ii in 1:2) {
        if (length(which(tempout$clustids==ii))>=10) {
        outlist=recursive_clustering(keepcols,fulldat_all[names(tempout$clustids)[tempout$clustids==ii],],fraclim=fraclim,splitlim=splitlim,startseed+ii,outlist)
        }
      }
    }
  }
  return(outlist)
}

###function to calculate Variation of Information or Adjusted Rand Index
calc_cluster_diff=function(xvec,yvec,functype=1,credistmat=c(),clustdistmat=c()) {
  if (functype==1) {
    totaltab=table(xvec,yvec)
    rowmat=sweep(totaltab,1,rowSums(totaltab),"/")  
    colmat=sweep(totaltab,2,colSums(totaltab),"/")
    summat=(totaltab*(log(rowmat)+log(colmat)))
    sumval=sum(summat[totaltab>0])/length(xvec)
    return(-sumval) 
  } else {
    return(adjustedRandIndex(xvec,yvec))
  }
}

###function to calculate score based on 100 random permutations
rand_cluster_diff=function(xvec,yvec,functype=1,credistmat=c(),clustdistmat=c()) {
  allvals=rep(0,100)
  for (ii in 1:100) {
    set.seed(ii)
    allvals[ii]=calc_cluster_diff(xvec,sample(yvec),functype,credistmat,clustdistmat)
  }
  return(allvals)
}


### 5) Run clustering on GLIF parameters, GLIF parameters+spike shape features, and electrophysiological features using full data set

In [6]:
###specify prefix for output file names###
pref="iterative_binary_clustering_2018"
parametersets=c("Features","Featuresnospike","GLIF1","GLIF2","GLIF3","GLIF4","GLIF1_spike_shape","GLIF2_spike_shape","GLIF3_spike_shape","GLIF4_spike_shape")
fraclimval=0.2  ###maximum fraction of incorrectly classified cells in test set (see recursive_clustering function in cell 3)
methall='ward.D'
for (nameval in parametersets) {
  if (nameval=="GLIF1") {keepcols=c(1,3,4,5,8)}
  if (nameval=="GLIF2") {keepcols=c(1,3,4,5,8,9,10)}
  if (nameval=="GLIF3") {keepcols=c(2,3,4,5,6,7,8)}
  if (nameval=="GLIF4") {keepcols=c(2,3,4,5,6,7,8,9,10)}
  if (nameval=="GLIF1_spike_shape") {keepcols=c(1,3,4,5,8,13:16)}
  if (nameval=="GLIF2_spike_shape") {keepcols=c(1,3,4,5,8,9,10,13:16)}
  if (nameval=="GLIF3_spike_shape") {keepcols=c(2,3,4,5,6,7,8,13:16)}
  if (nameval=="GLIF4_spike_shape") {keepcols=c(2,3,4,5,6,7,8,9,10,13:16)}
  if (nameval=="Features") {keepcols=1:ncol(featfulldat_all)}
  if (nameval=="Featuresnospike") {keepcols=c(1,2,3,4,5,8,11,12,13,14)}
  
  if (nameval %in% c("Features","Featuresnospike")) {
    startmat=featfulldat_all
  } else {
    startmat=fulldat_all
  }

  print(paste0("clustering ",nameval," model, using the following parameters: ",paste(colnames(startmat)[keepcols],collapse=",")))
  startnames=rep("1",nrow(startmat))
  names(startnames)=rownames(startmat)
  outlist=list()
  outlist$clustnames=startnames
  outlist$fracmat=c()
  allclusts=recursive_clustering(keepcols,startmat,fraclim=fraclimval,splitlim=splitlimval,startseed=1,outlist=outlist)
  temptab=table(allclusts$clustnames[intersect(names(allclusts$clustnames),rownames(metadata))],metadata$cre[match(intersect(names(allclusts$clustnames),rownames(metadata)),rownames(metadata))])
  colnames(temptab)=sapply(strsplit(colnames(temptab),"-"), `[`, 1)
  temptab=cbind(temptab,paste("Cluster ",rev(1:nrow(temptab)),sep=''))
  write.csv(temptab,file=paste0("composition_",pref,"_",nameval,".csv"))
  temptab=allclusts$clustnames
  write.csv(temptab,file=paste0("cluster_ids_",pref,"_",nameval,".csv"))
}

[1] "clustering Features model, using the following parameters: tau,ri,vrest,threshold_i_long_square,threshold_v_long_square,peak_v_long_square,fast_trough_v_long_square,trough_v_long_square,upstroke_downstroke_ratio_long_square,upstroke_downstroke_ratio_short_square,sag,f_i_curve_slope,latency,max_burstiness_across_sweeps"
[1] "clustering Featuresnospike model, using the following parameters: tau,ri,vrest,threshold_i_long_square,threshold_v_long_square,trough_v_long_square,sag,f_i_curve_slope,latency,max_burstiness_across_sweeps"
[1] "clustering GLIF1 model, using the following parameters: R_input,C,El,th_inf,spike_cut_length"
[1] "clustering GLIF2 model, using the following parameters: R_input,C,El,th_inf,spike_cut_length,reset_slope,reset_intercept"
[1] "clustering GLIF3 model, using the following parameters: R_ASC,C,El,th_inf,total charge 1/300+1/100,total charge 1/3+1/10+1/100,spike_cut_length"
[1] "clustering GLIF4 model, using the following parameters: R_ASC,C,El,th_inf,total ch

### 6) Run clustering on GLIF parameters, GLIF parameters+spike shape features, and electrophysiological features using bootstrapped subsets comprising 80% of the cells


In [5]:
###specify prefix for output file names###
pref="iterative_binary_clustering_2018"
parametersets=c("Features","Featuresnospike","GLIF1","GLIF2","GLIF3","GLIF4","GLIF1_spike_shape","GLIF2_spike_shape","GLIF3_spike_shape","GLIF4_spike_shape")
fraclimval=0.2  ###maximum fraction of incorrectly classified cells in test set (see recursive_clustering function in cell 3)
methall='ward.D'
sublist=list()
for (nameval in parametersets) {
  if (nameval=="GLIF1") {keepcols=c(1,3,4,5,8)}
  if (nameval=="GLIF2") {keepcols=c(1,3,4,5,8,9,10)}
  if (nameval=="GLIF3") {keepcols=c(2,3,4,5,6,7,8)}
  if (nameval=="GLIF4") {keepcols=c(2,3,4,5,6,7,8,9,10)}
  if (nameval=="GLIF1_spike_shape") {keepcols=c(1,3,4,5,8,13:16)}
  if (nameval=="GLIF2_spike_shape") {keepcols=c(1,3,4,5,8,9,10,13:16)}
  if (nameval=="GLIF3_spike_shape") {keepcols=c(2,3,4,5,6,7,8,13:16)}
  if (nameval=="GLIF4_spike_shape") {keepcols=c(2,3,4,5,6,7,8,9,10,13:16)}
  if (nameval=="Features") {keepcols=1:ncol(featfulldat_all)}
  if (nameval=="Featuresnospike") {keepcols=c(1,2,3,4,5,8,11,12,13,14)}
  
  if (nameval %in% c("Features","Featuresnospike")) {
    startmat=featfulldat_all
  } else {
    startmat=fulldat_all
  }

  print(paste0("clustering ",nameval," model, using the following parameters: ",paste(colnames(startmat)[keepcols],collapse=",")))
  startnames=rep("1",nrow(startmat))
  names(startnames)=rownames(startmat)
  sublist[[nameval]]=list()
  cellnum=round(nrow(startmat)*0.8)
  for (ii in 1:100) {
    set.seed(ii)
    startmatrows=sample(1:nrow(startmat),cellnum)
    outlist=list()
    outlist$clustnames=startnames
    outlist$fracmat=c()
    allclusts=recursive_clustering(keepcols,startmat[startmatrows,],fraclim=fraclimval,splitlim=splitlimval,startseed=1,outlist=outlist)
    sublist[[nameval]][[ii]]=allclusts$clustnames
  }
  save(sublist,file=paste0("bootstrappedclusters_",pref,".rda"))
}


[1] "clustering Features model, using the following parameters: tau,ri,vrest,threshold_i_long_square,threshold_v_long_square,peak_v_long_square,fast_trough_v_long_square,trough_v_long_square,upstroke_downstroke_ratio_long_square,upstroke_downstroke_ratio_short_square,sag,f_i_curve_slope,latency,max_burstiness_across_sweeps"
[1] "clustering Featuresnospike model, using the following parameters: tau,ri,vrest,threshold_i_long_square,threshold_v_long_square,trough_v_long_square,sag,f_i_curve_slope,latency,max_burstiness_across_sweeps"
[1] "clustering GLIF1 model, using the following parameters: R_input,C,El,th_inf,spike_cut_length"
[1] "clustering GLIF2 model, using the following parameters: R_input,C,El,th_inf,spike_cut_length,reset_slope,reset_intercept"
[1] "clustering GLIF3 model, using the following parameters: R_ASC,C,El,th_inf,total charge 1/300+1/100,total charge 1/3+1/10+1/100,spike_cut_length"
[1] "clustering GLIF4 model, using the following parameters: R_ASC,C,El,th_inf,total ch

### 7) Calculate Adjusted Rand and Adjusted Variation of Information Index confidence intervals between all clusterings and Cre line segregation, with confidence intervals based on bootstrapping
#### This generates the left panel of Supplemental Figure 17

In [9]:
pref="iterative_binary_clustering_2018"
cre_sub_voi=list()
cre_sub_ari=list()
load(paste0("bootstrappedclusters_",pref,".rda"))
featclust=featmetadata
for (nameval in names(sublist)) {
  cre_sub_voi[[nameval]]=c()
  cre_sub_ari[[nameval]]=c()
  for (ii in 1:length(sublist[[nameval]])) {
    checkclust=sublist[[nameval]][[ii]]
    checkclust=checkclust[checkclust!=1]
    keepclust=featclust[names(checkclust),1]
    set.seed(ii)
    rand_voi=rand_cluster_diff(keepclust,checkclust,1)
    rand_ari=rand_cluster_diff(keepclust,checkclust,2)
    cre_sub_voi[[nameval]]=c(cre_sub_voi[[nameval]],mean(rand_voi)-calc_cluster_diff(keepclust,checkclust,1))
    cre_sub_ari[[nameval]]=c(cre_sub_ari[[nameval]],calc_cluster_diff(keepclust,checkclust,2)-mean(rand_ari))
  }
}

cre_voi=c()
cre_ari=c()
cre_mean_voi=c()
cre_mean_ari=c()
for (nameval in names(sublist)) {
  glifclust=read.csv(paste0("cluster_ids_",pref,"_",nameval,".csv"),as.is=T)
  glifclust=glifclust[match(rownames(featclust),glifclust[,1]),]
  cre_voi=c(cre_voi,calc_cluster_diff(featclust[,1],glifclust[,2],1))
  cre_ari=c(cre_ari,calc_cluster_diff(featclust[,1],glifclust[,2],2))
  rand_voi=rand_cluster_diff(featclust[,1],glifclust[,2],1)
  rand_ari=rand_cluster_diff(featclust[,1],glifclust[,2],2)
  cre_mean_voi=c(cre_mean_voi,mean(rand_voi))
  cre_mean_ari=c(cre_mean_ari,mean(rand_ari))
}

pdf(paste0("Fig_Supp17_comparison_to_Cre_lines_with_CIs_",pref,".pdf"),useDingbats=F)
par(mar = c(5,5,2,5))
nameval=c("GLIF1","GLIF2","GLIF3","GLIF4","Featuresnospike","Features","GLIF1_spike_shape","GLIF2_spike_shape","GLIF3_spike_shape","GLIF4_spike_shape")
U=c()
L=c()
M=c()
for (mm in nameval) {
  quants=quantile(cre_sub_voi[[mm]],probs=c(0.05,0.5,0.95))
  U=c(U,quants[3])
  L=c(L,quants[1])
  M=c(M,quants[2])
}
mval=cre_mean_voi-cre_voi
plotCI(1:length(mval),M,ui=U,li=L,col="black",ylab="Adjusted VOI score",xaxt='n',xlab='',main="Comparison to Cre lines",ylim=c(0,max(c(U,L))))
axis(side=1,at=1:length(cre_voi),labels=c("GLIF1","GLIF2","GLIF3","GLIF4","Features, no\nspike-shape","Features","GLIF1+\nSpike Shape","GLIF2+\nSpike Shape","GLIF3+\nSpike Shape","GLIF4+\nSpike Shape"),las=2)
par(new = T)
U=c()
L=c()
M=c()
for (mm in nameval) {
  quants=quantile(cre_sub_ari[[mm]],probs=c(0.05,0.5,0.95))
  U=c(U,quants[3])
  L=c(L,quants[1])
  M=c(M,quants[2])
}
mval=cre_ari-cre_mean_ari
plotCI((1:length(mval))+0.1, M,ui=U,li=L,col="red", axes=F, xlab=NA, ylab=NA,xlim=c(1,length(mval)),ylim=c(0,max(c(U,L))))
axis(side=4,labels=F)
at = axTicks(4)
mtext(side = 4, text = at, at = at, col = "red", line = 1)
mtext(side = 4, line = 3, 'Adjusted Rand Index',col='red')
legend("topleft",c("Adjusted VOI","Adjusted Rand Index"),fill=c("black","red"))
dev.off()

### 8) Calculate Adjusted Rand and Adjusted Variation of Information Indices between GLIF clusterings and electrophysiological feature clustering, with confidence intervals based on bootstrapping
#### This generates the right panel of Supplemental Figure 17

In [11]:
pref="iterative_binary_clustering_2018"
sub_voi=list()
sub_ari=list()
load(paste0("bootstrappedclusters_",pref,".rda"))
featclust=read.csv(paste0("cluster_ids_",pref,"_Features.csv"),as.is=T,row.names=1)
namevals=c("GLIF1","GLIF2","GLIF3","GLIF4","Featuresnospike","GLIF1_spike_shape","GLIF2_spike_shape","GLIF3_spike_shape","GLIF4_spike_shape")
for (nameval in namevals) {
  sub_voi[[nameval]]=c()
  sub_ari[[nameval]]=c()
  for (ii in 1:length(sublist[[nameval]])) {
    checkclust=sublist[[nameval]][[ii]]
    checkclust=checkclust[checkclust!=1]
    keepclust=featclust[names(checkclust),1]
    set.seed(ii)
    rand_voi=rand_cluster_diff(keepclust,checkclust,1)
    rand_ari=rand_cluster_diff(keepclust,checkclust,2)
    sub_voi[[nameval]]=c(sub_voi[[nameval]],mean(rand_voi)-calc_cluster_diff(keepclust,checkclust,1))
    sub_ari[[nameval]]=c(sub_ari[[nameval]],calc_cluster_diff(keepclust,checkclust,2)-mean(rand_ari))
  }
}
all_voi=c()
all_ari=c()
mean_voi=c()
mean_ari=c()
for (nameval in namevals) {
  glifclust=read.csv(paste0("cluster_ids_",pref,"_",nameval,".csv"),as.is=T)
  glifclust=glifclust[match(rownames(featclust),glifclust[,1]),]
  all_voi=c(all_voi,calc_cluster_diff(featclust[,1],glifclust[,2],1))
  all_ari=c(all_ari,calc_cluster_diff(featclust[,1],glifclust[,2],2))
  rand_voi=rand_cluster_diff(featclust[,1],glifclust[,2],1)
  rand_ari=rand_cluster_diff(featclust[,1],glifclust[,2],2)
  mean_voi=c(mean_voi,mean(rand_voi))
  mean_ari=c(mean_ari,mean(rand_ari))
}
pdf(paste0("Fig_Supp17_comparison_to_feature_clustering_with_CIs_",pref,".pdf"),useDingbats=F)
par(mar = c(5,5,2,5))
U=c()
L=c()
M=c()
for (mm in namevals) {
  quants=quantile(sub_voi[[mm]],probs=c(0.05,0.5,0.95))
  U=c(U,quants[3])
  L=c(L,quants[1])
  M=c(M,quants[2])
}
mval=mean_voi-all_voi
plotCI(1:length(mval),M,ui=U,li=L,col="black",ylab="Adjusted VOI score",xaxt='n',xlab='',main="Comparison to clustering by features",ylim=c(0,max(c(U,L))))
axis(side=1,at=1:length(all_voi),labels=c("GLIF1","GLIF2","GLIF3","GLIF4","Features, no\nspike-shape","GLIF1+\nSpike Shape","GLIF2+\nSpike Shape","GLIF3+\nSpike Shape","GLIF4+\nSpike Shape"),las=2)
par(new = T)
U=c()
L=c()
M=c()
for (mm in namevals) {
  quants=quantile(sub_ari[[mm]],probs=c(0.05,0.5,0.95))
  U=c(U,quants[3])
  L=c(L,quants[1])
  M=c(M,quants[2])
}
mval=all_ari-mean_ari
plotCI((1:length(mval))+0.1, M,ui=U,li=L,col="red", axes=F, xlab=NA, ylab=NA,xlim=c(1,length(mval)),ylim=c(0,max(c(U,L))))
axis(side=4,labels=F)
at = axTicks(4)
mtext(side = 4, text = at, at = at, col = "red", line = 1)
mtext(side = 4, line = 3, 'Adjusted Rand Index',col='red')
legend("topleft",c("Adjusted VOI","Adjusted Rand Index"),fill=c("black","red"))
dev.off()