Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/MouseLand/rastermap
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Jan 16, 2020
2 parents 6cedd0c + ec0a963 commit 11d4641
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 44 deletions.
147 changes: 104 additions & 43 deletions rastermap/gui.py
Expand Up @@ -8,9 +8,10 @@
from pyqtgraph import GraphicsScene
from scipy.stats import zscore
from matplotlib import cm
from rastermap.roi import gROI
from rastermap.roi import gROI, dbROI
import rastermap.run
from rastermap import Rastermap
from sklearn.cluster import DBSCAN

def triangle_area(p):
area = 0.5 * np.abs(p[0,0] * p[1,1] - p[0,0] * p[2,1] +
Expand Down Expand Up @@ -180,7 +181,7 @@ def __init__(self):
self.p1 = self.win.addPlot(row=1, col=2, colspan=3,
rowspan=3, invertY=True, padding=0)
self.p1.setMouseEnabled(x=False, y=False)
self.img = pg.ImageItem(autoDownsample=False)
self.img = pg.ImageItem(autoDownsample=True)
self.p1.hideAxis('left')
self.p1.setMenuEnabled(False)
self.p1.scene().contextMenuItem = self.p1
Expand Down Expand Up @@ -231,16 +232,31 @@ def __init__(self):
self.makegrid.clicked.connect(self.make_grid)
self.makegrid.setStyleSheet(self.styleInactive)
self.makegrid.setEnabled(False)
self.makegrid.setFixedWidth(100)
self.makegrid.setFixedWidth(200)
self.l0.addWidget(self.makegrid, rs+7, 0, 1, 1)
self.gridsize = QtGui.QLineEdit(self)
self.gridsize.setValidator(QtGui.QIntValidator(0, 500))
self.gridsize.setText("10")
self.gridsize.setValidator(QtGui.QIntValidator(2, 20))
self.gridsize.setText("5")
self.gridsize.setFixedWidth(45)
self.gridsize.setAlignment(QtCore.Qt.AlignRight)
self.gridsize.returnPressed.connect(self.make_grid)
self.l0.addWidget(self.gridsize, rs+7, 1, 1, 1)

self.dbbutton = QtGui.QPushButton("DBSCAN clusters, ms=")
self.dbbutton.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
self.dbbutton.clicked.connect(self.dbscan)
self.dbbutton.setStyleSheet(self.styleInactive)
self.dbbutton.setEnabled(False)
self.dbbutton.setFixedWidth(200)
self.l0.addWidget(self.dbbutton, rs+8, 0, 1, 1)
self.min_samples = QtGui.QLineEdit(self)
self.min_samples.setValidator(QtGui.QIntValidator(5, 200))
self.min_samples.setText("50")
self.min_samples.setFixedWidth(45)
self.min_samples.setAlignment(QtCore.Qt.AlignRight)
self.min_samples.returnPressed.connect(self.dbscan)
self.l0.addWidget(self.min_samples, rs+8, 1, 1, 1)

ysm = QtGui.QLabel("<font color='white'>y-binning</font>")
ysm.setFixedWidth(100)
self.l0.addWidget(ysm, rs+6, 0, 1, 1)
Expand Down Expand Up @@ -295,11 +311,13 @@ def __init__(self):
# self.load_behavior('C:/Users/carse/github/TX4/beh.npy')
self.file_iscell = None
#self.fname = '/media/carsen/DATA2/grive/rastermap/DATA/embedding.npy'
#self.load_proc(self.fname)
self.fname = 'D:/grive/cshl_suite2p/TX39/embedding.npy'
self.load_proc(self.fname)

self.show()
self.win.show()


def add_imgROI(self):
if hasattr(self, 'imgROI'):
self.pfull.removeItem(self.imgROI)
Expand Down Expand Up @@ -343,23 +361,24 @@ def smooth_activity(self):
self.sp_smoothed /= 12

def plot_activity(self):
self.smooth_activity()
nn = self.sp_smoothed.shape[0]
nt = self.sp_smoothed.shape[1]
self.imgfull.setImage(self.sp_smoothed)
self.imgfull.setLevels([self.sat[0],self.sat[1]])
self.img.setImage(self.sp_smoothed)
self.img.setLevels([self.sat[0],self.sat[1]])
self.p1.setXRange(0, nt, padding=0)
self.p1.setYRange(0, nn, padding=0)
self.p1.setLimits(xMin=0,xMax=nt,yMin=0,yMax=nn)
self.pfull.setXRange(0, nt, padding=0)
self.pfull.setYRange(0, nn, padding=0)
self.pfull.setLimits(xMin=-1,xMax=nt+1,yMin=-1,yMax=nn+1)
self.imgROI.setPos(-.5,-.5)
self.imgROI.setSize([nt+.5,nn+.5])
self.imgROI.maxBounds = QtCore.QRectF(-1.,-1.,nt+1,nn+1)
if 0:
if self.embedded:
self.smooth_activity()
nn = self.sp_smoothed.shape[0]
nt = self.sp_smoothed.shape[1]
self.imgfull.setImage(self.sp_smoothed)
self.imgfull.setLevels([self.sat[0],self.sat[1]])
self.img.setImage(self.sp_smoothed)
self.img.setLevels([self.sat[0],self.sat[1]])
self.p1.setXRange(0, nt, padding=0)
self.p1.setYRange(0, nn, padding=0)
self.p1.setLimits(xMin=0,xMax=nt,yMin=0,yMax=nn)
self.pfull.setXRange(0, nt, padding=0)
self.pfull.setYRange(0, nn, padding=0)
self.pfull.setLimits(xMin=-1,xMax=nt+1,yMin=-1,yMax=nn+1)
self.imgROI.setPos(-.5,-.5)
self.imgROI.setSize([nt+.5,nn+.5])
self.imgROI.maxBounds = QtCore.QRectF(-1.,-1.,nt+1,nn+1)
else:
nn = self.sp.shape[0]
nt = self.sp.shape[1]
self.imgfull.setImage(self.sp)
Expand All @@ -382,8 +401,10 @@ def plot_activity(self):
def plot_colorbar(self):
nneur = self.colormat_plot.shape[0]
self.colorimg.setImage(self.colormat_plot)
N = int(self.smooth.text())
NN = self.sp_smoothed.shape[0]*N
if self.embedded:
N = int(self.smooth.text())
else:
N = 1
self.p3.setYRange(self.yrange[0]*N, self.yrange[-1]*N)
self.p3.setXRange(0,10)
self.p3.setLimits(yMin=self.yrange[0]*N,yMax=self.yrange[-1]*N,xMin=0,xMax=10)
Expand Down Expand Up @@ -415,17 +436,17 @@ def imgROI_range(self):
yrange = (np.arange(0,int(sizey)) + np.floor(posy)).astype(np.int32)
xrange = xrange[xrange>=0]
yrange = yrange[yrange>=0]
#if self.embedded:
xrange = xrange[xrange<self.sp_smoothed.shape[1]]
yrange = yrange[yrange<self.sp_smoothed.shape[0]]
#else:
# xrange = xrange[xrange<self.sp.shape[1]]
# yrange = yrange[yrange<self.sp.shape[0]]
if self.embedded:
xrange = xrange[xrange<self.sp_smoothed.shape[1]]
yrange = yrange[yrange<self.sp_smoothed.shape[0]]
else:
xrange = xrange[xrange<self.sp.shape[1]]
yrange = yrange[yrange<self.sp.shape[0]]
return xrange,yrange

def imgROI_position(self):
xrange,yrange = self.imgROI_range()
if 1:
if self.embedded:
self.img.setImage(self.sp_smoothed[np.ix_(yrange,xrange)])
else:
self.img.setImage(self.sp[np.ix_(yrange,xrange)])
Expand All @@ -437,7 +458,10 @@ def imgROI_position(self):
axy = self.p3.getAxis('left')
axx = self.p1.getAxis('bottom')
self.plot_colorbar()
N = int(self.smooth.text())
if self.embedded:
N = int(self.smooth.text())
else:
N = 1
axy.setTicks([[(0,str(self.yrange[0])),(self.yrange[-1]*N,str(self.yrange[-1]*N))]])
axx.setTicks([[(0.0,str(xrange[0])),(float(xrange.size),str(xrange[-1]))]])

Expand Down Expand Up @@ -481,7 +505,9 @@ def ROI_selection(self, loaded=False):

self.colormat_plot = self.colormat.copy()
self.plot_activity()
print('plotted activity')
self.plot_colorbar()
print('plotted colorbar')
self.win.show()

def update_selected(self, ineur):
Expand All @@ -501,19 +527,52 @@ def update_selected(self, ineur):
ineur = self.selected[ineur]
self.xp.setData(pos=self.embedding[ineur,:][np.newaxis,:])

def dbscan(self):
ms = int(self.min_samples.text())
# remove previous ROIs
if len(self.ROIs) > 0:
for n in range(len(self.ROIs)):
self.ROI_delete()

db = DBSCAN(eps=0.8, min_samples=ms).fit(self.embedding)
ilabels = np.unique(db.labels_)
ilabels = ilabels[ilabels>=0]
print(ilabels)
#ilabels = ilabels[:1]
for i in ilabels:
self.dbROI_add((db.labels_==i).nonzero()[0])
self.ROI_selection()

def make_grid(self):
ng = int(self.gridsize.text())
if len(self.ROIs) > 0:
for n in range(len(self.ROIs)):
self.ROI_delete()
sz = self.embedding.max() / ng
print(sz)
corners = np.array([j*sz for j in range(0,ng)])
print(corners)
#for j in range(ng):
# for k in range(ng):
# prect = np.array([[corners[j],corners[k]],
# [corners[j],corners[k]],
sz = (self.embedding.max() - self.embedding.min()) / ng
corners = np.linspace(self.embedding.min(), self.embedding.max(), ng+1)
jet = cm.get_cmap('jet')
jet = jet(np.linspace(0,1,ng**2))
jet = jet[:,:3]
for j in range(ng):
for k in range(ng):
prect = [np.array([[corners[j],corners[k]],
[corners[j+1],corners[k]],
[corners[j+1],corners[k+1]],
[corners[j],corners[k+1]],
[corners[j],corners[k]]])]
pos = [np.array([[corners[j+1],corners[k]+sz/2],
[corners[j],corners[k]+sz/2]])]
self.ROI_add(pos, prect, color=jet[j+k*ng]*255.0)
self.ROI_selection()

def dbROI_add(self, selected, color=None):
if color is None:
color = np.random.randint(255,size=(3,))
self.ROIs.append(dbROI(selected, color, self))
self.Rselected.append(self.ROIs[-1].selected)
self.Rcolors.append(np.reshape(np.tile(self.ROIs[-1].color, 10 * self.Rselected[-1].size),
(self.Rselected[-1].size, 10, 3)))
self.ROIorder.append(len(self.ROIs)-1)

def ROI_add(self, pos, prect, color=None):
if color is None:
Expand Down Expand Up @@ -570,9 +629,12 @@ def enable_embedded(self):
self.updateROI.setEnabled(True)
self.saveROI.setEnabled(True)
self.makegrid.setEnabled(True)
self.dbbutton.setEnabled(True)

self.updateROI.setStyleSheet(self.styleUnpressed)
self.saveROI.setStyleSheet(self.styleUnpressed)
self.makegrid.setStyleSheet(self.styleUnpressed)
self.dbbutton.setStyleSheet(self.styleUnpressed)

def disable_embedded(self):
self.updateROI.setEnabled(False)
Expand Down Expand Up @@ -633,8 +695,6 @@ def mouse_moved_bar(self, pos):
ineur = min(self.colormat.shape[0]-1, max(0, int(np.floor(y))))
self.update_selected(ineur)



def plot_clicked(self, event):
"""left-click chooses a cell, right-click flips cell to other view"""
flip = False
Expand Down Expand Up @@ -778,6 +838,7 @@ def load_mat(self, name=None):
self.ROI_selection()
self.enable_loaded()
self.show()
print('done loading')
self.loaded = True

def load_iscell(self):
Expand Down
30 changes: 29 additions & 1 deletion rastermap/roi.py
Expand Up @@ -11,6 +11,34 @@ def triangle_area(p0, p1, p2):
p2[:,0] * p0[1] - p2[:,0] * p1[1])
return area

