Skip to content

Commit

Permalink
Add appropriate noise model; set default params
Browse files Browse the repository at this point in the history
  • Loading branch information
jenniferColonell committed Aug 18, 2019
1 parent cb8afea commit aafcb77
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 113 deletions.
Binary file added eMouse_drift/SC026_3Bstag_noiseModel.mat
Binary file not shown.
51 changes: 38 additions & 13 deletions eMouse_drift/benchmark_drift_simulation.m
@@ -1,13 +1,12 @@
function benchmark_drift_simulation(rez, GTfilepath, simRecfilepath, sortType, bAutoMerge)
function benchmark_drift_simulation(rez, GTfilepath, simRecfilepath, sortType, bAutoMerge, varargin)

%for testing outside a script. comment out for normal calling!
% load('D:\drift_simulations\74U_norm_64site_20um_600sec_20min\rezFinal.mat');
% GTfilepath = 'D:\drift_simulations\74U_norm_64site_20um_600sec_20min\eMouseGroundTruth.mat';
% simRecfilepath = 'D:\drift_simulations\74U_norm_64site_20um_600sec_20min\eMouseSimRecord.mat';
% load('D:\drift_simulations\74U_norm_64site_20um_600sec_20min\ks2_master_060919\rezFinal.mat');
% GTfilepath = 'D:\drift_simulations\74U_norm_64site_20um_600sec_20min\ks2_master_060919\eMouseGroundTruth.mat';
% simRecfilepath = 'D:\drift_simulations\74U_norm_64site_20um_600sec_20min\ks2_master_060919\eMouseSimRecord.mat';
% sortType = 2;
% bAutoMerge = 0;


load(GTfilepath);

