In [None]:
using PyPlot, HDF5, NaNStatistics, Statistics, PyCall, Cairo, Images, ProgressMeter, HypothesisTests
using _Data

np = pyimport("numpy")
pd = pyimport("pandas")

include("Decoder_Functions.jl")

include("../project_place_cell/functions/func_map.jl")


rc_params = PyDict(pyimport("matplotlib")["rcParams"]);
rc_params["font.sans-serif"] = ["Arial"];
rc_params["font.size"] = 7;
rc_params["lines.linewidth"] = 1;
rc_params["lines.markersize"] = 4;
rc_params["xtick.major.size"] = 2;
rc_params["ytick.major.size"] = 2;
rc_params["axes.spines.top"] = false;
rc_params["axes.spines.right"] = false;

rc_params["xtick.major.pad"] = 2;
rc_params["ytick.major.pad"] = 2;

rc_params["axes.labelpad"] = 2;

cim(img::Matrix{UInt32}) = CairoImageSurface(img, Cairo.FORMAT_RGB24; flipxy = false) 
cim(img::Matrix{ARGB32}) = cim(reinterpret(UInt32, img))
cim(img::Matrix{RGB24}) = cim(reinterpret(UInt32, img))
cim(img::Matrix{UInt8}) = cim(Gray.(reinterpret(N0f8, img)))
cim(img::Array{UInt8,3}) = cim(RGB24.(reinterpret(N0f8, img[:,:,1]), reinterpret(N0f8, img[:,:,2]), reinterpret(N0f8, img[:,:,3])));downsample(img,n=4) = +([img[i:n:n*div(size(img,1)-1,n), j:n:n*div(size(img,2)-1,n)] for i = 1:n, j = 1:n]...)/(n*n);
downsample(img,n=4) = +([img[i:n:n*div(size(img,1)-1,n), j:n:n*div(size(img,2)-1,n)] for i = 1:n, j = 1:n]...)/(n*n);

In [None]:
save_filename = "decoding_revision_new"

datasets_corner_cue = 
[
    ["20220407_152537", 4, "jen"],
    ["20220406_111526", 9, "jen"],
    ["20220407_090156", 5, "jen"],
    ["20220417_165530", 2, "jen"],
    ["20220406_153842", 9, "jen"],
    ["20220405_171444", 4, "jen"],
    ["20220416_160516", 6, "jen"]
];

lengths = [90, 90, 90, 89, 90, 90, 90, 90, 90];
n_datasets = length(datasets_corner_cue)

## Example fish

In [None]:
use_fish = 7
experiment_filename = datasets_corner_cue[use_fish][1]
server = datasets_corner_cue[use_fish][2]
ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")


file = h5open(joinpath(data_path(ds_Lorenz), "$(experiment_filename)_$(save_filename).h5"), "r") #_fig4_brain_region
    place_candidates_unique = read(file, "place_cell_included")
    x_in_bins = read(file, "x_in_bins")
    y_in_bins = read(file, "y_in_bins")
    x_predicted = read(file, "x_predicted_in_bins")
    y_predicted = read(file, "y_predicted_in_bins")
    bool_index = read(file, "bool_index")
    sorting_x = read(file, "activity_sorting_x")
    sorting_y = read(file, "activity_sorting_y")
close(file)
bool_index = Bool.(bool_index)
file = h5open(joinpath(data_path(ds_Chuyu), "NMF_merge.h5"), "r")
    A_dFF = HDF5.readmmap(file["A_dF"])
close(file)

n_sweeps = size(A_dFF, 1);

In [None]:
nanmedian(Decoder.get_distance(x_in_bins, x_predicted, y_in_bins, y_predicted)[Bool.(bool_index)]) * 47/41

In [None]:
errors = Decoder.get_distance(x_in_bins, x_predicted, y_in_bins, y_predicted)[Bool.(bool_index)] * 47/41;

In [None]:
figure(figsize=(8,4))
subplot(1,2,1)
scatter(x_in_bins[bool_index], y_in_bins[bool_index], c=errors)
subplot(1,2,2)
scatter(x_predicted[bool_index], y_predicted[bool_index], c=errors)

In [None]:
plot(Bool.(bool_index))

In [None]:
figure(figsize=(20,5))
plot(x_in_bins[Bool.(bool_index)])
plot(x_predicted[Bool.(bool_index)])

In [None]:
plot(A_dFF[:, 10])

In [None]:
# check that there is no periodicity because of 1-min decoding steps

plot(x_predicted[1:n_sweeps-1] .- x_predicted[1+1:n_sweeps])
scatter([1:120:n_sweeps], fill(5, length(1:120:n_sweeps)), c="orange", s=2)
xlim(3600,5600)
xlabel("time")
ylabel("change decoder error")

In [None]:
#plot for x
# fish 1: 8000 to 10000
# fish5 3600:5300
xmin=7700
xmax=9500
ymin=0
ymax=50

figure(figsize=(20,3))
plot(x_in_bins[xmin:xmax], linewidth=0.5)
plot(x_predicted[xmin:xmax], color="k", linewidth=0.5)
ylim(ymin,ymax)



#plot for y
figure(figsize=(20,3))
plot(y_in_bins[xmin:xmax], linewidth=0.5)
plot(y_predicted[xmin:xmax], color="k", linewidth=0.5)
ylim(ymin,ymax)


In [None]:

xmin_ = xmin+1450
xmax_ = xmin+1450+360

figure(figsize=(2,1.2))
plot(x_in_bins[xmin_:xmax_], color="black")
figure(figsize=(2,1.2))
plot(y_in_bins[xmin_:xmax_], color="black", markersize = 0.1)


In [None]:
figure(figsize=(2,1.2))
plot(x_in_bins[xmin_:xmax_], y_in_bins[xmin_:xmax_], color="black", markersize = 0.1)


In [None]:
shift=4
bins=4

A_dFF_place_cells = A_dFF[:, place_candidates_unique]
A_dFF_place_cells_shifted, bool_index = Decoder.shift_time(shift, A_dFF_place_cells, [bool_index])

A_dFF_place_cells_binned = Decoder.bin_activity(bins, A_dFF_place_cells_shifted)

# z-scored dF
A_dFF_place_cells_binned_z = zeros(Float32, size(A_dFF_place_cells_binned[:,:]))
for i in 1:size(A_dFF_place_cells_binned, 2)
    A_dFF_place_cells_binned_z[:, i] .= (A_dFF_place_cells_binned[:, i] .- nanmean(A_dFF_place_cells_binned[xmin_:xmax_, i])) ./ nanstd(A_dFF_place_cells_binned[xmin_:xmax_, i])#xmin_:xmax_
end
A_dFF_place_cells_binned_z = A_dFF_place_cells_binned_z[xmin_:xmax_,:][1:bins:end,:]
ymin = nanminimum(x_in_bins)
ymax = nanmaximum(x_in_bins)

In [None]:
A_dFF_figure_x = A_dFF_place_cells_binned_z[:, sorting_x]
bins=10
A_dFF_figure_binned_x = Float32.(fill(NaN, size(A_dFF_figure_x, 1), Int(size(A_dFF_figure_x, 2)/bins -1)))

for i in 1:Int(size(A_dFF_figure_x, 2)/bins - 2)
    j=i*bins
    A_dFF_figure_binned_x[:, i] .= nanmean(A_dFF_figure_x[:, j:j+bins], dim=2)
end

A_dFF_figure_y = A_dFF_place_cells_binned_z[:, sorting_y]
A_dFF_figure_binned_y = Float32.(fill(NaN, size(A_dFF_figure_y, 1), Int(size(A_dFF_figure_y, 2)/bins -1)))

for i in 1:Int(size(A_dFF_figure_y, 2)/bins - 2)
    j=i*bins
    A_dFF_figure_binned_y[:, i] .= nanmean(A_dFF_figure_y[:, j:j+bins], dim=2)
end

In [None]:
# make new colormap
hot = plt.cm.get_cmap("hot", 256)

norm = matplotlib.colors.Normalize(vmin=-1.0, vmax=3.0)

A_dFF_figure_final_x = hot.(norm.(A_dFF_figure_binned_x));
A_dFF_figure_final_y = hot.(norm.(A_dFF_figure_binned_y));