class dbROI():
"""
ROI that is premade from DBSCAN clustering
"""
def __init__(self, selected, color, parent=None):
self.color = color
self.selected = selected
self.positions = parent.embedding[self.selected, :]
self.pen = pg.mkPen(pg.mkColor(self.color),
width=1,
style=QtCore.Qt.SolidLine)
self.ROIplot = pg.ScatterPlotItem(pos=self.positions, pen=self.pen, symbol='o', size=2)
parent.p0.addItem(self.ROIplot)
parent.p0.removeItem(parent.xp)
parent.p0.addItem(parent.xp)

def inROI(self, Y):
dists = np.zeros((Y.shape[0],))
for k,y in enumerate(Y):
dists[k] = (((self.positions - self.y)**2).sum(axis=1)**0.5).min()
inroi = dists < 0.5

return Y[inroi], dists[inroi]

def remove(self, parent):
'''remove ROI'''
parent.p0.removeItem(self.ROIplot)

class gROI():
'''
draw a line segment which is the gradient over which to plot the points
Expand All @@ -21,7 +49,7 @@ def __init__(self, pos, prect, color, parent=None):
self.d = ((prect[0][0,:] - prect[0][1,:])**2).sum()**0.5 / 2
#self.slope = (pos[1,1] - pos[0,1]) / (pos[1,0] - pos[0,0])
#self.yint = pos[1,0] - self.slope * pos[0,0]
np.save('groi.npy', {'prect': self.prect, 'pos': self.pos})
#np.save('groi.npy', {'prect': self.prect, 'pos': self.pos})
self.color = color
self.pen = pg.mkPen(pg.mkColor(self.color),
width=3,
Expand Down

0 comments on commit 11d4641

Please sign in to comment.