if bAutoMerge
Expand All @@ -16,6 +15,15 @@ function benchmark_drift_simulation(rez, GTfilepath, simRecfilepath, sortType, b
testClu = rez.st3(:,2) ;
end

bOutFile = 0;
%fprintf( 'length of vargin: %d\n', numel(varargin));
if( numel(varargin) == 1)
%path for output file
bOutFile = 1;
fprintf( 'output filename: %s\n', varargin{1} );
out_fid = fopen( varargin{1}, 'w' );
end

testRes = rez.st3(:,1);

[testRes, tOrder] = sort(testRes);
Expand Down Expand Up @@ -122,19 +130,36 @@ function benchmark_drift_simulation(rez, GTfilepath, simRecfilepath, sortType, b
meanPos(i) = mean(yDriftRec(ind,2));
end

fprintf('GTlabel\tnSpike\tmeanAmp\tmeanPos\tundetected\tbestMiss\tbestFP\tbestScore\tautoMiss\tautoFP\tautoScore\tnMerges\tphy labels\n');
if( bOutFile )
fprintf(out_fid, 'GTlabel\tnSpike\tmeanAmp\tmeanPos\tundetected\tbestMiss\tbestFP\tbestScore\tautoMiss\tautoFP\tautoScore\tnMerges\tphy labels\n');
else
fprintf('GTlabel\tnSpike\tmeanAmp\tmeanPos\tundetected\tbestMiss\tbestFP\tbestScore\tautoMiss\tautoFP\tautoScore\tnMerges\tphy labels\n');
end

nMerges = zeros(1,NN);
for i = 1:NN
nMerges(i) = length(allMerges{i})-1;
fprintf('%d\t%d\t%.3f\t%.3f\t%.3f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%d\t', ...
i, nSpike(i), meanAmp(i), ...
meanPos(i), unDetected(i), bestMiss(i), bestFP(i),bestScore(i), autoMiss(i),...
autoFP(i), autoScore(i), nMerges(i));
for j = 1:length(allMerges{i})-1
fprintf( '%d,', allMerges{i}(j)-1 );
if( bOutFile)
fprintf(out_fid, '%d\t%d\t%.3f\t%.3f\t%.3f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%d\t', ...
i, nSpike(i), meanAmp(i), ...
meanPos(i), unDetected(i), bestMiss(i), bestFP(i),bestScore(i), autoMiss(i),...
autoFP(i), autoScore(i), nMerges(i));
for j = 1:length(allMerges{i})-1
fprintf( out_fid, '%d,', allMerges{i}(j)-1 );
end
fprintf(out_fid,'%d\n',allMerges{i}(length(allMerges{i}))-1);
else
fprintf('%d\t%d\t%.3f\t%.3f\t%.3f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%d\t', ...
i, nSpike(i), meanAmp(i), ...
meanPos(i), unDetected(i), bestMiss(i), bestFP(i),bestScore(i), autoMiss(i),...
autoFP(i), autoScore(i), nMerges(i));
for j = 1:length(allMerges{i})-1
fprintf( '%d,', allMerges{i}(j)-1 );
end
fprintf('%d\n',allMerges{i}(length(allMerges{i}))-1);
end
fprintf('%d\n',allMerges{i}(length(allMerges{i}))-1);
end

fclose(out_fid);

end
5 changes: 4 additions & 1 deletion eMouse_drift/config_eMouse_drift_KS2.m
Expand Up @@ -8,7 +8,7 @@
ops.minfr_goodchannels = 0.1;

% threshold on projections (like in Kilosort1, can be different for last pass like [10 4])
ops.Th = [10 4];
ops.Th = [6 2];

% how important is the amplitude penalty (like in Kilosort1, 0 means not used, 10 is average, 50 is a lot)
ops.lam = 10;
Expand All @@ -27,6 +27,9 @@

% threshold crossings for pre-clustering (in PCA projection space)
ops.ThPre = 8;
ops.reorder = 1; % whether to reorder batches for drift correction.
ops.nskip = 25; % how many batches to skip for determining spike PCs

%% danger, changing these settings can lead to fatal errors
% options for determining PCs
ops.spkTh = -6; % spike threshold in standard deviations (-6)
Expand Down
5 changes: 3 additions & 2 deletions eMouse_drift/make_eMouseChannelMap_3A_short.m
@@ -1,6 +1,7 @@
function [chanMapName] = make_eMouseChannelMap_3A_short(fpath,NchanTOT)
function [chanMapName] = make_eMouseChannelMap_3B_short(fpath,NchanTOT)
% create a channel Map file for simulated data on a section of
% an imec 3A probe (eMouse)
% an imec 3B probe to use with eMouse
% essentially identical to the 3A version

% total number of channels = 385 (in real 3A, 384 channels + digital)
chanMap = (1:NchanTOT)';
Expand Down
58 changes: 58 additions & 0 deletions eMouse_drift/make_eMouseChannelMap_3B_short.m
@@ -0,0 +1,58 @@
function [chanMapName] = make_eMouseChannelMap_3B_short(fpath,NchanTOT)
% create a channel Map file for simulated data on a section of
% an imec 3B probe to use with eMouse
% essentially identical to the 3A version

% total number of channels = 385 (in real 3A, 384 channels + digital)
chanMap = (1:NchanTOT)';

% channels to ignore in analysis and when adding data
% in real 3B data, include the reference channels and the digital channel
% replicate the refChans that are within range of the short probe to
% preserve the geometry for channel to channel correlation in noise
% generation

allRef = [192];
refChan = allRef( find(allRef < max(NchanTOT)) );


connected = true(NchanTOT,1);
connected(refChan) = 0;

% copy the coordinates from MP chanmap for 3A.
halfChan = floor(NchanTOT/2);

xcoords = zeros(NchanTOT,1,'double');
ycoords = zeros(NchanTOT,1,'double');

xcoords(1:4:NchanTOT) = 43;
xcoords(2:4:NchanTOT) = 11;
xcoords(3:4:NchanTOT) = 59;
xcoords(4:4:NchanTOT) = 27;

ycoords(1:2:NchanTOT) = 20*(1:halfChan);
ycoords(2:2:NchanTOT) = 20*(1:halfChan);


% Often, multi-shank probes or tetrodes will be organized into groups of
% channels that cannot possibly share spikes with the rest of the probe. This helps
% the algorithm discard noisy templates shared across groups. In
% this case, we set kcoords to indicate which group the channel belongs to.
% In our case all channels are on the same shank in a single group so we
% assign them all to group 1.
% Note that kcoords is not yet implemented in KS2 (08/15/2019)

kcoords = ones(NchanTOT,1);

% at this point in Kilosort we do data = data(connected, :), ycoords =
% ycoords(connected), xcoords = xcoords(connected) and kcoords =
% kcoords(connected) and no more channel map information is needed (in particular
% no "adjacency graphs" like in KlustaKwik).
% Now we can save our channel map for the eMouse.

% would be good to also save the sampling frequency here
fs = 30000;

chanMapName = sprintf('chanMap_3B_%dsites.mat', NchanTOT);

save(fullfile(fpath, chanMapName), 'chanMap', 'connected', 'xcoords', 'ycoords', 'kcoords', 'fs', 'NchanTOT' )
74 changes: 52 additions & 22 deletions eMouse_drift/make_eMouseData_drift.m
Expand Up @@ -24,7 +24,8 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
drift.y0 = 3800; %in um, position along probe where motion is largest
%y = 0 is the tip of the probe
drift.halfDistance = 1000; %in um, distance along probe over which the motion decays
drift.amplitude = 20; %in um
drift.amplitude = 10; %in um for a sine wave
% peak variation is 2Xdrift.amplitude
drift.halfLife = 2; %in seconds
drift.period = 600; %in seconds
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Expand All @@ -49,10 +50,10 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)

if useDefault
%get waveforms from eMouse folder in KS2
filePath{1} = [KS2path,'\eMouse_drift\','kampff_St_unit_waves_allNeg_2X'];
fileCopies(1) = 2;
filePath{2} = [KS2path,'\eMouse_drift\','121817_single_unit_waves_allNeg.mat'];
fileCopies(2) = 2;
filePath{1} = [KS2path,'\eMouse_drift\','kampff_St_unit_waves_allNeg_2X.mat'];
fileCopies(1) = 2;
filePath{2} = [KS2path,'\eMouse_drift\','121817_SU_waves_allNeg_gridEst.mat'];
fileCopies(2) = 2;
else
%fill in paths to waveform files
filePath = {};
Expand All @@ -67,9 +68,14 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
pairDist = 50; % distance between paired units
bPlot = 0; %make diagnostic plots of waveforms

% Noise can either be generated from a gaussian distribution, or modeled on
% noise from real data, matching the frequency spectrum and cross channel
% correlation. The sample noise data is taken from a 3B2 recording performed
% by Susu Chen. Note that the noise data should come from a probe with
% the same geometry as the model probe.
if ( strcmp(noise_model,'fromData') )
if useDefault
nmPath = [KS2path,'\eMouse_drift\','Waksman_3A_noiseModel.mat'];
nmPath = [KS2path,'\eMouse_drift\','SC026_3Bstag_noiseModel.mat'];
else
%fill in path to desired noise model.mat file
end
Expand Down Expand Up @@ -173,16 +179,25 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
for i = 1:nUnit
uData.uColl(i).waves = uData.uColl(i).waves*scaleIntensity(i);
end
end

end

%deprecated
%allowed use of scattered interpolants (non-grid sampling of the
%waveforms -- but the calculations are 5-10X slower than gridded
%interpolants. For non-square sampling of waveforms (e.g. Neuropixels
%probes) -- make a scattered interpolant and then sample on a gridded
%interpolant to feed to the simulator.
%for these units, create an intepolant which will be used to
%calculate the waveform at arbitrary sites

if (uType(unitRange(1)) == 1)
[uFcurr, uRcurr] = makeGridInt( uData.uColl, nt );
else
[uFcurr, uRcurr] = makeScatInt( uData.uColl, nt );
end
% if (uType(unitRange(1)) == 1)
% [uFcurr, uRcurr] = makeGridInt( uData.uColl, nt );
% else
% [uFcurr, uRcurr] = makeScatInt( uData.uColl, nt );
% end

% with both types gridded, always make a gridded interpolant
[uFcurr, uRcurr] = makeGridInt( uData.uColl, nt );

%append these to the array over all units
uF = [uF, uFcurr];
Expand All @@ -195,7 +210,12 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
NN = NN + nUnit;
end


% calculate a size for each unit
uSize = zeros(1,NN);
for i = 1:NN
uSize(i) = (uR(i).maxX - uR(i).minX) * (uR(i).maxY - uR(i).minY);
end

% distribute units along the length the probe, either in pairs separated
% by unitDist um, or fully randomly.
% for now, keep x = the original position (i.e. don't try to recenter)
Expand Down Expand Up @@ -407,6 +427,9 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
yDriftRec = zeros( length(spk_times), 5, 'double' );
allspks = 0;

% The parpool option can speed up the calculation when using scattered
% interpolants AND running with a large number of workers (>8). With all
% gridded interpolants, the overhead is too large and
if (useParPool)
%delete any currently running pool
delete(gcp('nocreate'))
Expand Down Expand Up @@ -446,7 +469,7 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
dat = dat*rms_noise;
%fprintf( 'Noise mean = %.3f; std = %.3f\n', mean(dat(:,1)), std(dat(:,1)));
elseif ( strcmp(noise_model,'fromData') )
enoise = makeNoise( NT, noiseFromData, chanMap, NchanTOT );
enoise = makeNoise( NT, noiseFromData, chanMap, connected, NchanTOT );
if t_all>0
enoise(1:buff, :) = enoise_old(NT-buff + [1:buff], :);
end
Expand Down Expand Up @@ -510,7 +533,7 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
tempWav = squeeze(currWavArray(i,1:currNsiteArray(i),:));
dat(ts(i) + tRange, uSites) = dat(ts(i) + tRange, uSites) + tempWav';
else
%calculate the interpolants now
%calculate the interpolations now
[tempWav, uSites] = ...
intWav( uF{cc}, uX(cc), currYPos, uR(cc), xcoords, ycoords, connected, nt );

Expand Down Expand Up @@ -649,9 +672,10 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)

function [uWav, uSites] = intWav( currF, xPos, yPos, uR, xcoords, ycoords, connected, nt )


% figure out for which sites we need to calculate the waveform
uSites = findSites( xPos, yPos, xcoords, ycoords, connected, uR );

% given an array of sites on the probe, calculate the waveform using
% the interpolant determined for this unit
% xPos and yPos are the positions of the current unit
Expand All @@ -666,7 +690,10 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
tq = (double(repmat(1:nt, 1, nSites )))';
%remember, y = rows in the grid, and x = columns in the grid
%interpolation, and scattered interpolation set to match.
nVal = numel(currF.Values);
%tic
uWav = currF( yq, xq, tq );
%fprintf( '%d\t%d\t%.3f\n', numel(uSites), nVal, 1000*toc);
uWav = (reshape(uWav', [nt,nSites]))';

end
Expand Down Expand Up @@ -760,21 +787,22 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
end


function eNoise = makeNoise( noiseSamp, noiseModel,chanMap, NchanTOT )
function eNoise = makeNoise( noiseSamp,noiseModel,chanMap,connected,NchanTOT )

%if chanMap is a short version of a 3A probe, use the first
%nChan good channels to generate noise, then copy that array
%into an NT X NChanTot array

nChan = numel(chanMap); %number of noise channels to generate
tempNoise = zeros( noiseSamp, nChan, 'single' );
goodChan = sum(connected);
tempNoise = zeros( noiseSamp, goodChan, 'single' );
nT_fft = noiseModel.nm.nt; %number of time points in the original time series
fftSamp = noiseModel.nm.fft;

noiseBatch = ceil(noiseSamp/nT_fft);
lastWind = noiseSamp - (noiseBatch-1)*nT_fft; %in samples

for j = 1:nChan
for j = 1:goodChan
for i = 1:noiseBatch-1
tStart = (i-1)*nT_fft+1;
tEnd = i * nT_fft;
Expand All @@ -788,15 +816,17 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool)
end

%unwhiten this array
Wrot = noiseModel.nm.Wrot(1:nChan,1:nChan);
Wrot = noiseModel.nm.Wrot(1:goodChan,1:goodChan);
tempNoise_unwh = tempNoise/Wrot;

%scale to uV; will get scaled back to bits at the end
tempNoise_unwh = tempNoise_unwh/noiseModel.nm.bitPerUV;

%to get the final noise array, map to an array including all channels
eNoise = zeros(noiseSamp, NchanTOT, 'single');
eNoise(:,chanMap) = tempNoise_unwh;
%indicies of the good channels
goodChanIndex = find(connected);
eNoise(:,chanMap(goodChanIndex)) = tempNoise_unwh;

end

Expand Down

0 comments on commit aafcb77

Please sign in to comment.