In [None]:
using ProgressMeter, PyCall, PyPlot, Cairo, Images, HDF5, MultivariateStats, Interpolations, Lasso, Distributions, ImageFiltering
using _Data
using  NaNStatistics, Statistics
#using ScikitLearn

include("../Decoding/Decoder_Functions.jl")
include("../Decoding/Decoder_Pipeline.jl")
include("../project_place_cell/functions/func_map.jl")
np = pyimport("numpy")



rc_params = PyDict(pyimport("matplotlib")["rcParams"]);
rc_params["font.sans-serif"] = ["Arial"];
rc_params["font.size"] = 7;
rc_params["lines.linewidth"] = 1;
rc_params["lines.markersize"] = 4;
rc_params["xtick.major.size"] = 2;
rc_params["ytick.major.size"] = 2;
rc_params["axes.spines.top"] = false;
rc_params["axes.spines.right"] = false;

rc_params["xtick.major.pad"] = 2;
rc_params["ytick.major.pad"] = 2;

rc_params["axes.labelpad"] = 2;

cim(img::Matrix{UInt32}) = CairoImageSurface(img, Cairo.FORMAT_RGB24; flipxy = false) 
cim(img::Matrix{ARGB32}) = cim(reinterpret(UInt32, img))
cim(img::Matrix{RGB24}) = cim(reinterpret(UInt32, img))
cim(img::Matrix{UInt8}) = cim(Gray.(reinterpret(N0f8, img)))
cim(img::Array{UInt8,3}) = cim(RGB24.(reinterpret(N0f8, img[:,:,1]), reinterpret(N0f8, img[:,:,2]), reinterpret(N0f8, img[:,:,3])));downsample(img,n=4) = +([img[i:n:n*div(size(img,1)-1,n), j:n:n*div(size(img,2)-1,n)] for i = 1:n, j = 1:n]...)/(n*n);
downsample(img,n=4) = +([img[i:n:n*div(size(img,1)-1,n), j:n:n*div(size(img,2)-1,n)] for i = 1:n, j = 1:n]...)/(n*n);

In [None]:
datasets_corner_cue = 
[
    ["20220407_152537", 4, "jen"],
    ["20220406_111526", 9, "jen"],
    ["20220407_090156", 5, "jen"],
    ["20220417_165530", 25, "jen"],
    ["20220406_153842", 9, "jen"],
    ["20220405_171444", 25, "jen"],
    ["20220416_160516", 6, "jen"]
];

chuyu_server = [4, 9, 5, 2, 9, 4, 6]

lengths = [90, 90, 90, 89, 90, 90, 90, 90, 90];


n_datasets = length(datasets_corner_cue)

## EDF 1 d

In [None]:
activity_bins = 7
activity_shift = 4
n_pos = 60 #number of bins in long side
long_axis_in_mm = 47 #47 for rectangular, 33 for others
use_amount = 1000
file_name = "_decoding_revision.h5"


i=3

experiment_filename = datasets_corner_cue[i][1]
server = datasets_corner_cue[i][2]
experimenter = datasets_corner_cue[i][3]

ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(chuyu_server[i])" ? "/data" : "/nfs/data$(chuyu_server[i])")