In [None]:
imshow(A_dFF_figure_final_y', origin="lower", interpolation="nearest")

In [None]:
imshow(A_dFF_figure_final_x', origin="lower")

In [None]:
A_dFF_png_x = fill(RGBA(0,0,0,0), size(A_dFF_figure_final_x))
for i in 1:size(A_dFF_figure_final_x, 1)
    for j in 1:size(A_dFF_figure_final_x, 2)
        A_dFF_png_x[i,j] = RGBA(A_dFF_figure_final_x[i,j]...)
    end
end
Images.save("Figure_panels/panel2_x_background.png", A_dFF_png_x[:, end:-1:1]')
A_dFF_png_y = fill(RGBA(0,0,0,0), size(A_dFF_figure_final_y))
for i in 1:size(A_dFF_figure_final_y, 1)
    for j in 1:size(A_dFF_figure_final_y, 2)
        A_dFF_png_y[i,j] = RGBA(A_dFF_figure_final_y[i,j]...)
    end
end
Images.save("Figure_panels/panel2_y_background.png", A_dFF_png_y[:, end:-1:1]')

In [None]:
x_in_mm = (x_in_bins .- 8) .* 47/41;
y_in_mm = (y_in_bins .- 15) .* 47/41;

x_predicted = (x_predicted .- 8) .* 47/41;
y_predicted = (y_predicted .- 15) .* 47/41;

In [None]:

fig, ax1 = subplots(figsize=(2, 0.85))
ax2 = ax1.twinx()

ax2.plot(1.5 .+ x_in_mm[xmin_:xmax_], color="white")

ax1.set_xlabel("")#"Time")
ax1.set_xticks([])
ax1.set_xlim(0, 360)
ax1.set_ylim(0, 1000)
ax1.set_yticks([0, 500, 950], labels=["0", "", "1000"])
ax1.set_ylabel("Neurons", labelpad=-10)

ax2.set_yticks([0, 25, 50], labels=["0", "", "50"])
ax2.set_ylabel("True \$\\it{x}\$ (mm)")
ax2.yaxis.set_label_coords(1.1, 0.5)
ax2.set_ylim(0, 50)
ax1.tick_params("y", length=0)
ax2.tick_params("y", length=0)
ax2.spines["right"].set_visible(true)

tight_layout(pad=0)
savefig("Figure_panels/panel2_x_activity.pdf", format="pdf",  transparent=true, dpi=300)

fig, ax1 = subplots(figsize=(2,0.65))
ax2 = ax1.twinx()
#m = ax1.imshow(A_dFF_figure_y[:,1:1:end]', origin="lower", aspect="auto")
ax2.plot(1 .+ y_in_mm[xmin_:xmax_], color="white")
ax1.set_yticks([0, 500, 1000], labels=["0", "", "1000"])
ax1.set_xlabel("Time (min)", labelpad=-5)
ax1.set_xlim(0, 360)
ax1.set_ylim(0, 1000)
ax1.set_xticks([0, 120, 240, 360], labels=["0", "", "", "3"])
ax1.set_ylabel("Neurons", loc="bottom", labelpad=-10)
ax2.set_yticks([0, 25], labels=["0", "25"])
ax2.set_ylabel("True \$\\it{y}\$ (mm)")
ax2.yaxis.set_label_coords(1.1, 0.45)
ax2.set_ylim(0, 25)

ax1.tick_params("y", length=0)
ax2.tick_params("y", length=0)
ax2.spines["right"].set_visible(true)

tight_layout(pad=0)
savefig("Figure_panels/panel2_y_activity.pdf", format="pdf",  transparent=true, dpi=300)

In [None]:
#plot for x

fig, ax = subplots(2,1, figsize=(3,1.7), sharex=true)

ax[1].plot(1 .+ x_in_mm[xmin:xmax], color="k", linewidth=0.5)
ax[1].plot(1 .+ x_predicted[xmin:xmax], color="red", linewidth=0.5, alpha=0.9)
ax[1].set_ylim(ymin,ymax)
ax[1].set_xlim(0, xmax-xmin)
ax[1].set_yticks([0, 25, 50], labels=["0", "", "50"])
ax[1].set_ylabel("\$\\it{x}\$ (mm)", rotation=90, labelpad=-5)

ax[2].plot(1 .+ y_in_mm[xmin:xmax], color="k", linewidth=0.5, label="True")
ax[2].plot(1 .+ y_predicted[xmin:xmax], color="red", linewidth=0.5, alpha=0.9, label="Decoded")
ax[2].set_ylim(ymin,ymax)
ax[2].set_yticks([0, 25, 50], labels=["0", "", "50"])
ax[2].set_ylabel("\$\\it{y}\$ (mm)", rotation=90, labelpad=-5)
ax[2].set_xticks(0:600:xmax-xmin)
ax[2].set_xticklabels(["0", "", "", "15"])#round.(Int32, (1:600:(xmax-xmin))./120))
ax[2].set_xlabel("Time (min)", labelpad=-5)

fig.legend(loc=(0.15, 0.35))

tight_layout(pad=0.5)
savefig("Figure_panels/panel2.pdf", format="pdf",  transparent=true, dpi=300)

print(nanmedian(Decoder.get_distance(x_in_mm[xmin:xmax], x_predicted[xmin:xmax], y_in_mm[xmin:xmax], y_predicted[xmin:xmax])))

# fig = figure(figsize=(2,1.2))
# plot(x_in_mm[xmin:xmax], y_in_mm[xmin:xmax], "k.", markersize = 0.1)
# for i = xmin:xmax-1

#     plot(x_in_mm[i:i+1], y_in_mm[i:i+1], color = [0, 0, 0], alpha = 0.5)
# end
# xlim(150,3150)
# ylim(500,2250)
# axis("equal")
# axis("off")
# tight_layout(pad=0)
# savefig("Figure_panels/panel1.pdf", format="pdf",  transparent=true, dpi=300)


## error by position

In [None]:
@pyimport skimage.transform as skimage_transform
@pyimport scipy.ndimage as ndimage

In [None]:
center_loc_all = []
img_bg_all = []
chamber_roi_all = []

for i in 1:length(datasets_corner_cue)

    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]
    experimenter = datasets_corner_cue[i][3]

    ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    ds = Dataset(experiment_filename, experimenter, gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    
    
    img_bg = h5open(ds, "behavior.h5"; raw = true) do file
    read(file, "img_bg")
    end;
    img_bg_end = img_bg[:,:,end]


    # orientation-corrected fish location (time binned)
    position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n60.h5"))
        chamber_roi = read(position_file,"chamber_roi")
    close(position_file)


    chamber_roi[findall(chamber_roi.!=0)].=1;
    countour = find_contours(chamber_roi)[1];
    countour_array = hcat(getindex.(countour, 1), getindex.(countour,2));
    center_loc = nanmean(countour_array, dims=1);
    
    append!(center_loc_all, [center_loc])
    append!(img_bg_all, [img_bg_end])
    append!(chamber_roi_all, [chamber_roi])
end







In [None]:
shift_all= Dict()#first number: -is up, +is down
shift_all["20220407_152537"] = [88, 20]
shift_all["20220405_171444"] = [65,-205]
shift_all["20220406_111526"] = [0,0]
shift_all["20220417_110011"] = [150, 100]
shift_all["20220407_090156"] = [-15, -5]
shift_all["20220417_165530"] = [-25, -290] #should be rotated a bit...
shift_all["20220813_103255"] = [15, -250] #should be rotated a bit...
shift_all["20220406_153842"] = [145, -215]
shift_all["20220818_123314"] = [-20,-260]
shift_all["20220819_094333"] = [-80,-265]
shift_all["20220818_163902"] = [-80,-270]
shift_all["20220416_160516"] = [-110,-270]

flip_h = ["20220407_152537", "20220406_111526", "20220407_090156", "20220417_165530", "20220813_103255", "20220818_123314", "20220819_094333", "20220818_163902", "20220416_160516"]
flip_v = ["20220405_171444", "20220417_165530", "20220813_103255", "20220406_153842", "20220818_123314", "20220819_094333", "20220818_163902", "20220416_160516"]

chamber_roi_sum = fill(0.0, size(chamber_roi_all[1]))

figure(figsize=(20,10))
for i in 1:n_datasets
    filename = datasets_corner_cue[i][1]

    shift = [shift_all[filename][1], shift_all[filename][2]]
    
    tform = skimage_transform.EuclideanTransform(translation=shift);
    img_bg_end_r = skimage_transform.warp(img_bg_all[i], tform);
    chamber_roi_r = skimage_transform.warp(chamber_roi_all[i], tform);
    
    if filename in flip_h
        img_bg_end_r = img_bg_end_r[:, end:-1:1]
        chamber_roi_r = chamber_roi_r[:, end:-1:1]
    end
    if filename in flip_v
        img_bg_end_r = img_bg_end_r[end:-1:1, :]
        chamber_roi_r = chamber_roi_r[end:-1:1, :]
    end
    
    chamber_roi_sum .= chamber_roi_sum .+ chamber_roi_r
    
    
    imshow(img_bg_end_r', origin="lower", alpha=0.2)
    if i==4
        imshow(img_bg_end_r', origin="lower", cmap="hot", alpha=1, zorder=-1)
    end

end

In [None]:
imshow(chamber_roi_sum', origin="lower"); colorbar()

In [None]:
chamber_roi_consistent = copy(chamber_roi_sum)
chamber_roi_consistent[chamber_roi_consistent .< 4e-19] .= 0
chamber_roi_consistent[chamber_roi_consistent .>= 4e-19] .= 1

blurred_chamber_roi_consistent = ndimage.median_filter(chamber_roi_consistent,20)
figure(figsize=(20,10))
imshow(blurred_chamber_roi_consistent', origin="lower")

In [None]:
max_x = size(img_bg_all[1])[1]
max_y = size(img_bg_all[1])[2]


figure(figsize=(20,10))
for i in 1:n_datasets
    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]

    ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")

    
    position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n60.h5"))
        x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
        y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
    close(position_file)
    
    
    x_fish_shifted = x_fish_sweep_mean .- shift_all[experiment_filename][2]
    y_fish_shifted = y_fish_sweep_mean .- shift_all[experiment_filename][1]
    
    if experiment_filename in flip_v
        x_fish_shifted = max_x .- x_fish_shifted
    end
    if experiment_filename in flip_h
        y_fish_shifted = max_y .- y_fish_shifted
    end
    
    scatter(x_fish_shifted, y_fish_shifted, alpha=0.5, s=2)
end


In [None]:
n_pos=40

all_x = Vector{Float32}()
all_y = Vector{Float32}()
all_errors = Vector{Float32}()

count_maps_all_map = []
count_maps_all = []
error_maps_all = []
repr_maps_all = []

figure(figsize=(4,2))
for i in 1:n_datasets
    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]

    ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")

    file = h5open(joinpath(data_path(ds_Lorenz), "$(experiment_filename)_$(save_filename).h5"), "r")
        place_candidates_unique = read(file, "place_cell_included")
        x_in_bins = read(file, "x_in_bins")
        y_in_bins = read(file, "y_in_bins")
        x_predicted = read(file, "x_predicted_in_bins")
        y_predicted = read(file, "y_predicted_in_bins")
        bool_index = BitVector(read(file, "bool_index"))
    close(file)
    
    position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n60.h5"))
        x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
        y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
    close(position_file)
    
    x_fish_shifted = x_fish_sweep_mean .- shift_all[experiment_filename][2]
    y_fish_shifted = y_fish_sweep_mean .- shift_all[experiment_filename][1]
    
    if experiment_filename in flip_v
        x_fish_shifted = max_x .- x_fish_shifted
    end
    if experiment_filename in flip_h
        y_fish_shifted = max_y .- y_fish_shifted
    end
    
    
    inside_chamber = fill(false, length(y_fish_shifted))
    idxs = [i for i in 1:length(y_fish_shifted) if 1 == blurred_chamber_roi_consistent[round(Int32, x_fish_shifted[i]), round(Int32, y_fish_shifted[i])]]
    inside_chamber[idxs] .= true
    
    
    errors = Decoder.get_distance(x_in_bins, x_predicted, y_in_bins, y_predicted) .* 47/41
    
    all_errors = vcat(all_errors, errors[bool_index .& inside_chamber])
    all_x = vcat(all_x, x_fish_shifted[bool_index .& inside_chamber])
    all_y = vcat(all_y, y_fish_shifted[bool_index .& inside_chamber])    
    
    
    dig_min_x = floor(Int64, minimum(x_fish_shifted[bool_index .& inside_chamber]));
    dig_max_x = floor(Int64, maximum(x_fish_shifted[bool_index .& inside_chamber]));

    dig_min_y = floor(Int64, minimum(y_fish_shifted[bool_index .& inside_chamber]));
    dig_max_y = floor(Int64, maximum(y_fish_shifted[bool_index .& inside_chamber]));

    interval = maximum([(dig_max_y-dig_min_y+2)/n_pos,(dig_max_x-dig_min_x+2)/n_pos])

    x_bins = collect(dig_min_x-1:interval:dig_min_x+interval*(n_pos)+1);
    y_bins = collect(dig_min_y-1:interval:dig_min_y+interval*(n_pos)+1);

    x_digital = numpy.digitize(x_fish_shifted[bool_index .& inside_chamber], x_bins)
    y_digital = numpy.digitize(y_fish_shifted[bool_index .& inside_chamber], y_bins);
    loc_digital = (y_digital.-1).*n_pos.+x_digital;
    
    error_map = fill(NaN, maximum(x_digital), maximum(y_digital))
    count_map_quant = fill(NaN, maximum(x_digital), maximum(y_digital))
    com = [0.0,0.0]
    for i in unique(x_digital)
        for j in unique(y_digital)
            error_map[i, j] = nanmedian(errors[bool_index .& inside_chamber][(x_digital .== i) .& (y_digital .== j)])
            count_map_quant[i,j] = nansum((x_digital .== i) .& (y_digital .== j))
        end
    end
    
    temp = MAP.calculate_mask_map_digital(loc_digital, n_pos; use_gaussian_filter=true);
    count_map = temp[1:maximum(x_digital), 1:maximum(y_digital)]
    
    
    
    file = h5open(joinpath(data_path(ds_Chuyu), "neuron_spatial_info_15_$(lengths[i])_chamber_geometry_$(experiment_filename)_sigma1_n60_A_dF.h5"))
            place_cell_index = read(file, "place_cell_index")
    place_map_all_1 = HDF5.readmmap(file["place_map_all"])
            place_map_all_1 = HDF5.readmmap(file["place_map_all"])
        valid_neurons = HDF5.readmmap(file["valid_neurons"])
        mask_map_all_1 = HDF5.readmmap(file["mask_map_all"])
        activity_num_map_all_1 = HDF5.readmmap(file["activity_num_map_all"])
    close(file)
    
    file = h5open(joinpath(data_path(ds_Chuyu), "NMF_merge.h5"))
    
        X_all = read(file, "X_all")
        Y_all = read(file, "Y_all")
        Z_all = read(file, "Z_all")
    close(file)
        
    
    # calculate representation density
    

    for_place_calculation_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n60.h5"))
        x_bins = read(for_place_calculation_file,"x_bins")
        y_bins = read(for_place_calculation_file,"y_bins")
        mask_valid = read(for_place_calculation_file,"mask_valid")
    close(for_place_calculation_file)
    
    mask_valid_x_min = minimum(getindex.(findall(mask_valid), 1))
    mask_valid_x_max = maximum(getindex.(findall(mask_valid), 1))
    mask_valid_y_min = minimum(getindex.(findall(mask_valid), 2))
    mask_valid_y_max = maximum(getindex.(findall(mask_valid), 2))
    
    
    coarse_maps = fill(NaN32, size(place_map_all_1[:,:,place_cell_index])) # this is used later to classify maps
    @showprogress for (i,neuron_index) in enumerate(place_cell_index)
        test_map = activity_num_map_all_1[:,:,neuron_index]./mask_map_all_1[:,:,neuron_index]
        coarse_maps[:,:,i] = map_field(test_map) # get >80% from 95percentile
    end

    x_add=0
    y_add=0
    if experiment_filename in ["20220407_152537", "20220406_153842"]
        y_add=1
    end
    if experiment_filename in ["20220407_090156", "20220406_153842", "20220405_171444"]
        x_add=1
    end
    
    mask_valid_x_min = mask_valid_x_min + x_add
    mask_valid_y_min = mask_valid_y_min + y_add
    
    sum_map_coarse_binary = numpy.nansum(coarse_maps, axis= 2);
    repr_map = sum_map_coarse_binary[mask_valid_x_min:mask_valid_x_min+39, mask_valid_y_min:mask_valid_y_min+19]
    if experiment_filename in flip_v
        repr_map = repr_map[:, end:-1:1]
    end
    if experiment_filename in flip_h
        repr_map = repr_map[end:-1:1, :]
    end
    
    
    # append all
    
    append!(repr_maps_all, [repr_map])
    append!(count_maps_all_map, [count_map])
    append!(count_maps_all, [count_map_quant])
    append!(error_maps_all, [error_map])    
    
    figure()
    #imshow(img_bg_all[i])
    imshow(repr_map', origin="lower")
    figure()
    imshow(count_map', origin="lower")
end

repr_maps_all = cat(repr_maps_all..., dims=3);
count_maps_all_map = cat(count_maps_all_map..., dims=3);
count_maps_all = cat(count_maps_all..., dims=3);
error_maps_all = cat(error_maps_all..., dims=3);

In [None]:
dig_min_x = floor(Int64, minimum(all_x));
dig_max_x = floor(Int64, maximum(all_x));

dig_min_y = floor(Int64, minimum(all_y));
dig_max_y = floor(Int64, maximum(all_y));

interval = maximum([(dig_max_y-dig_min_y+2)/n_pos,(dig_max_x-dig_min_x+2)/n_pos])


repr_maps = nanmedian(repr_maps_all, dim=3);

error_maps = nanmedian(error_maps_all, dim=3);
count_maps_show = nanmedian(count_maps_all_map, dim=3);
count_maps = nanmedian(count_maps_all, dim=3);
count_maps = count_maps_show;

In [None]:
chamber_roi_mask = Float32.(copy(blurred_chamber_roi_consistent))
chamber_roi_mask[chamber_roi_mask.!=0].= NaN32


fig = figure(figsize=(1.8,0.9))
m=imshow(error_maps', origin="lower", extent = [dig_min_x-(interval/2), dig_max_x+(interval/2), dig_min_y-(interval/2), dig_max_y+(interval/2)], cmap="hot", vmin=0, vmax=12)

countour = find_contours(blurred_chamber_roi_consistent)[1];
countour_array = hcat(getindex.(countour, 1), getindex.(countour,2));
plot(vcat(countour_array[1:50:end,1][end], countour_array[1:50:end,1]),vcat(countour_array[1:50:end,2][end], countour_array[1:50:end,2]),c="k", alpha=0.3)

imshow(chamber_roi_mask',origin="lower",  cmap="binary")

xlim(dig_min_x-interval, dig_max_x+interval)
ylim(dig_min_y-interval, dig_max_y+interval)
axis("off")
tight_layout(pad=0)
plt.savefig("Figure_panels/error_by_position.pdf", format="pdf", transparent = true, dpi=300)
figure(figsize=(2,1))
cb = colorbar(m, pad=0.15,fraction=0.07, orientation="horizontal", location="top")
cb.set_ticks([0, 15])
cb.set_label("Decoder error (mm)", labelpad=-5)

In [None]:
chamber_roi_mask = Float32.(copy(blurred_chamber_roi_consistent))
chamber_roi_mask[chamber_roi_mask.!=0].= NaN32

Blues = plt.cm.get_cmap("Blues", 100)

fig = figure(figsize=(1.4,0.6))
m=imshow((count_maps_show./25)', origin="lower", extent = [dig_min_x-(interval/2), dig_max_x+(interval/2)+interval, dig_min_y-(interval/2), dig_max_y+(interval/2)], cmap=Blues, vmin=0, vmax=1)

countour = find_contours(blurred_chamber_roi_consistent)[1];
countour_array = hcat(getindex.(countour, 1), getindex.(countour,2));
plot(vcat(countour_array[1:50:end,1][end], countour_array[1:50:end,1]),vcat(countour_array[1:50:end,2][end], countour_array[1:50:end,2]),c="k", alpha=0.3)

imshow(chamber_roi_mask',origin="lower",  cmap="binary")

xlim(dig_min_x-interval, dig_max_x+interval)
ylim(dig_min_y-interval, dig_max_y+interval)
axis("off")
tight_layout(pad=0)
plt.savefig("Figure_panels/occupancy.pdf", format="pdf", transparent = true, dpi=300)
figure(figsize=(2,1))
cb = colorbar(m, pad=0.15,fraction=0.07, orientation="horizontal", location="top")
#cb.set_ticks([0, 1, 2], labels=["1", "10", "100"])
cb.set_label("Occupancy", labelpad=0)

In [None]:
n_pos=40

all_x = Vector{Float32}()
all_y = Vector{Float32}()
all_errors = Vector{Float32}()

count_maps_all_map = []
count_maps_all = []
enter_maps_all = []
exit_maps_all = []

figure(figsize=(4,2))
for i in 1:n_datasets
    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]

    ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")

    file = h5open(joinpath(data_path(ds_Lorenz), "$(experiment_filename)_$(save_filename).h5"), "r")
        bool_index = BitVector(read(file, "bool_index"))
    close(file)
    
    position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n60.h5"))
        x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
        y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
    close(position_file)
    
    x_fish_shifted = x_fish_sweep_mean .- shift_all[experiment_filename][2]
    y_fish_shifted = y_fish_sweep_mean .- shift_all[experiment_filename][1]
    
    if experiment_filename in flip_v
        x_fish_shifted = max_x .- x_fish_shifted
    end
    if experiment_filename in flip_h
        y_fish_shifted = max_y .- y_fish_shifted
    end
    
    
    inside_chamber = fill(false, length(y_fish_shifted))
    idxs = [i for i in 1:length(y_fish_shifted) if 1 == blurred_chamber_roi_consistent[round(Int32, x_fish_shifted[i]), round(Int32, y_fish_shifted[i])]]
    inside_chamber[idxs] .= true
    
    
    all_x = vcat(all_x, x_fish_shifted[bool_index .& inside_chamber])
    all_y = vcat(all_y, y_fish_shifted[bool_index .& inside_chamber])
        
    
    dig_min_x = floor(Int64, minimum(x_fish_shifted[bool_index .& inside_chamber]));
    dig_max_x = floor(Int64, maximum(x_fish_shifted[bool_index .& inside_chamber]));

    dig_min_y = floor(Int64, minimum(y_fish_shifted[bool_index .& inside_chamber]));
    dig_max_y = floor(Int64, maximum(y_fish_shifted[bool_index .& inside_chamber]));

    interval = maximum([(dig_max_y-dig_min_y+2)/n_pos,(dig_max_x-dig_min_x+2)/n_pos])

    x_bins = collect(dig_min_x-1:interval:dig_min_x+interval*(n_pos)+1);
    y_bins = collect(dig_min_y-1:interval:dig_min_y+interval*(n_pos)+1);

    x_digital = numpy.digitize(x_fish_shifted[bool_index .& inside_chamber], x_bins)
    y_digital = numpy.digitize(y_fish_shifted[bool_index .& inside_chamber], y_bins);
    loc_digital = (y_digital.-1).*n_pos.+x_digital;
    
    enter_map = fill(NaN, maximum(x_digital), maximum(y_digital))
    exit_map = fill(NaN, maximum(x_digital), maximum(y_digital))
    
    count_map_quant = fill(NaN, maximum(x_digital), maximum(y_digital))
    com = [0.0,0.0]
    for i in unique(x_digital)
        for j in unique(y_digital)
            
            enter_map[i,j] = nansum(((x_digital[1:end-1] .!= i) .& (y_digital[1:end-1] .!= j)) .& ((x_digital[2:end] .== i) .& (y_digital[2:end] .== j)))
            exit_map[i,j] = nansum(((x_digital[1:end-1] .== i) .& (y_digital[1:end-1] .== j)) .& ((x_digital[2:end] .!= i) .& (y_digital[2:end] .!= j)))
            count_map_quant[i,j] = nansum((x_digital .== i) .& (y_digital .== j))
        end
    end
    
    
    
    # append all
    
    append!(count_maps_all_map, [count_map_quant])
    append!(enter_maps_all, [enter_map])
    append!(exit_maps_all, [exit_map])
    
    
    figure()
    imshow(count_map_quant', origin="lower")
end

count_maps_all_map = cat(count_maps_all_map..., dims=3);
enter_maps_all = cat(enter_maps_all..., dims=3);
exit_maps_all = cat(exit_maps_all..., dims=3);

In [None]:
mean(reshape(enter_maps_all, (20*40*7)))

In [None]:
std(reshape(exit_maps_all, (20*40*7)))

In [None]:
hist(reshape(count_maps_all, (20*40*7, 1)), bins=100)

println(mean(reshape(count_maps_all, (20*40*7, 1))))
println(std(reshape(count_maps_all, (20*40, 7))))

In [None]:
imshow(count_maps_all_map[:,:,1])

In [None]:
chamber_roi_mask = Float32.(copy(blurred_chamber_roi_consistent))
chamber_roi_mask[chamber_roi_mask.!=0].= NaN32


fig = figure(figsize=(1.8,0.9))
m=imshow(repr_maps[:,1:end-1]', origin="lower", extent = [dig_min_x-(interval/2), dig_max_x+(interval/2)+interval, dig_min_y-(interval/2), dig_max_y+(interval/2)], cmap="hot", vmin=0, vmax=250, zorder=-1)


countour = find_contours(blurred_chamber_roi_consistent)[1];
countour_array = hcat(getindex.(countour, 1), getindex.(countour,2));
plot(vcat(countour_array[1:50:end,1][end], countour_array[1:50:end,1]),vcat(countour_array[1:50:end,2][end], countour_array[1:50:end,2]),c="k", alpha=0.3)

imshow(chamber_roi_mask',origin="lower",  cmap="binary")



xlim(dig_min_x-interval, dig_max_x+interval)
ylim(dig_min_y-interval, dig_max_y+interval)
axis("off")
tight_layout(pad=0)
plt.savefig("Figure_panels/representation_map.pdf", format="pdf", transparent = true, dpi=300)
figure(figsize=(2,1))
cb = colorbar(m, pad=0.15,fraction=0.07, orientation="horizontal", location="top")
#cb.set_ticks([0, 1, 2], labels=["1", "10", "100"])
cb.set_label("Place representation (a.u.)", labelpad=0)

In [None]:
println(nancor(vec(reshape(repr_maps[:,1:end-1], 760,1)), vec(reshape(count_maps[:,1:end-1], 760,1))))
scatter(vec(reshape(repr_maps[:,1:end-1], 760,1)), vec(reshape(count_maps[:,1:end-1], 760,1)))
corrs=fill(NaN,n_datasets)
for i in 1:n_datasets
    corrs[i] = nancor(vec(reshape(repr_maps_all[:,:,i], 800,1)), vec(reshape(count_maps_all[:,:,i], 800,1)))
    #scatter(vec(reshape(repr_maps_all[:,1:end-1,i], 760,1)), vec(reshape(count_maps_all[:,1:end-1,i], 760,1)))
end
println(nanstd(corrs))
print(nanmean(corrs))
print(corrs)

In [None]:
fish_nr = 3


chamber_roi_mask = Float32.(copy(blurred_chamber_roi_consistent))
chamber_roi_mask[chamber_roi_mask.!=0].= NaN32

fig = figure(figsize=(1.8,0.9))
m=imshow(repr_maps_all[:,:,fish_nr]', origin="lower", extent = [dig_min_x-(interval/2), dig_max_x+(interval/2)+interval, dig_min_y-(interval/2), dig_max_y+(interval/2)], cmap="hot", vmin=0, vmax = nanpctile(repr_maps_all[:,:,fish_nr], 99), zorder=-1)


countour = find_contours(blurred_chamber_roi_consistent)[1];
countour_array = hcat(getindex.(countour, 1), getindex.(countour,2));
plot(vcat(countour_array[1:50:end,1][end], countour_array[1:50:end,1]),vcat(countour_array[1:50:end,2][end], countour_array[1:50:end,2]),c="k", alpha=0.3)

imshow(chamber_roi_mask',origin="lower",  cmap="binary")



xlim(dig_min_x-interval, dig_max_x+interval)
ylim(dig_min_y-interval, dig_max_y+interval)
axis("off")
tight_layout(pad=0)
plt.savefig("Figure_panels/representation_map_fish_$(fish_nr).pdf", format="pdf", transparent = true, dpi=300)
figure(figsize=(2,1))
cb = colorbar(m, pad=0.15,fraction=0.07, orientation="horizontal", location="top")
#cb.set_ticks([0, 1, 2], labels=["1", "10", "100"])
cb.set_label("Place representation (a.u.)", labelpad=0)



fig = figure(figsize=(1.8,0.9))
m=imshow((count_maps_all[:,:,fish_nr]./nanpctile(count_maps_all[:,:,fish_nr], 95))', origin="lower", extent = [dig_min_x-(interval/2), dig_max_x+(interval/2)+interval, dig_min_y-(interval/2), dig_max_y+(interval/2)], cmap=Blues, vmin=0, vmax=1)

plot(vcat(countour_array[1:50:end,1][end], countour_array[1:50:end,1]),vcat(countour_array[1:50:end,2][end], countour_array[1:50:end,2]),c="k", alpha=0.3)

imshow(chamber_roi_mask',origin="lower",  cmap="binary")



xlim(dig_min_x-interval, dig_max_x+interval)
ylim(dig_min_y-interval, dig_max_y+interval)
axis("off")
tight_layout(pad=0)
plt.savefig("Figure_panels/occupancy_map_fish_$(fish_nr).pdf", format="pdf", transparent = true, dpi=300)
figure(figsize=(2,1))
cb = colorbar(m, pad=0.15,fraction=0.07, orientation="horizontal", location="top")
#cb.set_ticks([0, 1, 2], labels=["1", "10", "100"])
cb.set_label("Occupancy (a.u.)", labelpad=0)

In [None]:
repr_var_map = std(repr_maps_all, dims=3)[:,:,1];

fig = figure(figsize=(1.8,0.9))
m=imshow(repr_var_map', origin="lower", extent = [dig_min_x-(interval/2), dig_max_x+(interval/2)+interval, dig_min_y-(interval/2), dig_max_y+(interval/2)], cmap="hot", vmin=0, vmax=150)#nanpctile(repr_var_map, 95))

plot(vcat(countour_array[1:50:end,1][end], countour_array[1:50:end,1]),vcat(countour_array[1:50:end,2][end], countour_array[1:50:end,2]),c="k", alpha=0.3)

imshow(chamber_roi_mask',origin="lower",  cmap="binary")



xlim(dig_min_x-interval, dig_max_x+interval)
ylim(dig_min_y-interval, dig_max_y+interval)
axis("off")
tight_layout(pad=0)
plt.savefig("Figure_panels/representation_std_map.pdf", format="pdf", transparent = true, dpi=300)
figure(figsize=(2,1))
cb = colorbar(m, pad=0.15,fraction=0.07, orientation="horizontal", location="top")
#cb.set_ticks([0, 1, 2], labels=["1", "10", "100"])
cb.set_label("Representation std (a.u.)", labelpad=0)

In [None]:
function flatten(x)
    return reshape(x, length(x))
end

In [None]:
corrs = Float32[]
for fish_nr in 1:7
    subplot(3,3,fish_nr)
    scatter(flatten(count_maps_all[:,:,fish_nr]), flatten(repr_maps_all[:,:,fish_nr]))
    push!(corrs, cor(flatten(count_maps_all[:,:,fish_nr]), flatten(repr_maps_all[:,:,fish_nr])))
    tight_layout()
    xlim(0,50)
    title(corrs[fish_nr])
end

In [None]:
cor(flatten(count_maps_show), flatten(repr_maps))

In [None]:
figure(figsize=(1.5,1.5))
scatter(1:7, corrs, color="black")
yticks([-0.4, 0, 0.4])
ylabel("Correlation", labelpad=-2)
xlabel("Dataset")

ylim(-0.4, 0.4)
xticks(1:7)
xlabel("Fish")


tight_layout(pad=0.1)
plt.savefig("Figure_panels/correlations_all.pdf", format="pdf", transparent = true, dpi=300)

In [None]:
println(mean(corrs))
println(std(corrs))

## overall error distribution

In [None]:
fig, ax = subplots(3,3,figsize=(5,3.5))

for use_fish in 1:length(datasets_corner_cue)
    experiment_filename = datasets_corner_cue[use_fish][1]
    server = datasets_corner_cue[use_fish][2]
    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")


    file = h5open(joinpath(data_path(ds_Lorenz), "$(experiment_filename)_$(save_filename).h5"), "r")  # _fig4_data_median
    
        print(keys(file))
        
        place_candidates_unique = read(file, "place_cell_included")
        x_in_bins = read(file, "x_in_bins")
        y_in_bins = read(file, "y_in_bins")
        x_predicted = read(file, "x_predicted_in_bins")
        y_predicted = read(file, "y_predicted_in_bins")
        bool_index = read(file, "bool_index")
    close(file)
    bool_index = BitArray(bool_index)
    errors = Decoder.get_distance(x_in_bins, x_predicted, y_in_bins, y_predicted)[bool_index] * 47/40;
    
    subplot(3,3, use_fish)
    #hist(errors, bins=50);

    bins= numpy.linspace(0,50,30)
    greys = plt.cm.get_cmap("Greys", 10)
    color1 = greys(4)
    h = hist(errors, bins=bins, histtype="stepfilled", fc=(color1[1], color1[2], color1[3], 0.5),ec=color1)
    println(sum(.!isnan.(errors))/120)
    axvline(nanmedian(errors), color="black", label="Median error")
    axvline(15.08, color="black", linestyle="--", label="Control 3")
    axvline(47, color="black", linestyle=":", label="Chamber length")
    plt.text(25, round(maximum(h[1])/250)*250, "Fish $(use_fish)")
    ylabel("Count", labelpad=-7)
    yticks([0, round(maximum(h[1])/250)*250])

    xlabel("Error (mm)")
    xticks(0:10:50)
    xlim(0, 50)

end
subplot(3,3, 8)
axis("off")
subplot(3,3, 9)
axis("off")
handles, labels = ax[1].get_legend_handles_labels()
ax[9].legend(handles, labels, loc="lower right", frameon=false, handlelength=1.3)
fig.subplots_adjust(hspace=4)
tight_layout(pad=0)
savefig("Figure_panels/error_distribution_separate.pdf", format="pdf",  transparent=true, dpi=300)

In [None]:
using StatsBase

In [None]:
fish_errors = fill(NaN, 50, n_datasets)
all_errors = []
for use_fish in [1,2,3,4,5,6,7]#1:length(datasets_corner_cue)
    experiment_filename = datasets_corner_cue[use_fish][1]
    server = datasets_corner_cue[use_fish][2]
    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")


    file = h5open(joinpath(data_path(ds_Lorenz), "$(experiment_filename)_$(save_filename).h5"), "r")
        place_candidates_unique = read(file, "place_cell_included")
        x_in_bins = read(file, "x_in_bins")
        y_in_bins = read(file, "y_in_bins")
        x_predicted = read(file, "x_predicted_in_bins")
        y_predicted = read(file, "y_predicted_in_bins")
        bool_index = read(file, "bool_index")
    close(file)
    bool_index = BitArray(bool_index)
    errors = Decoder.get_distance(x_in_bins, x_predicted, y_in_bins, y_predicted)[bool_index] * 47/41;

    fish_errors[:, use_fish] .= hist(errors, bins=50)[1] ./ length(errors)
    
    append!(all_errors, errors)

end

data = Dict()
data["Timepoints"] = all_errors
#data["Timepoints2"] = all_errors
#data["Timepoints3"] = all_errors

df = pd.DataFrame(data)

fig, ax = subplots(figsize=(1.5,1.15))
bins= numpy.linspace(0,50,50)
greys = plt.cm.get_cmap("Greys", 10)
color1 = greys(4)
h=hist(all_errors, bins=bins, histtype="stepfilled", fc=(color1[1], color1[2], color1[3], 0.5),ec=color1)

axvline(nanmedian(all_errors), color="black", label="Median error")
axvline(15.08, ymin=0, ymax=0.5, color="black", linestyle="--", label="Control 3")
axvline(47, ymin=0, ymax=0.5, color="black", linestyle=":", label="Chamber length")
println(nanmedian(all_errors))

ylabel("Timepoints", labelpad=-12)
xlabel("Error (mm)")
xlim(0, 50)
ylim(0, 4500)
xticks(0:10:50)
yticks([0, 6500])
leg = fig.legend(loc=(0.35,0.6), frameon=false, handlelength=1.3)

tight_layout(pad=0)
savefig("Figure_panels/error_distribution_together.pdf", format="pdf",  transparent=true, dpi=300)

In [None]:
print(bins)

In [None]:
print(h[1])

## error by population

In [None]:
x = fill(NaN, length(datasets_corner_cue),13)
names = ["place_cells_optic_tectum", "place_cells_rhomb_mes", "place_cells_rhomb", "place_cells_di", "random_cells", "place_cells_subpallium", "place_cells_pallium", "baseline_behavior_informed", "place_cells_tel", "place_cells_mes", "place_cells", "place_cells_habenula", "baseline_uniformly_random"]

for i in 1:length(datasets_corner_cue)
    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]

    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")

    file = h5open(joinpath(data_path(ds_Lorenz), "$(experiment_filename)_$(save_filename).h5"), "r") #_fig4_brain_region.h5
        
        errors_by_population_in_mm_keys = read(file, "errors_by_population_in_mm_keys")
        errors_by_population_in_mm_values = read(file, "errors_by_population_in_mm_values") 
    close(file)
    x[i, :] = [n in errors_by_population_in_mm_keys ? errors_by_population_in_mm_values[findall(n .== errors_by_population_in_mm_keys)[1]] : NaN for n in names]
    
end



In [None]:
x

In [None]:
fig, ax = plt.subplots(figsize=(5,5))

for i in 1:length(x[1, :])
    col="grey"
    
    m = nanmedian(x[:,i])
    ax.plot((i-0.3, i+0.3), (m, m), "k-",linewidth=1, zorder=-1)

    ax.scatter(fill(i, length(x[:,1])), x[:, i], s=10, alpha=0.8)

end
ylabel("error (mm)")
ax.set_xticks(1:length(x[1,:]), labels=names, rotation=45,ha="right", rotation_mode="anchor");

In [None]:
names = ["PC",
        "Tel. PC",
         "M.+R. PC",
         "OT. PC",
         "C1",
         "C2",
         "C3"]

In [None]:
order = [11, 9, 2, 1, 5, 13, 8];
x_ordered = x[:,order]

In [None]:
x_ordered[:,1]

In [None]:
SignedRankTest(x_ordered[:,2] .- x_ordered[:,3])

In [None]:
@pyimport scipy

scipy.stats.wilcoxon(x_ordered[:,2], x_ordered[:,3])


In [None]:
nanstd(x_ordered[:, :], dim=1)

In [None]:
nanmedian(Float32.(x_ordered[:, :]), dim=1)

In [None]:
baseline3 = nanmean(x_ordered[:,7])

In [None]:
fig, ax = plt.subplots(figsize=(1.25,1.32))

for i in 1:length(x_ordered[1, :])
    col="grey"
    if i in [5,6,7]
        col="lightgrey"
    end
    
    m = nanmedian(x_ordered[:,i])
    ax.plot((i-0.3, i+0.3), (m, m), "k-",linewidth=1, zorder=-1)
    
    ax.scatter(fill(i, n_datasets), x_ordered[:, i], color=col, s=1, alpha=0.8)
    
end

ax.set_ylim(0, 25)
ax.set_xlim(0.5,7.5)
ax.set_ylabel("Error (mm)")
ax.set_xticks((1:7))
ax.set_yticks([0,5,10,15,20,25], [0,"","","","","25"])
ax.set_xticklabels(names, rotation=45, ha="right", rotation_mode="anchor")
ax.text(3.75,-12, "Controls")
tight_layout(pad=0)
savefig("Figure_panels/F2e_error_by_region.pdf", format="pdf", transparent=true, dpi=300)

## error by amount cells

In [None]:
baseline3 = 15.080

In [None]:
x = fill(NaN, length(datasets_corner_cue), 14)
names = []

for i in 1:n_datasets
    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]
    
    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    
    file = h5open(joinpath(data_path(ds_Lorenz), "$(experiment_filename)_$(save_filename).h5"), "r")
        amount = read(file, "errors_by_amount_120_in_mm_keys")
        error_by_amount_place_cells = read(file, "errors_by_amount_120_in_mm_values")
        println(error_by_amount_place_cells)
    close(file)
    
    names = [10000, 5000, 1000, 500, 250, 100, 50, 25, 10, 5, 4, 3, 2, 1]

    x[i, end-length(error_by_amount_place_cells)+1:end] = error_by_amount_place_cells
    
end
print(x)
x=x[[1,2,3,4,6], :]

In [None]:
x

In [None]:
rc_params["mathtext.default"] = "regular"

fig, ax = plt.subplots(figsize=(1.1,1.2))

m = nanmean(x, dim=1)
s = nanstd(x, dim=1)
ax.axhline(baseline3, color="black", ls="--", linewidth=1, label="baseline") # behavior baseline
plt.text(1500, baseline3+1, "C3")

ax.plot(names, m, color="black", linewidth=1)

plt.fill_between(names, m-s, m+s, color="lightgrey", alpha=0.7)

ax.set_xlabel("Cells included")
ax.set_ylabel("Error (mm)")

ax.set_xscale("log")
#ax.set_xticks(ticks=[1, 10, 100, 1000, 10000], minor=true)
ax.set_xticks([1, 10, 100, 1000, 10000], labels=["\$10^0\$", "", "\$10^2\$", "", "\$10^4\$"], minor=false)
ax.set_xlim(1, 10000)
ax.set_ylim(0, 30)

tight_layout(pad=0)
savefig("Figure_panels/error_by_amount.pdf", format="pdf", transparent=true, dpi=300)

In [None]:
m

In [None]:
s

In [None]:
names

In [None]:
names[argmin(m)]

## minimum cells

In [None]:
amount_cells = 25
x = fill(NaN, length(datasets_corner_cue), amount_cells)

for i in 1:n_datasets
    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]
    
    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    
    file = h5open(joinpath(data_path(ds_Lorenz), "$(experiment_filename)_$(save_filename).h5"), "r")
        cells = read(file, "errors_by_amount_greedy_in_mm_cells")
        error_by_amount_place_cells_greedy = read(file, "errors_by_amount_greedy_in_mm_values")
    
    close(file)

    x[i, :] = error_by_amount_place_cells_greedy 
end

In [None]:
rc_params["mathtext.default"] = "regular"

fig, ax = plt.subplots(figsize=(1.1,1.1))



m = nanmean(x, dim=1)
s = nanstd(x, dim=1)
#ax.axhline(baseline3, color="black", ls="--", linewidth=1, label="baseline") # behavior baseline
#plt.text(1500, baseline3+1, "C3")

ax.plot(1:amount_cells, m, color="black", linewidth=1)

plt.fill_between(1:amount_cells, m-s, m+s, color="lightgrey", alpha=0.7)

ax.set_xlabel("Cells included")
ax.set_ylabel("Error (mm)")

#ax.set_xscale("log")
#ax.set_xticks(ticks=[1, 10, 100, 1000, 10000], minor=true)
ax.set_xticks([1, 5, 10, 15, 20, 25])
ax.set_xlim(1, 25)
ax.set_ylim(0, 30)

tight_layout(pad=0)
savefig("Figure_panels/error_by_amount_minimum.pdf", format="pdf", transparent=true, dpi=300)

In [None]:
m

In [None]:
s

## error by long shift

In [None]:
x = fill(NaN, length(datasets_corner_cue), 22)
names = []

for i in 1:n_datasets
    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]
    
    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")

    file = h5open(joinpath(data_path(ds_Lorenz), "$(experiment_filename)_$(save_filename).h5"), "r")
        amount = read(file, "long_shift_errors_mm_keys")
        error_by_amount_place_cells = read(file, "long_shift_errors_mm_values")
    
    close(file)
        
    names = amount
    x[i, :] = error_by_amount_place_cells
    
end

In [None]:
fig, ax = plt.subplots(figsize=(1.1,1.2))

m = nanmean(x, dim=1)
s = nanstd(x, dim=1)
ax.axhline(baseline3, color="black", ls="--", linewidth=1, label="baseline") # behavior baseline
plt.text(2, baseline3+1, "C3")

ax.plot(names, m, color="black", linewidth=1)

plt.fill_between(names, m-s, m+s, color="lightgrey", alpha=0.7)

ax.set_xlabel("Shift (s)")
ax.set_ylabel("Error (mm)")
ax.set_xlim(1, 1000)
ax.set_ylim(0, 20)
ax.set_xscale("log", subs=[])
ax.set_xticks([1, 10, 100, 1000], minor=false)

tight_layout(pad=0)
savefig("Figure_panels/error_by_shift.pdf", format="pdf",  transparent=true, dpi=300)

In [None]:
names

In [None]:
figure(figsize=(4,4))
scatter(x[:,names .== 0], x[:,names .== 5])
plot([0,10],[0,10])
xlim(0,10);ylim(0,10)

In [None]:
@pyimport scipy

scipy.stats.wilcoxon(x[:,names .== 0], x[:,names .== 5], alternative="greater")

In [None]:
names[argmin(m)]

## fish animation

In [None]:
# for rotatin     ["20220407_090156", 5, "20220407_104712", 5, "jen"],

In [None]:
@pyimport skimage.transform as skimage_transform

# Fig1: 20220407_090156,9
# rotation 20220406_153842 to 20220406_171558 server 9

#use_fish = 3
experiment_filename = "20220407_090156"#"20220407_104712" #datasets_corner_cue[use_fish][1] #"20220407_104712"
server = 5# datasets_corner_cue[use_fish][2]

ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
ds_frames = Dataset(string(gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)", "/jen/data_raw/$(experiment_filename)"))
ds = Dataset(experiment_filename, "jen", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")



neuron_merge_activity_file = h5open(joinpath(data_path(ds_Chuyu), "roi_mean_activity.h5"))
    A_dFF = HDF5.readmmap(neuron_merge_activity_file["A_dFF"]);
close(neuron_merge_activity_file)


file = h5open(joinpath(data_path(ds_Chuyu), "spatial_info_4.h5"), "r") #_whole #spatial_info_4 done on merged cells
    valid_neurons = HDF5.readmmap(file["valid_neurons"])
    specifity_z = HDF5.readmmap(file["specifity_z"])
    specifity = HDF5.readmmap(file["specifity"])
    bool_index = HDF5.readmmap(file["bool_index"])
close(file)

population_z = (specifity .- nanmean(specifity)) ./ nanstd(specifity);
bool_index = BitArray(bool_index)

place_cells_unique = findall(specifity_z .> 20)

try
    orientation_correction_file = h5open(joinpath(data_path(ds_Chuyu), "orientation_correction.h5"))
        chamber_roi = read(orientation_correction_file,"chamber_roi_r")
        img_bg_end = read(orientation_correction_file,"img_bg_end_r")
    close(orientation_correction_file)
catch
end

position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_4.h5"))
    x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
    y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
close(position_file)


C, heading, img_bg, y_fish, x_offset, x_fish, y_offset = h5open(ds, "behavior.h5"; raw = true) do file
    read(file, "C"),
    read(file, "heading"), 
    read(file, "img_bg"),
    read(file, "fish_yolk_y"), 
    read(file, "offset_x"), 
    read(file, "fish_yolk_x"), 
    read(file, "offset_y")
end;

heading_sweep_mean = fill(NaN, length(heading[1:125:end]))
    
for i in 2:length(heading_sweep_mean)-1
    heading_sweep_mean[i] = angular_mean(heading[(i*125-62):(i*125+62)])
end

place_candidates_unique = findall((population_z[valid_neurons] .> 3) .& (specifity_z[valid_neurons] .> 15))

ds_frames = Dataset(string(gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)", "/jen/data_raw/$(experiment_filename)"))
frames_reader = Reader(ds_frames, :ir);

In [None]:
file = h5open(joinpath(data_path(ds), "NMF.h5"), "r")
    z_all = HDF5.readmmap(file["z_all"])
    centroid_x_all = HDF5.readmmap(file["centroid_x_all"])
    centroid_y_all = HDF5.readmmap(file["centroid_y_all"])
close(file)

file = h5open(joinpath(data_path(ds_Chuyu), "spatial_info_4.h5"), "r") #_whole #spatial_info_4 done on merged cells
    valid_neurons = HDF5.readmmap(file["valid_neurons"])
    specifity_z = HDF5.readmmap(file["specifity_z"])
    specifity = HDF5.readmmap(file["specifity"])
close(file)

n_neurons = size(A_dFF, 2)
n_sweeps = size(A_dFF, 1)

population_z = (specifity .- nanmean(specifity)) ./ nanstd(specifity);

# average the locations of rois to get the locations of neurons
neuron_label = read(h5open(joinpath(data_path(ds_Chuyu), "neuron_label.h5"),"r"), "neuron_label");
cell_locs = fill(NaN32, n_neurons, 3)
for which_neuron in 1:n_neurons
    cell_locs[which_neuron, 1] = mean(centroid_x_all[neuron_label.==which_neuron])
    cell_locs[which_neuron, 2] = mean(centroid_y_all[neuron_label.==which_neuron])
    cell_locs[which_neuron, 3] = mean(z_all[neuron_label.==which_neuron])
end

centroid_x_all = cell_locs[:, 1];
centroid_y_all = cell_locs[:, 2];
z_all = cell_locs[:, 3];

place_candidates_unique = findall(specifity_z .> 25);


In [None]:
celli=38

In [None]:
celli+=1
print(celli)


cell = 45126


activity_smothed = Decoder.bin_activity(2, A_dFF[:, cell])
stand = nanstd(activity_smothed)

tmin=1
tmax=10799

imshow(img_bg[:,:,end]', origin="lower", cmap="binary_r")
scatter(x_fish_sweep_mean[tmin:tmax+1][isnan.(activity_smothed[tmin:tmax+1])], y_fish_sweep_mean[tmin:tmax+1][isnan.(activity_smothed[tmin:tmax+1])], color = [0.5, 0.5, 0.5], s=4, alpha = 1)
stand = nanstd(activity_smothed[tmin:tmax])
m = activity_smothed[tmin:tmax] .> 3*stand
scatter(256 .+ x_offset[tmin*125:125:tmax*125][m], 256 .+ y_offset[tmin*125:125:tmax*125][m], color = [1,0,0], s=8, alpha = 1)
axis("equal")
axis("off")

In [None]:
# one cell animation


fig, ax = plt.subplots()

im = ax.imshow(img_bg[:,:,end]', origin="lower", cmap="binary_r")
traj_plt, = ax.plot([],[], color = [0.8, 0.8, 0.8], alpha = 0.35, linewidth=1)
scat = ax.scatter([],[], color=[1,0,0], s=8, alpha = 1)
scat_nan = ax.scatter([],[], color = [0.5, 0.5, 0.5], s=4, alpha = 1)

A = img_bg[:,:,end]

frame_idx=1


axis("equal")
axis("off")
tight_layout(pad=0)

activity_smothed = Decoder.bin_activity(2, A_dFF[:, cell])

activity_color = activity_smothed ./ nanmaximum(activity_smothed)
activity_color[activity_color .< 0] .= 0


stand = nanstd(activity_smothed[tmin:tmax])
jump = 250
@showprogress for j = tmin*125:jump:(tmax*125)-1
    
    img = copy(img_bg[:,:,end])
        
    #off=-100+256-256
    #img[x_offset[j]-off:x_offset[j]+300-off, y_offset[j]-off: y_offset[j]+300-off] .= read(frames_reader, j)[100:400, 100:400]
    
    frame = Float32.(read(frames_reader, j))

    offset = argmax(frame[256-50:256+50, 256-50:256+50]) #could do com maybe
    center_of_fish = [206+offset[1], 206+offset[2]]
    
    mask = fill(false, size(frame));
    mask[center_of_fish[1]-80:center_of_fish[1]+60, center_of_fish[2]-30:center_of_fish[2]+30] .= true
    mask = skimage_transform.rotate(mask, heading[j]*360/(2*pi), center = [206+offset[2], 206+offset[1]])#[size(mask)[1]/2, size(mask)[2]/2])

    mask2 = fill(false, size(img))
    mask2[x_offset[j]:x_offset[j]+511, y_offset[j]: y_offset[j]+511] .= mask
    
    #brightness adjustion
    frame = (frame .- mean(frame[mask])) ./ std(frame[mask])
    frame = (frame .* 700) .+ mean(img)

    
    img[mask2] .= frame[mask]
    
    
    im.set_data(img')
    
    #tracjectory
    traj_plt.set_ydata(y_fish[tmin*125:5:j])
    traj_plt.set_xdata(x_fish[tmin*125:5:j])
    
    if j%125 == 0
          
        i = round(Int32, j/125)

        scat_nan.set_offsets(hcat(256 .+ x_offset[tmin*125:125:i*125][isnan.(activity_smothed[tmin:i])], 256 .+ y_offset[tmin*125:125:i*125][isnan.(activity_smothed[tmin:i])]))


        m = activity_smothed[tmin:i] .> 3*stand
        if sum(m) >= 1
            
            scat.set_offsets(hcat(x_fish[tmin*125:125:i*125][m], y_fish[tmin*125:125:i*125][m]))
            #scat.set_offsets(hcat(256 .+ x_offset[tmin*125:125:i*125][m], 256 .+ y_offset[tmin*125:125:i*125][m]))
            #scat.set_array([[x,0,0] for x in activity_color[tmin:i][m]])
            #ax.scatter(256 .+ x_offset[tmin*125:125:i*125][m], 256 .+ y_offset[tmin*125:125:i*125][m], color = [[x,0,0] for x in activity_color[tmin:i][m]], s=8, alpha = 1)
        end

    end
    
    ax.set_xlim(150,3150)
    ax.set_ylim(500,2250)

    
    fig.savefig("tmp/frame_$(frame_idx).png")
    frame_idx +=1
end


In [None]:
#Using background fish, try with higher background framerate
#multi-cell

# for first video place_candidates_unique[[561, 595, 747, 528]]

cells = place_candidates_unique[[81, 85, 39]]
colors = [[0,1,1], [1,1,0], [1,0,1]]


fig, ax = plt.subplots()

im = ax.imshow(img_bg[:,:,end]', origin="lower", cmap="binary_r")
traj_plt, = ax.plot([],[], color = [0.8, 0.8, 0.8], alpha = 0.35, linewidth=1)

scat_nan = ax.scatter([],[], color = [0.5, 0.5, 0.5], s=8, alpha = 1)

s1 = ax.scatter([],[], color=colors[1], s=8, alpha = 1)
s2 = ax.scatter([],[], color=colors[2], s=8, alpha = 1)
s3 = ax.scatter([],[], color=colors[3], s=8, alpha = 1)

scat = [s1, s2, s3]


A = img_bg[:,:,end]


frame_idx=1

tmin=3000
tmax=10000

axis("equal")
axis("off")
tight_layout(pad=0)

activity_smothed = []
activity_color = []
stand = []

for cell in cells
    
    as = Decoder.bin_activity(2, A_dFF[:, cell])

    ac_temp = as ./ nanmaximum(as)
    ac_temp[ac_temp .< 0] .= 0
    
    try
        activity_smothed = vcat(activity_smothed, [as])
        
        activity_color = vcat(activity_color, [ac_temp])
        
        stand = append!(stand, nanstd(as))
    catch
        activity_smothed = as
        
        activity_color = ac_temp
        
        stand = nanstd(as)
    end
end
    
jump = 125
@showprogress for j = tmin*125:jump:(tmax*125)-1
    
    img = copy(img_bg[:,:,end])
        
    #off=-100+256-256
    #img[x_offset[j]-off:x_offset[j]+300-off, y_offset[j]-off: y_offset[j]+300-off] .= read(frames_reader, j)[100:400, 100:400]
    
    frame = Float32.(read(frames_reader, j))

    offset = argmax(frame[256-50:256+50, 256-50:256+50]) #could do com maybe
    center_of_fish = [206+offset[1], 206+offset[2]]
    
    mask = fill(false, size(frame));
    mask[center_of_fish[1]-80:center_of_fish[1]+60, center_of_fish[2]-30:center_of_fish[2]+30] .= true
    mask = skimage_transform.rotate(mask, heading[j]*360/(2*pi), center = [206+offset[2], 206+offset[1]])#[size(mask)[1]/2, size(mask)[2]/2])

    mask2 = fill(false, size(img))
    mask2[x_offset[j]:x_offset[j]+511, y_offset[j]: y_offset[j]+511] .= mask
    
    #brightness adjustion
    frame = (frame .- mean(frame[mask])) ./ std(frame[mask])
    frame = (frame .* 700) .+ mean(img)

    
    img[mask2] .= frame[mask]
    
    
    im.set_data(img')
    
    #tracjectory
    traj_plt.set_ydata(y_fish[tmin*125:5:j])
    traj_plt.set_xdata(x_fish[tmin*125:5:j])
    
    if j%125 == 0
        
        
        
        i = round(Int32, j/125)
        for i_cell in 1:3
            
            scat_nan.set_offsets(hcat(256 .+ x_offset[tmin*125:125:i*125][isnan.(activity_smothed[i_cell][tmin:i])], 256 .+ y_offset[tmin*125:125:i*125][isnan.(activity_smothed[i_cell][tmin:i])]))

            m = activity_smothed[i_cell][tmin:i] .> 3*stand[i_cell]
            if sum(m) >= 1

                scat[i_cell].set_offsets(hcat(x_fish[tmin*125:125:i*125][m], y_fish[tmin*125:125:i*125][m]))
                #scat.set_offsets(hcat(256 .+ x_offset[tmin*125:125:i*125][m], 256 .+ y_offset[tmin*125:125:i*125][m]))
                #scat.set_array([[x,0,0] for x in activity_color[tmin:i][m]])
                #ax.scatter(256 .+ x_offset[tmin*125:125:i*125][m], 256 .+ y_offset[tmin*125:125:i*125][m], color = [[x,0,0] for x in activity_color[tmin:i][m]], s=8, alpha = 1)
            end
        end
    end
    
    ax.set_xlim(150,3150)
    ax.set_ylim(500,2250)

    
    fig.savefig("tmp/frame_$(frame_idx).png")
    frame_idx +=1
end

In [None]:
using FileIO, ImageIO, Images

A = Any[]

@showprogress for i in 1:1750
    img_path = "tmp/frame_$(i).png"
    push!(A, load(img_path))
    
end
B = cat(A..., dims=3)
FileIO.save("test.gif", B; fps=60)
A = nothing