ds = Dataset(experiment_filename, experimenter, gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")


In [None]:
readdir(path(ds_Chuyu))

In [None]:

C, heading, img_bg = h5open(ds, "behavior.h5"; raw = true) do file
    read(file, "C"),
    read(file, "heading"),
    read(file, "img_bg")
end;



# orientation-corrected fish location (time binned)
position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n$(n_pos).h5"))
    chamber_roi = read(position_file,"chamber_roi")
    x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
    y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
    speed_mm_s = read(position_file, "speed_mm_s")
    loc_digital = read(position_file, "loc_digital")
    x_digital = read(position_file, "x_digital")
    y_digital = read(position_file, "y_digital")
    x_bins = read(position_file, "x_bins")
    y_bins = read(position_file, "y_bins")
close(position_file)

moving_valid = speed_mm_s .> 0.1;

min_x = floor(Int64, minimum(x_fish_sweep_mean));
max_x = floor(Int64, maximum(x_fish_sweep_mean));

min_y = floor(Int64, minimum(y_fish_sweep_mean));
max_y = floor(Int64, maximum(y_fish_sweep_mean));

interval = maximum([(max_y-min_y+2)/n_pos,(max_x-min_x+2)/n_pos])


file = h5open(joinpath(data_path(ds_Chuyu), "neuron_spatial_info_15_$(lengths[i])_chamber_geometry_$(experiment_filename)_sigma1_n$(n_pos)_A_dF.h5"), "r") #_whole #spatial_info_4 done on merged cells
    print(keys(file))
    place_cell_index = read(file, "place_cell_index")
    valid_neurons = HDF5.readmmap(file["valid_neurons"])
    specificity_shuffle_z = HDF5.readmmap(file["specificity_shuffle_z"])
    specificity_population_z = HDF5.readmmap(file["specificity_population_z"])
    specificity = HDF5.readmmap(file["specificity"])
    specificity_shuffle_p = read(file, "specificity_shuffle_p")
    bool_index = HDF5.readmmap(file["bool_index"])
close(file)

bool_index = BitArray(bool_index)



file = h5open(joinpath(data_path(ds_Chuyu), "NMF_merge.h5"), "r")
    A_dFF = HDF5.readmmap(file["A_dF"])

    z_all = HDF5.readmmap(file["Z_all"])
    centroid_x_all = HDF5.readmmap(file["X_all"])
    centroid_y_all = HDF5.readmmap(file["Y_all"])
close(file)

n_neurons = size(A_dFF, 2)
n_sweeps = size(A_dFF, 1)



function bin_to_px(x, y, offset=true)
    if offset
        return (x .- 0.5) .* interval .+ (min_x-1), (y .- 0.5) .* interval .+ (min_y-1)
    else
        return (x .- 0.5) .* interval, (y .- 0.5) .* interval
    end
end

function px_to_bin(x, y)
    return 0.5 .+ ((x .- (min_x-1)) ./ interval), 0.5 .+ ((y .- (min_y-1)) ./ interval)
end

x_in_bins, y_in_bins = px_to_bin(x_fish_sweep_mean, y_fish_sweep_mean);



place_candidates_unique = Decoder.get_top_neurons(use_amount, specificity_population_z, specificity_shuffle_z);

In [None]:
# compute place maps
mean_map_all = fill(NaN32, n_pos, n_pos, size(A_dFF,2))
count_map_all = fill(NaN32, n_pos, n_pos, size(A_dFF,2))
@showprogress for neuron_idx in 1:size(A_dFF, 2)

    neural_activity, which_loc = MAP.valid_activity_loc(A_dFF[:, neuron_idx], bool_index, loc_digital)
    mean_map_all[:, :, neuron_idx], count_map_all[:, :, neuron_idx], summed = MAP.calculate_map_direct(neural_activity, which_loc, n_pos; at_least_visit = 0, use_gaussian_filter=true, sigma=1, filter_mask = nothing)

end

In [None]:
# helper function
function whether_in(vector, collection)
    return [x in collection for x in vector]
end


# compute for each neuron the times when fish is in its place field
in_field_masks = falses(n_neurons, n_sweeps)
@showprogress for neuron_idx in 1:size(A_dFF, 2)
    
    components_peaks, img_label_valid, valid_components = map_components_peak(mean_map_all[:,:,neuron_idx]; threshold = 9/10, components_size_threshold = 20)
    
    map_mask = img_label_valid .!= 0
    
    in_field_masks[neuron_idx, :] .= Bool[map_mask[x_digital[f], y_digital[f]] for f in 1:n_sweeps]
        
end

In [None]:
# function to define individual traversals (fish leaves place field for less than gap_in_s before it comes back in)
function get_traversals(dat, gap_in_s=5)

    traversals =  Vector{Vector{Int64}}()
    start_temp = NaN
    for i in 1:length(dat)-1
        start_temp = isnan(start_temp) ? dat[i] : start_temp

        if dat[i+1] - dat[i] > gap_in_s*2
            if dat[i] - start_temp > 2
                append!(traversals, [[start_temp, dat[i]]])
            end
            start_temp = NaN
        end
    end
    return traversals
end

In [None]:
# main code to plot place fields and calcium traces

cand = Decoder.get_top_neurons(1000, sum(isnan.(A_dFF[:, place_candidates_unique]), dims=1)[1,:], specificity_population_z[place_candidates_unique]);

start = round(Int, 25.5*120)
minutes = 30

for (i, neuron_idx) in enumerate(place_candidates_unique[cand[1:20]])
    
    figure(figsize=(20,4))
    
    subplot(2,2,1)
    title("$(neuron_idx)")
    sp1 = imshow(mean_map_all[:,:,neuron_idx]' ./ nanpctile(mean_map_all[:,:,neuron_idx], 95), origin="lower", extent = [x_bins[1], x_bins[end], y_bins[1], y_bins[end]], cmap="jet", vmin=0, vmax=1)
    chamber_roi2 = Float32.(copy(chamber_roi))
    chamber_roi2[chamber_roi2 .== 1] .= NaN
    imshow(chamber_roi2', cmap="binary", origin="lower")
    # plot(vcat(countour_array[1:50:end,1][end], countour_array[1:50:end,1]),vcat(countour_array[1:50:end,2][end], countour_array[1:50:end,2]),c="k", alpha=0.1)
    
    #tight_layout()
    axis("off")
    xlim(400,3000)
    ylim(2300, 450)
    
    
    subplot(2,2,2)
    
    traversals = get_traversals(findall(in_field_masks[neuron_idx, :]))

    exits = Int[t[2] for t in traversals];
    entries = Int[t[1] for t in traversals];


    in_field = Float64.(copy(in_field_masks[neuron_idx, :]))
    in_field[in_field .== 0.0] .= NaN

    for t in traversals
        if t[2] > start+120*minutes
            break
        end
        if t[1] < start
            continue
        end
        axvspan(t[1]-start, t[2]-start, color="red", alpha=0.2, lw=0)
    end

    plot(A_dFF[start:start+minutes*120, neuron_idx] ./ nanmaximum(A_dFF[start:start+minutes*120, neuron_idx]), color="black")
    xlabel("Time (minutes)", labelpad=-5); ylabel("Activity (a.u.)")
    xticks([0, minutes/2*120, minutes*120], [0, "", minutes])
    ylim(-0.5, 1)
    
end
    

In [None]:
# check orientation of chamber
imshow(img_bg[:,:,end]', origin="lower")

In [None]:
# export place field
neuron_idx = 58664

fig = figure(figsize=(4,2))
sp1 = imshow(mean_map_all[:,:,neuron_idx]' ./ nanpctile(mean_map_all[:,:,neuron_idx], 99), origin="lower", extent = [x_bins[1], x_bins[end], y_bins[1], y_bins[end]], cmap="jet", vmin=0, vmax=1)
chamber_roi2 = Float32.(copy(chamber_roi))
chamber_roi2[chamber_roi2 .== 1] .= NaN
imshow(chamber_roi2', cmap="binary", origin="lower")
# plot(vcat(countour_array[1:50:end,1][end], countour_array[1:50:end,1]),vcat(countour_array[1:50:end,2][end], countour_array[1:50:end,2]),c="k", alpha=0.1)

tight_layout()
axis("off")
xlim(400,3000)
ylim(2300, 450)
fig.savefig("EDF1d_example_neuron_map_$(i)_$(neuron_idx).pdf", bbox_inches="tight",transparent = true,pad_inches = 0);

In [None]:
# export calcium trace

figure(figsize=(5, 1))

start = round(Int, 25.5*120)
minutes = 30
    
traversals = get_traversals(findall(in_field_masks[neuron_idx, :]))

exits = Int[t[2] for t in traversals];
entries = Int[t[1] for t in traversals];


in_field = Float64.(copy(in_field_masks[neuron_idx, :]))
in_field[in_field .== 0.0] .= NaN

for t in traversals
    if t[2] > start+120*minutes
        break
    end
    if t[1] < start
        continue
    end
    axvspan(t[1]-start, t[2]-start, color="red", alpha=0.2, lw=0)
end


plot(A_dFF[start:start+minutes*120, neuron_idx] ./ nanmaximum(A_dFF[start:start+minutes*120, neuron_idx]), linewidth=0.5, color="black")
xlabel("Time (min)", labelpad=-5); ylabel("Activity (a.u.)")
#xticks(0:50:1200)
xlim(1,minutes*120)
ylim(-0.15, 1)
yticks([0,1])
xticks([0, minutes/2*120, minutes*120], [0, "", minutes])
#for i in exits
#    axvline(i - start)
#end
tight_layout(pad=0.5)
savefig("EDF1d_example_neuron_$(i)_$(neuron_idx)_full.pdf", format="pdf",  transparent=true, dpi=300)

## EDF 1 f

In [None]:
speed_all = []


for i in 1:n_datasets
    
    
    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]
    experimenter = datasets_corner_cue[i][3]

    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(chuyu_server[i])" ? "/data" : "/nfs/data$(chuyu_server[i])")

    ds = Dataset(experiment_filename, experimenter, gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")


    C, heading, img_bg, y_fish, x_offset, x_fish, y_offset = h5open(ds, "behavior.h5"; raw = true) do file
        read(file, "C"),
        read(file, "heading"),
        read(file, "img_bg"),
        read(file, "fish_yolk_y"),
        read(file, "offset_x"),
        read(file, "fish_yolk_x"),
        read(file, "offset_y")
    end;



    # orientation-corrected fish location (time binned)
    position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n60.h5"))
        chamber_roi = read(position_file,"chamber_roi")
        x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
        y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
        speed_mm_s = read(position_file, "speed_mm_s")
        loc_digital = read(position_file, "loc_digital")
        x_digital = read(position_file, "x_digital")
        y_digital = read(position_file, "y_digital")
        x_bins = read(position_file, "x_bins")
        y_bins = read(position_file, "y_bins")
    close(position_file)

    moving_valid = speed_mm_s .> 0.1;

    min_x = floor(Int64, minimum(x_fish_sweep_mean));
    max_x = floor(Int64, maximum(x_fish_sweep_mean));

    min_y = floor(Int64, minimum(y_fish_sweep_mean));
    max_y = floor(Int64, maximum(y_fish_sweep_mean));

    interval = maximum([(max_y-min_y+2)/n_pos,(max_x-min_x+2)/n_pos])


    file = h5open(ds_Chuyu, "neuron_spatial_info_15_$(lengths[i])_chamber_geometry_$(experiment_filename)_sigma1_n60_A_dF.h5", "r")
        valid_neurons = HDF5.readmmap(file["valid_neurons"])
        specificity_shuffle_z = HDF5.readmmap(file["specificity_shuffle_z"])
        specificity_population_z = HDF5.readmmap(file["specificity_population_z"])
        specificity = HDF5.readmmap(file["specificity"])
        bool_index = HDF5.readmmap(file["bool_index"])
         place_cell_index = HDF5.readmmap(file["place_cell_index"])
    close(file)

    bool_index = BitArray(bool_index)



    file = h5open(joinpath(data_path(ds_Chuyu), "NMF_merge.h5"), "r")
        A_dFF = HDF5.readmmap(file["A_dF"])
        neuron_label = HDF5.readmmap(file["neuron_label"])

        z_all = HDF5.readmmap(file["Z_all"])
        centroid_x_all = HDF5.readmmap(file["X_all"])
        centroid_y_all = HDF5.readmmap(file["Y_all"])
    close(file)

    n_neurons = size(A_dFF, 2)
    n_sweeps = size(A_dFF, 1)
    
    
    
        
    file = h5open(joinpath(data_path(ds_Chuyu), "region_roi_bool.h5"))
        region_names = read(file, "region_names")
        region_roi_bool = read(file, "region_roi_bool")
    close(file)


    tel_index = findall(region_names .== "Telencephalon -")[1];
    mask_tel = falses(n_neurons)

    for which_neuron in 1:n_neurons
        mask_tel[which_neuron] = maximum(region_roi_bool[neuron_label.==which_neuron, tel_index])
    end

    mask_tel = findall(mask_tel)

    place_cell_index = intersect(place_cell_index, mask_tel)
    



    function bin_to_px(x, y, offset=true)
        if offset
            return (x .- 0.5) .* interval .+ (min_x-1), (y .- 0.5) .* interval .+ (min_y-1)
        else
            return (x .- 0.5) .* interval, (y .- 0.5) .* interval
        end
    end

    function px_to_bin(x, y)
        return 0.5 .+ ((x .- (min_x-1)) ./ interval), 0.5 .+ ((y .- (min_y-1)) ./ interval)
    end

    x_in_bins, y_in_bins = px_to_bin(x_fish_sweep_mean, y_fish_sweep_mean);

    
    append!(speed_all, [speed_mm_s])
    
end


In [None]:
readdir("/nfs/data4/chuyu/data/20220405_171444")

In [None]:
figure(figsize=(1.5,1.5))
hist(vcat(speed_all...), bins=0:0.5:10, color="lightgray"); 
xlabel("Speed (mm/s)", labelpad=-5); ylabel("Total time (min)", labelpad=-10)
yticks([0, 25, 50, 75, 100, 125].*120, [0, "", "", "", "", 125])
xticks([0,2,4,6,8,10], [0, "", "", "", "", 10])

tight_layout(pad=0.5)
savefig("EDF1f_speed_blurring_time.pdf", format="pdf",  transparent=true, dpi=300)


In [None]:
speed_mean = mean(vcat(speed_all...))
speed_std = std(vcat(speed_all...))

In [None]:
speed_mean

## EDF 1 g

In [None]:
half_decay_80Hz = 1.205
half_decay_40Hz = 1.103
half_decay_20Hz = 0.994
half_decay_10Hz = 0.783
half_decay_5Hz = 0.559
half_decay_1Hz = 0.442

In [None]:
figure(figsize=(1.9,1.6))
hist(vcat(speed_all...) .* half_decay_10Hz, bins=0:0.25:10, label="Best case");
hist(vcat(speed_all...) .* half_decay_80Hz, bins=0:0.25:10, label="Worst case", alpha=0.5);

axvline( (mean(vcat(speed_all...)) - std(vcat(speed_all...))) .* half_decay_10Hz, color="darkblue", ymin=0, ymax=1)
axvline( (mean(vcat(speed_all...)) + std(vcat(speed_all...))) .* half_decay_80Hz, color="darkorange", ymin=0, ymax=0.7)

xlabel("Estimated fluorescence \nhalf decay (mm)", labelpad=-5);
ylabel("Total time (min)", labelpad=-10)
legend(handlelength=1.5, frameon=false, loc=(0.25,0.7))

yticks([0, 25, 50, 75, 100, 125].*120, [0, "", "", "", "", 125])
xticks([0,5,10], [0, "",10])

tight_layout(pad=0.5)
savefig("EDF1g_speed_blurring_space.pdf", format="pdf",  transparent=true, dpi=300)
    

In [None]:
blurring_half_decay_mean_worst = (mean(vcat(speed_all...)) + std(vcat(speed_all...))) .* half_decay_80Hz
blurring_half_decay_mean_best = (mean(vcat(speed_all...)) - std(vcat(speed_all...))) .* half_decay_10Hz

## EDF 1 h

In [None]:
field = zeros(2000) # = 20 mm
field[800:1200] .= 1;

In [None]:
firing_decay = zeros(1000)
half_time_mm = blurring_half_decay_mean_best # +blurring_half_decay_sd
decay = [exp(-log(2)/half_time_mm/100 * i) for i in 0:2000];

In [None]:
field2 = copy(field)
field2[1200:2000] = decay[1:801];
field2[1:800] = decay[800:-1:1];

In [None]:
firing_decay = zeros(1000)
half_time_mm = blurring_half_decay_mean_worst # +blurring_half_decay_sd
decay = [exp(-log(2)/half_time_mm/100 * i) for i in 0:2000];

field3 = copy(field)
field3[1200:2000] = decay[1:801];
field3[1:800] = decay[800:-1:1];

In [None]:
figure(figsize=(1,1.5))
#axhline(1, c="black")
axhline(0.8, c="grey")

plot(field2, c="darkblue")
plot(field3, c="darkorange")

plot(field, color="black")
xticks(0:1000:2000, labels=-10:10:10)
xlabel("x (mm)")
ylabel("Activity (a.u.)")


tight_layout(pad=0)
savefig("EDF1h_speed_blurring_explanation.pdf", format="pdf",  transparent=true, dpi=300)

In [None]:
function map_field(which_map; threshold = 8/10, bottom_activity= 0)
    top_activity = numpy.nanpercentile(which_map, 95)
    field_threshold = (top_activity-bottom_activity)*threshold +bottom_activity
    thresholded_map = copy(which_map)
    valid_index = thresholded_map.>(field_threshold)

    return valid_index
end

In [None]:
(sum(map_field(field2)) .- sum(map_field(field))) / (2*100)


In [None]:
(sum(map_field(field3)) .- sum(map_field(field))) / (2*100)